From 5e8cd693a5f32a6781cf0c1c5a7d4f53a0e8df74 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Wed, 13 Aug 2025 16:41:01 +0800 Subject: [PATCH] Refa: split services about llm. (#9450) ### What problem does this PR solve? ### Type of change - [x] Refactoring --- agent/component/agent_with_tools.py | 3 +- agent/component/llm.py | 3 +- api/apps/conversation_app.py | 4 +- api/apps/dialog_app.py | 2 +- api/apps/llm_app.py | 3 +- api/apps/sdk/chat.py | 2 +- api/apps/sdk/doc.py | 3 +- api/apps/sdk/session.py | 5 +- api/apps/user_app.py | 53 +---- api/db/init_data.py | 41 +--- api/db/services/dialog_service.py | 3 +- api/db/services/llm_service.py | 275 +++++--------------------- api/db/services/tenant_llm_service.py | 252 +++++++++++++++++++++++ api/utils/api_utils.py | 3 +- rag/prompts/prompts.py | 4 +- 15 files changed, 327 insertions(+), 329 deletions(-) create mode 100644 api/db/services/tenant_llm_service.py diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 40c47d9b5..f656af0e3 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -24,7 +24,8 @@ from typing import Any import json_repair from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta -from api.db.services.llm_service import LLMBundle, TenantLLMService +from api.db.services.llm_service import LLMBundle +from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.mcp_server_service import MCPServerService from api.utils.api_utils import timeout from rag.prompts import message_fit_in diff --git a/agent/component/llm.py b/agent/component/llm.py index 2d2e831bf..5e10220a4 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -24,7 +24,8 @@ from copy import deepcopy from functools import partial from api.db import LLMType -from api.db.services.llm_service import LLMBundle, TenantLLMService +from api.db.services.llm_service import LLMBundle +from api.db.services.tenant_llm_service import TenantLLMService from agent.component.base import ComponentBase, ComponentParamBase from api.utils.api_utils import timeout from rag.prompts import message_fit_in, citation_prompt diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 802ac50ca..b44101135 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -28,8 +28,8 @@ from api.db.db_models import APIToken from api.db.services.conversation_service import ConversationService, structure_answer from api.db.services.dialog_service import DialogService, ask, chat from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import LLMBundle, TenantService -from api.db.services.user_service import UserTenantService +from api.db.services.llm_service import LLMBundle +from api.db.services.user_service import UserTenantService, TenantService from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request from graphrag.general.mind_map_extractor import MindMapExtractor from rag.app.tag import label_question diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index c5c48b6e1..85778bb90 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -18,7 +18,7 @@ from flask import request from flask_login import login_required, current_user from api.db.services.dialog_service import DialogService from api.db import StatusEnum -from api.db.services.llm_service import TenantLLMService +from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.user_service import TenantService, UserTenantService from api import settings diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 2ec8180cd..a876cd712 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -17,7 +17,8 @@ import logging import json from flask import request from flask_login import login_required, current_user -from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService +from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService +from api.db.services.llm_service import LLMService from api import settings from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.db import StatusEnum, LLMType diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py index edf5e2699..fc3559d58 100644 --- a/api/apps/sdk/chat.py +++ b/api/apps/sdk/chat.py @@ -21,7 +21,7 @@ from api import settings from api.db import StatusEnum from api.db.services.dialog_service import DialogService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import TenantLLMService +from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_service import TenantService from api.utils import get_uuid from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 24420f0f3..454633f07 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -32,7 +32,8 @@ from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import LLMBundle, TenantLLMService +from api.db.services.llm_service import LLMBundle +from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.task_service import TaskService, queue_tasks from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required from rag.app.qa import beAdoc, rmPrefix diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index ce6b189bc..9e3454cd1 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -16,20 +16,17 @@ import json import re import time - import tiktoken from flask import Response, jsonify, request - from agent.canvas import Canvas from api.db import LLMType, StatusEnum -from api.db.db_models import API4Conversation, APIToken +from api.db.db_models import APIToken from api.db.services.api_service import API4ConversationService from api.db.services.canvas_service import UserCanvasService, completionOpenAI from api.db.services.canvas_service import completion as agent_completion from api.db.services.conversation_service import ConversationService, iframe_completion from api.db.services.conversation_service import completion as rag_completion from api.db.services.dialog_service import DialogService, ask, chat -from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.utils import get_uuid diff --git a/api/apps/user_app.py b/api/apps/user_app.py index 93b07615c..76bd89563 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -28,7 +28,7 @@ from api.apps.auth import get_auth_client from api.db import FileType, UserTenantRole from api.db.db_models import TenantLLM from api.db.services.file_service import FileService -from api.db.services.llm_service import LLMService, TenantLLMService +from api.db.services.llm_service import TenantLLMService, get_init_tenant_llm from api.db.services.user_service import TenantService, UserService, UserTenantService from api.utils import ( current_timestamp, @@ -619,57 +619,8 @@ def user_register(user_id, user): "size": 0, "location": "", } - tenant_llm = [] - seen = set() - factory_configs = [] - for factory_config in [ - settings.CHAT_CFG, - settings.EMBEDDING_CFG, - settings.ASR_CFG, - settings.IMAGE2TEXT_CFG, - settings.RERANK_CFG, - ]: - factory_name = factory_config["factory"] - if factory_name not in seen: - seen.add(factory_name) - factory_configs.append(factory_config) - - for factory_config in factory_configs: - for llm in LLMService.query(fid=factory_config["factory"]): - tenant_llm.append( - { - "tenant_id": user_id, - "llm_factory": factory_config["factory"], - "llm_name": llm.llm_name, - "model_type": llm.model_type, - "api_key": factory_config["api_key"], - "api_base": factory_config["base_url"], - "max_tokens": llm.max_tokens if llm.max_tokens else 8192, - } - ) - - if settings.LIGHTEN != 1: - for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS: - mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model) - tenant_llm.append( - { - "tenant_id": user_id, - "llm_factory": fid, - "llm_name": mdlnm, - "model_type": "embedding", - "api_key": "", - "api_base": "", - "max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512, - } - ) - - unique = {} - for item in tenant_llm: - key = (item["tenant_id"], item["llm_factory"], item["llm_name"]) - if key not in unique: - unique[key] = item - tenant_llm = list(unique.values()) + tenant_llm = get_init_tenant_llm(user_id) if not UserService.save(**user): return diff --git a/api/db/init_data.py b/api/db/init_data.py index 83456035d..d462a7b2b 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -27,7 +27,8 @@ from api.db.services import UserService from api.db.services.canvas_service import CanvasTemplateService from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle +from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService +from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm from api.db.services.user_service import TenantService, UserTenantService from api import settings from api.utils.file_utils import get_project_base_directory @@ -64,43 +65,7 @@ def init_superuser(): "role": UserTenantRole.OWNER } - user_id = user_info - tenant_llm = [] - - seen = set() - factory_configs = [] - for factory_config in [ - settings.CHAT_CFG["factory"], - settings.EMBEDDING_CFG["factory"], - settings.ASR_CFG["factory"], - settings.IMAGE2TEXT_CFG["factory"], - settings.RERANK_CFG["factory"], - ]: - factory_name = factory_config["factory"] - if factory_name not in seen: - seen.add(factory_name) - factory_configs.append(factory_config) - - for factory_config in factory_configs: - for llm in LLMService.query(fid=factory_config["factory"]): - tenant_llm.append( - { - "tenant_id": user_id, - "llm_factory": factory_config["factory"], - "llm_name": llm.llm_name, - "model_type": llm.model_type, - "api_key": factory_config["api_key"], - "api_base": factory_config["base_url"], - "max_tokens": llm.max_tokens if llm.max_tokens else 8192, - } - ) - - unique = {} - for item in tenant_llm: - key = (item["tenant_id"], item["llm_factory"], item["llm_name"]) - if key not in unique: - unique[key] = item - tenant_llm = list(unique.values()) + tenant_llm = get_init_tenant_llm(user_info["id"]) if not UserService.save(**user_info): logging.error("can't init admin.") diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index c2748589b..aa4f8ba4e 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -33,7 +33,8 @@ from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.langfuse_service import TenantLangfuseService -from api.db.services.llm_service import LLMBundle, TenantLLMService +from api.db.services.llm_service import LLMBundle +from api.db.services.tenant_llm_service import TenantLLMService from api.utils import current_timestamp, datetime_format from rag.app.resume import forbidden_select_fields4resume from rag.app.tag import label_question diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 046e0e274..60bdde7a3 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -18,246 +18,73 @@ import logging import re from functools import partial from typing import Generator - -from langfuse import Langfuse - -from api import settings -from api.db import LLMType -from api.db.db_models import DB, LLM, LLMFactories, TenantLLM +from api.db.db_models import LLM from api.db.services.common_service import CommonService -from api.db.services.langfuse_service import TenantLangfuseService -from api.db.services.user_service import TenantService -from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel - - -class LLMFactoriesService(CommonService): - model = LLMFactories +from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService class LLMService(CommonService): model = LLM -class TenantLLMService(CommonService): - model = TenantLLM +def get_init_tenant_llm(user_id): + from api import settings + tenant_llm = [] - @classmethod - @DB.connection_context() - def get_api_key(cls, tenant_id, model_name): - mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name) - if not fid: - objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm) - else: - objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) + seen = set() + factory_configs = [] + for factory_config in [ + settings.CHAT_CFG, + settings.EMBEDDING_CFG, + settings.ASR_CFG, + settings.IMAGE2TEXT_CFG, + settings.RERANK_CFG, + ]: + factory_name = factory_config["factory"] + if factory_name not in seen: + seen.add(factory_name) + factory_configs.append(factory_config) - if (not objs) and fid: - if fid == "LocalAI": - mdlnm += "___LocalAI" - elif fid == "HuggingFace": - mdlnm += "___HuggingFace" - elif fid == "OpenAI-API-Compatible": - mdlnm += "___OpenAI-API" - elif fid == "VLLM": - mdlnm += "___VLLM" - objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) - if not objs: - return - return objs[0] - - @classmethod - @DB.connection_context() - def get_my_llms(cls, tenant_id): - fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens] - objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts() - - return list(objs) - - @staticmethod - def split_model_name_and_factory(model_name): - arr = model_name.split("@") - if len(arr) < 2: - return model_name, None - if len(arr) > 2: - return "@".join(arr[0:-1]), arr[-1] - - # model name must be xxx@yyy - try: - model_factories = settings.FACTORY_LLM_INFOS - model_providers = set([f["name"] for f in model_factories]) - if arr[-1] not in model_providers: - return model_name, None - return arr[0], arr[-1] - except Exception as e: - logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}") - return model_name, None - - @classmethod - @DB.connection_context() - def get_model_config(cls, tenant_id, llm_type, llm_name=None): - e, tenant = TenantService.get_by_id(tenant_id) - if not e: - raise LookupError("Tenant not found") - - if llm_type == LLMType.EMBEDDING.value: - mdlnm = tenant.embd_id if not llm_name else llm_name - elif llm_type == LLMType.SPEECH2TEXT.value: - mdlnm = tenant.asr_id - elif llm_type == LLMType.IMAGE2TEXT.value: - mdlnm = tenant.img2txt_id if not llm_name else llm_name - elif llm_type == LLMType.CHAT.value: - mdlnm = tenant.llm_id if not llm_name else llm_name - elif llm_type == LLMType.RERANK: - mdlnm = tenant.rerank_id if not llm_name else llm_name - elif llm_type == LLMType.TTS: - mdlnm = tenant.tts_id if not llm_name else llm_name - else: - assert False, "LLM type error" - - model_config = cls.get_api_key(tenant_id, mdlnm) - mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm) - if not model_config: # for some cases seems fid mismatch - model_config = cls.get_api_key(tenant_id, mdlnm) - if model_config: - model_config = model_config.to_dict() - llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) - if not llm and fid: # for some cases seems fid mismatch - llm = LLMService.query(llm_name=mdlnm) - if llm: - model_config["is_tools"] = llm[0].is_tools - if not model_config: - if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: - llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) - if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]: - model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""} - if not model_config: - if mdlnm == "flag-embedding": - model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""} - else: - if not mdlnm: - raise LookupError(f"Type of {llm_type} model is not set.") - raise LookupError("Model({}) not authorized".format(mdlnm)) - return model_config - - @classmethod - @DB.connection_context() - def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs): - model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) - kwargs.update({"provider": model_config["llm_factory"]}) - if llm_type == LLMType.EMBEDDING.value: - if model_config["llm_factory"] not in EmbeddingModel: - return - return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) - - if llm_type == LLMType.RERANK: - if model_config["llm_factory"] not in RerankModel: - return - return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) - - if llm_type == LLMType.IMAGE2TEXT.value: - if model_config["llm_factory"] not in CvModel: - return - return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs) - - if llm_type == LLMType.CHAT.value: - if model_config["llm_factory"] not in ChatModel: - return - return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs) - - if llm_type == LLMType.SPEECH2TEXT: - if model_config["llm_factory"] not in Seq2txtModel: - return - return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"]) - if llm_type == LLMType.TTS: - if model_config["llm_factory"] not in TTSModel: - return - return TTSModel[model_config["llm_factory"]]( - model_config["api_key"], - model_config["llm_name"], - base_url=model_config["api_base"], + for factory_config in factory_configs: + for llm in LLMService.query(fid=factory_config["factory"]): + tenant_llm.append( + { + "tenant_id": user_id, + "llm_factory": factory_config["factory"], + "llm_name": llm.llm_name, + "model_type": llm.model_type, + "api_key": factory_config["api_key"], + "api_base": factory_config["base_url"], + "max_tokens": llm.max_tokens if llm.max_tokens else 8192, + } ) - @classmethod - @DB.connection_context() - def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): - e, tenant = TenantService.get_by_id(tenant_id) - if not e: - logging.error(f"Tenant not found: {tenant_id}") - return 0 - - llm_map = { - LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name, - LLMType.SPEECH2TEXT.value: tenant.asr_id, - LLMType.IMAGE2TEXT.value: tenant.img2txt_id, - LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name, - LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name, - LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name, - } - - mdlnm = llm_map.get(llm_type) - if mdlnm is None: - logging.error(f"LLM type error: {llm_type}") - return 0 - - llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm) - - try: - num = ( - cls.model.update(used_tokens=cls.model.used_tokens + used_tokens) - .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True) - .execute() + if settings.LIGHTEN != 1: + for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS: + mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model) + tenant_llm.append( + { + "tenant_id": user_id, + "llm_factory": fid, + "llm_name": mdlnm, + "model_type": "embedding", + "api_key": "", + "api_base": "", + "max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512, + } ) - except Exception: - logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name) - return 0 - return num - - @classmethod - @DB.connection_context() - def get_openai_models(cls): - objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() - return list(objs) - - @staticmethod - def llm_id2llm_type(llm_id: str) -> str | None: - llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id) - llm_factories = settings.FACTORY_LLM_INFOS - for llm_factory in llm_factories: - for llm in llm_factory["llm"]: - if llm_id == llm["llm_name"]: - return llm["model_type"].split(",")[-1] - - for llm in LLMService.query(llm_name=llm_id): - return llm.model_type - - llm = TenantLLMService.get_or_none(llm_name=llm_id) - if llm: - return llm.model_type - for llm in TenantLLMService.query(llm_name=llm_id): - return llm.model_type + unique = {} + for item in tenant_llm: + key = (item["tenant_id"], item["llm_factory"], item["llm_name"]) + if key not in unique: + unique[key] = item + return list(unique.values()) -class LLMBundle: +class LLMBundle(LLM4Tenant): def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs): - self.tenant_id = tenant_id - self.llm_type = llm_type - self.llm_name = llm_name - self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs) - assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name) - model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) - self.max_length = model_config.get("max_tokens", 8192) - - self.is_tools = model_config.get("is_tools", False) - self.verbose_tool_use = kwargs.get("verbose_tool_use") - - langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id) - self.langfuse = None - if langfuse_keys: - langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host) - if langfuse.auth_check(): - self.langfuse = langfuse - trace_id = self.langfuse.create_trace_id() - self.trace_context = {"trace_id": trace_id} + super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs) def bind_tools(self, toolcall_session, tools): if not self.is_tools: diff --git a/api/db/services/tenant_llm_service.py b/api/db/services/tenant_llm_service.py new file mode 100644 index 000000000..ec023f115 --- /dev/null +++ b/api/db/services/tenant_llm_service.py @@ -0,0 +1,252 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +from langfuse import Langfuse +from api import settings +from api.db import LLMType +from api.db.db_models import DB, LLMFactories, TenantLLM +from api.db.services.common_service import CommonService +from api.db.services.langfuse_service import TenantLangfuseService +from api.db.services.user_service import TenantService +from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel + + +class LLMFactoriesService(CommonService): + model = LLMFactories + + +class TenantLLMService(CommonService): + model = TenantLLM + + @classmethod + @DB.connection_context() + def get_api_key(cls, tenant_id, model_name): + mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name) + if not fid: + objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm) + else: + objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) + + if (not objs) and fid: + if fid == "LocalAI": + mdlnm += "___LocalAI" + elif fid == "HuggingFace": + mdlnm += "___HuggingFace" + elif fid == "OpenAI-API-Compatible": + mdlnm += "___OpenAI-API" + elif fid == "VLLM": + mdlnm += "___VLLM" + objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) + if not objs: + return + return objs[0] + + @classmethod + @DB.connection_context() + def get_my_llms(cls, tenant_id): + fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens] + objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts() + + return list(objs) + + @staticmethod + def split_model_name_and_factory(model_name): + arr = model_name.split("@") + if len(arr) < 2: + return model_name, None + if len(arr) > 2: + return "@".join(arr[0:-1]), arr[-1] + + # model name must be xxx@yyy + try: + model_factories = settings.FACTORY_LLM_INFOS + model_providers = set([f["name"] for f in model_factories]) + if arr[-1] not in model_providers: + return model_name, None + return arr[0], arr[-1] + except Exception as e: + logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}") + return model_name, None + + @classmethod + @DB.connection_context() + def get_model_config(cls, tenant_id, llm_type, llm_name=None): + from api.db.services.llm_service import LLMService + e, tenant = TenantService.get_by_id(tenant_id) + if not e: + raise LookupError("Tenant not found") + + if llm_type == LLMType.EMBEDDING.value: + mdlnm = tenant.embd_id if not llm_name else llm_name + elif llm_type == LLMType.SPEECH2TEXT.value: + mdlnm = tenant.asr_id + elif llm_type == LLMType.IMAGE2TEXT.value: + mdlnm = tenant.img2txt_id if not llm_name else llm_name + elif llm_type == LLMType.CHAT.value: + mdlnm = tenant.llm_id if not llm_name else llm_name + elif llm_type == LLMType.RERANK: + mdlnm = tenant.rerank_id if not llm_name else llm_name + elif llm_type == LLMType.TTS: + mdlnm = tenant.tts_id if not llm_name else llm_name + else: + assert False, "LLM type error" + + model_config = cls.get_api_key(tenant_id, mdlnm) + mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm) + if not model_config: # for some cases seems fid mismatch + model_config = cls.get_api_key(tenant_id, mdlnm) + if model_config: + model_config = model_config.to_dict() + llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) + if not llm and fid: # for some cases seems fid mismatch + llm = LLMService.query(llm_name=mdlnm) + if llm: + model_config["is_tools"] = llm[0].is_tools + if not model_config: + if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: + llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) + if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]: + model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""} + if not model_config: + if mdlnm == "flag-embedding": + model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""} + else: + if not mdlnm: + raise LookupError(f"Type of {llm_type} model is not set.") + raise LookupError("Model({}) not authorized".format(mdlnm)) + return model_config + + @classmethod + @DB.connection_context() + def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs): + model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) + kwargs.update({"provider": model_config["llm_factory"]}) + if llm_type == LLMType.EMBEDDING.value: + if model_config["llm_factory"] not in EmbeddingModel: + return + return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) + + if llm_type == LLMType.RERANK: + if model_config["llm_factory"] not in RerankModel: + return + return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) + + if llm_type == LLMType.IMAGE2TEXT.value: + if model_config["llm_factory"] not in CvModel: + return + return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs) + + if llm_type == LLMType.CHAT.value: + if model_config["llm_factory"] not in ChatModel: + return + return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs) + + if llm_type == LLMType.SPEECH2TEXT: + if model_config["llm_factory"] not in Seq2txtModel: + return + return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"]) + if llm_type == LLMType.TTS: + if model_config["llm_factory"] not in TTSModel: + return + return TTSModel[model_config["llm_factory"]]( + model_config["api_key"], + model_config["llm_name"], + base_url=model_config["api_base"], + ) + + @classmethod + @DB.connection_context() + def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): + e, tenant = TenantService.get_by_id(tenant_id) + if not e: + logging.error(f"Tenant not found: {tenant_id}") + return 0 + + llm_map = { + LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name, + LLMType.SPEECH2TEXT.value: tenant.asr_id, + LLMType.IMAGE2TEXT.value: tenant.img2txt_id, + LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name, + LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name, + LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name, + } + + mdlnm = llm_map.get(llm_type) + if mdlnm is None: + logging.error(f"LLM type error: {llm_type}") + return 0 + + llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm) + + try: + num = ( + cls.model.update(used_tokens=cls.model.used_tokens + used_tokens) + .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True) + .execute() + ) + except Exception: + logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name) + return 0 + + return num + + @classmethod + @DB.connection_context() + def get_openai_models(cls): + objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() + return list(objs) + + @staticmethod + def llm_id2llm_type(llm_id: str) -> str | None: + from api.db.services.llm_service import LLMService + llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id) + llm_factories = settings.FACTORY_LLM_INFOS + for llm_factory in llm_factories: + for llm in llm_factory["llm"]: + if llm_id == llm["llm_name"]: + return llm["model_type"].split(",")[-1] + + for llm in LLMService.query(llm_name=llm_id): + return llm.model_type + + llm = TenantLLMService.get_or_none(llm_name=llm_id) + if llm: + return llm.model_type + for llm in TenantLLMService.query(llm_name=llm_id): + return llm.model_type + + +class LLM4Tenant: + def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs): + self.tenant_id = tenant_id + self.llm_type = llm_type + self.llm_name = llm_name + self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs) + assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name) + model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) + self.max_length = model_config.get("max_tokens", 8192) + + self.is_tools = model_config.get("is_tools", False) + self.verbose_tool_use = kwargs.get("verbose_tool_use") + + langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id) + self.langfuse = None + if langfuse_keys: + langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host) + if langfuse.auth_check(): + self.langfuse = langfuse + trace_id = self.langfuse.create_trace_id() + self.trace_context = {"trace_id": trace_id} \ No newline at end of file diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 8e98c1d8d..4de7306b1 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -48,7 +48,8 @@ from werkzeug.http import HTTP_STATUS_CODES from api import settings from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC from api.db.db_models import APIToken -from api.db.services.llm_service import LLMService, TenantLLMService +from api.db.services.llm_service import LLMService +from api.db.services.tenant_llm_service import TenantLLMService from api.utils import CustomJSONEncoder, get_uuid, json_dumps from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions diff --git a/rag/prompts/prompts.py b/rag/prompts/prompts.py index e3ce3e457..c49c92ea9 100644 --- a/rag/prompts/prompts.py +++ b/rag/prompts/prompts.py @@ -197,7 +197,7 @@ def question_proposal(chat_mdl, content, topn=3): def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None): from api.db import LLMType from api.db.services.llm_service import LLMBundle - from api.db.services.llm_service import TenantLLMService + from api.db.services.tenant_llm_service import TenantLLMService if not chat_mdl: if TenantLLMService.llm_id2llm_type(llm_id) == "image2text": @@ -231,7 +231,7 @@ def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_ def cross_languages(tenant_id, llm_id, query, languages=[]): from api.db import LLMType from api.db.services.llm_service import LLMBundle - from api.db.services.llm_service import TenantLLMService + from api.db.services.tenant_llm_service import TenantLLMService if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text": chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)