mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: split services about llm. (#9450)
### What problem does this PR solve? ### Type of change - [x] Refactoring
This commit is contained in:
@ -24,7 +24,8 @@ from typing import Any
|
|||||||
import json_repair
|
import json_repair
|
||||||
|
|
||||||
from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
|
from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
|
||||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.mcp_server_service import MCPServerService
|
from api.db.services.mcp_server_service import MCPServerService
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.prompts import message_fit_in
|
from rag.prompts import message_fit_in
|
||||||
|
|||||||
@ -24,7 +24,8 @@ from copy import deepcopy
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.prompts import message_fit_in, citation_prompt
|
from rag.prompts import message_fit_in, citation_prompt
|
||||||
|
|||||||
@ -28,8 +28,8 @@ from api.db.db_models import APIToken
|
|||||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||||
from api.db.services.dialog_service import DialogService, ask, chat
|
from api.db.services.dialog_service import DialogService, ask, chat
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle, TenantService
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService, TenantService
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from flask import request
|
|||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
from api.db.services.dialog_service import DialogService
|
from api.db.services.dialog_service import DialogService
|
||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
from api.db.services.llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api import settings
|
from api import settings
|
||||||
|
|||||||
@ -17,7 +17,8 @@ import logging
|
|||||||
import json
|
import json
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||||
|
from api.db.services.llm_service import LLMService
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from api.db import StatusEnum, LLMType
|
from api.db import StatusEnum, LLMType
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from api import settings
|
|||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
from api.db.services.dialog_service import DialogService
|
from api.db.services.dialog_service import DialogService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
|
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
|
||||||
|
|||||||
@ -32,7 +32,8 @@ from api.db.services.document_service import DocumentService
|
|||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.task_service import TaskService, queue_tasks
|
from api.db.services.task_service import TaskService, queue_tasks
|
||||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
|
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
|
||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
|
|||||||
@ -16,20 +16,17 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from flask import Response, jsonify, request
|
from flask import Response, jsonify, request
|
||||||
|
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
from api.db import LLMType, StatusEnum
|
from api.db import LLMType, StatusEnum
|
||||||
from api.db.db_models import API4Conversation, APIToken
|
from api.db.db_models import APIToken
|
||||||
from api.db.services.api_service import API4ConversationService
|
from api.db.services.api_service import API4ConversationService
|
||||||
from api.db.services.canvas_service import UserCanvasService, completionOpenAI
|
from api.db.services.canvas_service import UserCanvasService, completionOpenAI
|
||||||
from api.db.services.canvas_service import completion as agent_completion
|
from api.db.services.canvas_service import completion as agent_completion
|
||||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||||
from api.db.services.conversation_service import completion as rag_completion
|
from api.db.services.conversation_service import completion as rag_completion
|
||||||
from api.db.services.dialog_service import DialogService, ask, chat
|
from api.db.services.dialog_service import DialogService, ask, chat
|
||||||
from api.db.services.file_service import FileService
|
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from api.apps.auth import get_auth_client
|
|||||||
from api.db import FileType, UserTenantRole
|
from api.db import FileType, UserTenantRole
|
||||||
from api.db.db_models import TenantLLM
|
from api.db.db_models import TenantLLM
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.llm_service import LLMService, TenantLLMService
|
from api.db.services.llm_service import TenantLLMService, get_init_tenant_llm
|
||||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||||
from api.utils import (
|
from api.utils import (
|
||||||
current_timestamp,
|
current_timestamp,
|
||||||
@ -619,57 +619,8 @@ def user_register(user_id, user):
|
|||||||
"size": 0,
|
"size": 0,
|
||||||
"location": "",
|
"location": "",
|
||||||
}
|
}
|
||||||
tenant_llm = []
|
|
||||||
|
|
||||||
seen = set()
|
tenant_llm = get_init_tenant_llm(user_id)
|
||||||
factory_configs = []
|
|
||||||
for factory_config in [
|
|
||||||
settings.CHAT_CFG,
|
|
||||||
settings.EMBEDDING_CFG,
|
|
||||||
settings.ASR_CFG,
|
|
||||||
settings.IMAGE2TEXT_CFG,
|
|
||||||
settings.RERANK_CFG,
|
|
||||||
]:
|
|
||||||
factory_name = factory_config["factory"]
|
|
||||||
if factory_name not in seen:
|
|
||||||
seen.add(factory_name)
|
|
||||||
factory_configs.append(factory_config)
|
|
||||||
|
|
||||||
for factory_config in factory_configs:
|
|
||||||
for llm in LLMService.query(fid=factory_config["factory"]):
|
|
||||||
tenant_llm.append(
|
|
||||||
{
|
|
||||||
"tenant_id": user_id,
|
|
||||||
"llm_factory": factory_config["factory"],
|
|
||||||
"llm_name": llm.llm_name,
|
|
||||||
"model_type": llm.model_type,
|
|
||||||
"api_key": factory_config["api_key"],
|
|
||||||
"api_base": factory_config["base_url"],
|
|
||||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if settings.LIGHTEN != 1:
|
|
||||||
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
|
|
||||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
|
|
||||||
tenant_llm.append(
|
|
||||||
{
|
|
||||||
"tenant_id": user_id,
|
|
||||||
"llm_factory": fid,
|
|
||||||
"llm_name": mdlnm,
|
|
||||||
"model_type": "embedding",
|
|
||||||
"api_key": "",
|
|
||||||
"api_base": "",
|
|
||||||
"max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
unique = {}
|
|
||||||
for item in tenant_llm:
|
|
||||||
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
|
||||||
if key not in unique:
|
|
||||||
unique[key] = item
|
|
||||||
tenant_llm = list(unique.values())
|
|
||||||
|
|
||||||
if not UserService.save(**user):
|
if not UserService.save(**user):
|
||||||
return
|
return
|
||||||
|
|||||||
@ -27,7 +27,8 @@ from api.db.services import UserService
|
|||||||
from api.db.services.canvas_service import CanvasTemplateService
|
from api.db.services.canvas_service import CanvasTemplateService
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||||
|
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
@ -64,43 +65,7 @@ def init_superuser():
|
|||||||
"role": UserTenantRole.OWNER
|
"role": UserTenantRole.OWNER
|
||||||
}
|
}
|
||||||
|
|
||||||
user_id = user_info
|
tenant_llm = get_init_tenant_llm(user_info["id"])
|
||||||
tenant_llm = []
|
|
||||||
|
|
||||||
seen = set()
|
|
||||||
factory_configs = []
|
|
||||||
for factory_config in [
|
|
||||||
settings.CHAT_CFG["factory"],
|
|
||||||
settings.EMBEDDING_CFG["factory"],
|
|
||||||
settings.ASR_CFG["factory"],
|
|
||||||
settings.IMAGE2TEXT_CFG["factory"],
|
|
||||||
settings.RERANK_CFG["factory"],
|
|
||||||
]:
|
|
||||||
factory_name = factory_config["factory"]
|
|
||||||
if factory_name not in seen:
|
|
||||||
seen.add(factory_name)
|
|
||||||
factory_configs.append(factory_config)
|
|
||||||
|
|
||||||
for factory_config in factory_configs:
|
|
||||||
for llm in LLMService.query(fid=factory_config["factory"]):
|
|
||||||
tenant_llm.append(
|
|
||||||
{
|
|
||||||
"tenant_id": user_id,
|
|
||||||
"llm_factory": factory_config["factory"],
|
|
||||||
"llm_name": llm.llm_name,
|
|
||||||
"model_type": llm.model_type,
|
|
||||||
"api_key": factory_config["api_key"],
|
|
||||||
"api_base": factory_config["base_url"],
|
|
||||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
unique = {}
|
|
||||||
for item in tenant_llm:
|
|
||||||
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
|
||||||
if key not in unique:
|
|
||||||
unique[key] = item
|
|
||||||
tenant_llm = list(unique.values())
|
|
||||||
|
|
||||||
if not UserService.save(**user_info):
|
if not UserService.save(**user_info):
|
||||||
logging.error("can't init admin.")
|
logging.error("can't init admin.")
|
||||||
|
|||||||
@ -33,7 +33,8 @@ from api.db.services.common_service import CommonService
|
|||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.langfuse_service import TenantLangfuseService
|
from api.db.services.langfuse_service import TenantLangfuseService
|
||||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.utils import current_timestamp, datetime_format
|
from api.utils import current_timestamp, datetime_format
|
||||||
from rag.app.resume import forbidden_select_fields4resume
|
from rag.app.resume import forbidden_select_fields4resume
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
|
|||||||
@ -18,246 +18,73 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
from api.db.db_models import LLM
|
||||||
from langfuse import Langfuse
|
|
||||||
|
|
||||||
from api import settings
|
|
||||||
from api.db import LLMType
|
|
||||||
from api.db.db_models import DB, LLM, LLMFactories, TenantLLM
|
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.langfuse_service import TenantLangfuseService
|
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||||
from api.db.services.user_service import TenantService
|
|
||||||
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
|
|
||||||
|
|
||||||
|
|
||||||
class LLMFactoriesService(CommonService):
|
|
||||||
model = LLMFactories
|
|
||||||
|
|
||||||
|
|
||||||
class LLMService(CommonService):
|
class LLMService(CommonService):
|
||||||
model = LLM
|
model = LLM
|
||||||
|
|
||||||
|
|
||||||
class TenantLLMService(CommonService):
|
def get_init_tenant_llm(user_id):
|
||||||
model = TenantLLM
|
from api import settings
|
||||||
|
tenant_llm = []
|
||||||
|
|
||||||
@classmethod
|
seen = set()
|
||||||
@DB.connection_context()
|
factory_configs = []
|
||||||
def get_api_key(cls, tenant_id, model_name):
|
for factory_config in [
|
||||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
settings.CHAT_CFG,
|
||||||
if not fid:
|
settings.EMBEDDING_CFG,
|
||||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
settings.ASR_CFG,
|
||||||
else:
|
settings.IMAGE2TEXT_CFG,
|
||||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
settings.RERANK_CFG,
|
||||||
|
]:
|
||||||
|
factory_name = factory_config["factory"]
|
||||||
|
if factory_name not in seen:
|
||||||
|
seen.add(factory_name)
|
||||||
|
factory_configs.append(factory_config)
|
||||||
|
|
||||||
if (not objs) and fid:
|
for factory_config in factory_configs:
|
||||||
if fid == "LocalAI":
|
for llm in LLMService.query(fid=factory_config["factory"]):
|
||||||
mdlnm += "___LocalAI"
|
tenant_llm.append(
|
||||||
elif fid == "HuggingFace":
|
{
|
||||||
mdlnm += "___HuggingFace"
|
"tenant_id": user_id,
|
||||||
elif fid == "OpenAI-API-Compatible":
|
"llm_factory": factory_config["factory"],
|
||||||
mdlnm += "___OpenAI-API"
|
"llm_name": llm.llm_name,
|
||||||
elif fid == "VLLM":
|
"model_type": llm.model_type,
|
||||||
mdlnm += "___VLLM"
|
"api_key": factory_config["api_key"],
|
||||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
"api_base": factory_config["base_url"],
|
||||||
if not objs:
|
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||||
return
|
}
|
||||||
return objs[0]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@DB.connection_context()
|
|
||||||
def get_my_llms(cls, tenant_id):
|
|
||||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
|
|
||||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
|
||||||
|
|
||||||
return list(objs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def split_model_name_and_factory(model_name):
|
|
||||||
arr = model_name.split("@")
|
|
||||||
if len(arr) < 2:
|
|
||||||
return model_name, None
|
|
||||||
if len(arr) > 2:
|
|
||||||
return "@".join(arr[0:-1]), arr[-1]
|
|
||||||
|
|
||||||
# model name must be xxx@yyy
|
|
||||||
try:
|
|
||||||
model_factories = settings.FACTORY_LLM_INFOS
|
|
||||||
model_providers = set([f["name"] for f in model_factories])
|
|
||||||
if arr[-1] not in model_providers:
|
|
||||||
return model_name, None
|
|
||||||
return arr[0], arr[-1]
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
|
||||||
return model_name, None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@DB.connection_context()
|
|
||||||
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
|
||||||
e, tenant = TenantService.get_by_id(tenant_id)
|
|
||||||
if not e:
|
|
||||||
raise LookupError("Tenant not found")
|
|
||||||
|
|
||||||
if llm_type == LLMType.EMBEDDING.value:
|
|
||||||
mdlnm = tenant.embd_id if not llm_name else llm_name
|
|
||||||
elif llm_type == LLMType.SPEECH2TEXT.value:
|
|
||||||
mdlnm = tenant.asr_id
|
|
||||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
|
||||||
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
|
||||||
elif llm_type == LLMType.CHAT.value:
|
|
||||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
|
||||||
elif llm_type == LLMType.RERANK:
|
|
||||||
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
|
||||||
elif llm_type == LLMType.TTS:
|
|
||||||
mdlnm = tenant.tts_id if not llm_name else llm_name
|
|
||||||
else:
|
|
||||||
assert False, "LLM type error"
|
|
||||||
|
|
||||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
|
||||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
|
||||||
if not model_config: # for some cases seems fid mismatch
|
|
||||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
|
||||||
if model_config:
|
|
||||||
model_config = model_config.to_dict()
|
|
||||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
|
||||||
if not llm and fid: # for some cases seems fid mismatch
|
|
||||||
llm = LLMService.query(llm_name=mdlnm)
|
|
||||||
if llm:
|
|
||||||
model_config["is_tools"] = llm[0].is_tools
|
|
||||||
if not model_config:
|
|
||||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
|
||||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
|
||||||
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
|
||||||
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
|
||||||
if not model_config:
|
|
||||||
if mdlnm == "flag-embedding":
|
|
||||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
|
|
||||||
else:
|
|
||||||
if not mdlnm:
|
|
||||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
|
||||||
raise LookupError("Model({}) not authorized".format(mdlnm))
|
|
||||||
return model_config
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@DB.connection_context()
|
|
||||||
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
|
||||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
|
||||||
kwargs.update({"provider": model_config["llm_factory"]})
|
|
||||||
if llm_type == LLMType.EMBEDDING.value:
|
|
||||||
if model_config["llm_factory"] not in EmbeddingModel:
|
|
||||||
return
|
|
||||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
|
||||||
|
|
||||||
if llm_type == LLMType.RERANK:
|
|
||||||
if model_config["llm_factory"] not in RerankModel:
|
|
||||||
return
|
|
||||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
|
||||||
|
|
||||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
|
||||||
if model_config["llm_factory"] not in CvModel:
|
|
||||||
return
|
|
||||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
|
||||||
|
|
||||||
if llm_type == LLMType.CHAT.value:
|
|
||||||
if model_config["llm_factory"] not in ChatModel:
|
|
||||||
return
|
|
||||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
|
||||||
|
|
||||||
if llm_type == LLMType.SPEECH2TEXT:
|
|
||||||
if model_config["llm_factory"] not in Seq2txtModel:
|
|
||||||
return
|
|
||||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
|
||||||
if llm_type == LLMType.TTS:
|
|
||||||
if model_config["llm_factory"] not in TTSModel:
|
|
||||||
return
|
|
||||||
return TTSModel[model_config["llm_factory"]](
|
|
||||||
model_config["api_key"],
|
|
||||||
model_config["llm_name"],
|
|
||||||
base_url=model_config["api_base"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
if settings.LIGHTEN != 1:
|
||||||
@DB.connection_context()
|
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
|
||||||
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
|
||||||
e, tenant = TenantService.get_by_id(tenant_id)
|
tenant_llm.append(
|
||||||
if not e:
|
{
|
||||||
logging.error(f"Tenant not found: {tenant_id}")
|
"tenant_id": user_id,
|
||||||
return 0
|
"llm_factory": fid,
|
||||||
|
"llm_name": mdlnm,
|
||||||
llm_map = {
|
"model_type": "embedding",
|
||||||
LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
|
"api_key": "",
|
||||||
LLMType.SPEECH2TEXT.value: tenant.asr_id,
|
"api_base": "",
|
||||||
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
|
"max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
|
||||||
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
}
|
||||||
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
|
||||||
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
mdlnm = llm_map.get(llm_type)
|
|
||||||
if mdlnm is None:
|
|
||||||
logging.error(f"LLM type error: {llm_type}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
|
||||||
|
|
||||||
try:
|
|
||||||
num = (
|
|
||||||
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
|
||||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
|
||||||
.execute()
|
|
||||||
)
|
)
|
||||||
except Exception:
|
|
||||||
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
return num
|
unique = {}
|
||||||
|
for item in tenant_llm:
|
||||||
@classmethod
|
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
||||||
@DB.connection_context()
|
if key not in unique:
|
||||||
def get_openai_models(cls):
|
unique[key] = item
|
||||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
return list(unique.values())
|
||||||
return list(objs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
|
||||||
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
|
||||||
llm_factories = settings.FACTORY_LLM_INFOS
|
|
||||||
for llm_factory in llm_factories:
|
|
||||||
for llm in llm_factory["llm"]:
|
|
||||||
if llm_id == llm["llm_name"]:
|
|
||||||
return llm["model_type"].split(",")[-1]
|
|
||||||
|
|
||||||
for llm in LLMService.query(llm_name=llm_id):
|
|
||||||
return llm.model_type
|
|
||||||
|
|
||||||
llm = TenantLLMService.get_or_none(llm_name=llm_id)
|
|
||||||
if llm:
|
|
||||||
return llm.model_type
|
|
||||||
for llm in TenantLLMService.query(llm_name=llm_id):
|
|
||||||
return llm.model_type
|
|
||||||
|
|
||||||
|
|
||||||
class LLMBundle:
|
class LLMBundle(LLM4Tenant):
|
||||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||||
self.tenant_id = tenant_id
|
super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs)
|
||||||
self.llm_type = llm_type
|
|
||||||
self.llm_name = llm_name
|
|
||||||
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
|
|
||||||
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
|
||||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
|
||||||
self.max_length = model_config.get("max_tokens", 8192)
|
|
||||||
|
|
||||||
self.is_tools = model_config.get("is_tools", False)
|
|
||||||
self.verbose_tool_use = kwargs.get("verbose_tool_use")
|
|
||||||
|
|
||||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
|
||||||
self.langfuse = None
|
|
||||||
if langfuse_keys:
|
|
||||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
|
||||||
if langfuse.auth_check():
|
|
||||||
self.langfuse = langfuse
|
|
||||||
trace_id = self.langfuse.create_trace_id()
|
|
||||||
self.trace_context = {"trace_id": trace_id}
|
|
||||||
|
|
||||||
def bind_tools(self, toolcall_session, tools):
|
def bind_tools(self, toolcall_session, tools):
|
||||||
if not self.is_tools:
|
if not self.is_tools:
|
||||||
|
|||||||
252
api/db/services/tenant_llm_service.py
Normal file
252
api/db/services/tenant_llm_service.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import logging
|
||||||
|
from langfuse import Langfuse
|
||||||
|
from api import settings
|
||||||
|
from api.db import LLMType
|
||||||
|
from api.db.db_models import DB, LLMFactories, TenantLLM
|
||||||
|
from api.db.services.common_service import CommonService
|
||||||
|
from api.db.services.langfuse_service import TenantLangfuseService
|
||||||
|
from api.db.services.user_service import TenantService
|
||||||
|
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
|
||||||
|
|
||||||
|
|
||||||
|
class LLMFactoriesService(CommonService):
|
||||||
|
model = LLMFactories
|
||||||
|
|
||||||
|
|
||||||
|
class TenantLLMService(CommonService):
|
||||||
|
model = TenantLLM
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_api_key(cls, tenant_id, model_name):
|
||||||
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
||||||
|
if not fid:
|
||||||
|
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
||||||
|
else:
|
||||||
|
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||||
|
|
||||||
|
if (not objs) and fid:
|
||||||
|
if fid == "LocalAI":
|
||||||
|
mdlnm += "___LocalAI"
|
||||||
|
elif fid == "HuggingFace":
|
||||||
|
mdlnm += "___HuggingFace"
|
||||||
|
elif fid == "OpenAI-API-Compatible":
|
||||||
|
mdlnm += "___OpenAI-API"
|
||||||
|
elif fid == "VLLM":
|
||||||
|
mdlnm += "___VLLM"
|
||||||
|
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||||
|
if not objs:
|
||||||
|
return
|
||||||
|
return objs[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_my_llms(cls, tenant_id):
|
||||||
|
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
|
||||||
|
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||||
|
|
||||||
|
return list(objs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_model_name_and_factory(model_name):
|
||||||
|
arr = model_name.split("@")
|
||||||
|
if len(arr) < 2:
|
||||||
|
return model_name, None
|
||||||
|
if len(arr) > 2:
|
||||||
|
return "@".join(arr[0:-1]), arr[-1]
|
||||||
|
|
||||||
|
# model name must be xxx@yyy
|
||||||
|
try:
|
||||||
|
model_factories = settings.FACTORY_LLM_INFOS
|
||||||
|
model_providers = set([f["name"] for f in model_factories])
|
||||||
|
if arr[-1] not in model_providers:
|
||||||
|
return model_name, None
|
||||||
|
return arr[0], arr[-1]
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
||||||
|
return model_name, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
||||||
|
from api.db.services.llm_service import LLMService
|
||||||
|
e, tenant = TenantService.get_by_id(tenant_id)
|
||||||
|
if not e:
|
||||||
|
raise LookupError("Tenant not found")
|
||||||
|
|
||||||
|
if llm_type == LLMType.EMBEDDING.value:
|
||||||
|
mdlnm = tenant.embd_id if not llm_name else llm_name
|
||||||
|
elif llm_type == LLMType.SPEECH2TEXT.value:
|
||||||
|
mdlnm = tenant.asr_id
|
||||||
|
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||||
|
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
||||||
|
elif llm_type == LLMType.CHAT.value:
|
||||||
|
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||||
|
elif llm_type == LLMType.RERANK:
|
||||||
|
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
||||||
|
elif llm_type == LLMType.TTS:
|
||||||
|
mdlnm = tenant.tts_id if not llm_name else llm_name
|
||||||
|
else:
|
||||||
|
assert False, "LLM type error"
|
||||||
|
|
||||||
|
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||||
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||||
|
if not model_config: # for some cases seems fid mismatch
|
||||||
|
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||||
|
if model_config:
|
||||||
|
model_config = model_config.to_dict()
|
||||||
|
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||||
|
if not llm and fid: # for some cases seems fid mismatch
|
||||||
|
llm = LLMService.query(llm_name=mdlnm)
|
||||||
|
if llm:
|
||||||
|
model_config["is_tools"] = llm[0].is_tools
|
||||||
|
if not model_config:
|
||||||
|
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||||
|
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||||
|
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
||||||
|
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
||||||
|
if not model_config:
|
||||||
|
if mdlnm == "flag-embedding":
|
||||||
|
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
|
||||||
|
else:
|
||||||
|
if not mdlnm:
|
||||||
|
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||||
|
raise LookupError("Model({}) not authorized".format(mdlnm))
|
||||||
|
return model_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||||
|
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||||
|
kwargs.update({"provider": model_config["llm_factory"]})
|
||||||
|
if llm_type == LLMType.EMBEDDING.value:
|
||||||
|
if model_config["llm_factory"] not in EmbeddingModel:
|
||||||
|
return
|
||||||
|
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||||
|
|
||||||
|
if llm_type == LLMType.RERANK:
|
||||||
|
if model_config["llm_factory"] not in RerankModel:
|
||||||
|
return
|
||||||
|
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||||
|
|
||||||
|
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||||
|
if model_config["llm_factory"] not in CvModel:
|
||||||
|
return
|
||||||
|
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
||||||
|
|
||||||
|
if llm_type == LLMType.CHAT.value:
|
||||||
|
if model_config["llm_factory"] not in ChatModel:
|
||||||
|
return
|
||||||
|
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
||||||
|
|
||||||
|
if llm_type == LLMType.SPEECH2TEXT:
|
||||||
|
if model_config["llm_factory"] not in Seq2txtModel:
|
||||||
|
return
|
||||||
|
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
||||||
|
if llm_type == LLMType.TTS:
|
||||||
|
if model_config["llm_factory"] not in TTSModel:
|
||||||
|
return
|
||||||
|
return TTSModel[model_config["llm_factory"]](
|
||||||
|
model_config["api_key"],
|
||||||
|
model_config["llm_name"],
|
||||||
|
base_url=model_config["api_base"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
|
||||||
|
e, tenant = TenantService.get_by_id(tenant_id)
|
||||||
|
if not e:
|
||||||
|
logging.error(f"Tenant not found: {tenant_id}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
llm_map = {
|
||||||
|
LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
|
||||||
|
LLMType.SPEECH2TEXT.value: tenant.asr_id,
|
||||||
|
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
|
||||||
|
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
||||||
|
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
||||||
|
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
mdlnm = llm_map.get(llm_type)
|
||||||
|
if mdlnm is None:
|
||||||
|
logging.error(f"LLM type error: {llm_type}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||||
|
|
||||||
|
try:
|
||||||
|
num = (
|
||||||
|
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
||||||
|
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return num
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_openai_models(cls):
|
||||||
|
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||||
|
return list(objs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||||
|
from api.db.services.llm_service import LLMService
|
||||||
|
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||||
|
llm_factories = settings.FACTORY_LLM_INFOS
|
||||||
|
for llm_factory in llm_factories:
|
||||||
|
for llm in llm_factory["llm"]:
|
||||||
|
if llm_id == llm["llm_name"]:
|
||||||
|
return llm["model_type"].split(",")[-1]
|
||||||
|
|
||||||
|
for llm in LLMService.query(llm_name=llm_id):
|
||||||
|
return llm.model_type
|
||||||
|
|
||||||
|
llm = TenantLLMService.get_or_none(llm_name=llm_id)
|
||||||
|
if llm:
|
||||||
|
return llm.model_type
|
||||||
|
for llm in TenantLLMService.query(llm_name=llm_id):
|
||||||
|
return llm.model_type
|
||||||
|
|
||||||
|
|
||||||
|
class LLM4Tenant:
|
||||||
|
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.llm_type = llm_type
|
||||||
|
self.llm_name = llm_name
|
||||||
|
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
|
||||||
|
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
||||||
|
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||||
|
self.max_length = model_config.get("max_tokens", 8192)
|
||||||
|
|
||||||
|
self.is_tools = model_config.get("is_tools", False)
|
||||||
|
self.verbose_tool_use = kwargs.get("verbose_tool_use")
|
||||||
|
|
||||||
|
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
||||||
|
self.langfuse = None
|
||||||
|
if langfuse_keys:
|
||||||
|
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||||||
|
if langfuse.auth_check():
|
||||||
|
self.langfuse = langfuse
|
||||||
|
trace_id = self.langfuse.create_trace_id()
|
||||||
|
self.trace_context = {"trace_id": trace_id}
|
||||||
@ -48,7 +48,8 @@ from werkzeug.http import HTTP_STATUS_CODES
|
|||||||
from api import settings
|
from api import settings
|
||||||
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
from api.db.services.llm_service import LLMService, TenantLLMService
|
from api.db.services.llm_service import LLMService
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||||
|
|
||||||
|
|||||||
@ -197,7 +197,7 @@ def question_proposal(chat_mdl, content, topn=3):
|
|||||||
def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
|
def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
|
|
||||||
if not chat_mdl:
|
if not chat_mdl:
|
||||||
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||||
@ -231,7 +231,7 @@ def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_
|
|||||||
def cross_languages(tenant_id, llm_id, query, languages=[]):
|
def cross_languages(tenant_id, llm_id, query, languages=[]):
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
|
|
||||||
if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user