mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Refa: split services about llm. (#9450)
### What problem does this PR solve? ### Type of change - [x] Refactoring
This commit is contained in:
@ -28,8 +28,8 @@ from api.db.db_models import APIToken
|
||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||
from api.db.services.dialog_service import DialogService, ask, chat
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle, TenantService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import UserTenantService, TenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
from rag.app.tag import label_question
|
||||
|
||||
@ -18,7 +18,7 @@ from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api import settings
|
||||
|
||||
@ -17,7 +17,8 @@ import logging
|
||||
import json
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api import settings
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.db import StatusEnum, LLMType
|
||||
|
||||
@ -21,7 +21,7 @@ from api import settings
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
|
||||
|
||||
@ -32,7 +32,8 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.task_service import TaskService, queue_tasks
|
||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
|
||||
@ -16,20 +16,17 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
import tiktoken
|
||||
from flask import Response, jsonify, request
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.db import LLMType, StatusEnum
|
||||
from api.db.db_models import API4Conversation, APIToken
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService, completionOpenAI
|
||||
from api.db.services.canvas_service import completion as agent_completion
|
||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||
from api.db.services.conversation_service import completion as rag_completion
|
||||
from api.db.services.dialog_service import DialogService, ask, chat
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.utils import get_uuid
|
||||
|
||||
@ -28,7 +28,7 @@ from api.apps.auth import get_auth_client
|
||||
from api.db import FileType, UserTenantRole
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService
|
||||
from api.db.services.llm_service import TenantLLMService, get_init_tenant_llm
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
from api.utils import (
|
||||
current_timestamp,
|
||||
@ -619,57 +619,8 @@ def user_register(user_id, user):
|
||||
"size": 0,
|
||||
"location": "",
|
||||
}
|
||||
tenant_llm = []
|
||||
|
||||
seen = set()
|
||||
factory_configs = []
|
||||
for factory_config in [
|
||||
settings.CHAT_CFG,
|
||||
settings.EMBEDDING_CFG,
|
||||
settings.ASR_CFG,
|
||||
settings.IMAGE2TEXT_CFG,
|
||||
settings.RERANK_CFG,
|
||||
]:
|
||||
factory_name = factory_config["factory"]
|
||||
if factory_name not in seen:
|
||||
seen.add(factory_name)
|
||||
factory_configs.append(factory_config)
|
||||
|
||||
for factory_config in factory_configs:
|
||||
for llm in LLMService.query(fid=factory_config["factory"]):
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": factory_config["factory"],
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": factory_config["api_key"],
|
||||
"api_base": factory_config["base_url"],
|
||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||
}
|
||||
)
|
||||
|
||||
if settings.LIGHTEN != 1:
|
||||
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": fid,
|
||||
"llm_name": mdlnm,
|
||||
"model_type": "embedding",
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
|
||||
}
|
||||
)
|
||||
|
||||
unique = {}
|
||||
for item in tenant_llm:
|
||||
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
||||
if key not in unique:
|
||||
unique[key] = item
|
||||
tenant_llm = list(unique.values())
|
||||
tenant_llm = get_init_tenant_llm(user_id)
|
||||
|
||||
if not UserService.save(**user):
|
||||
return
|
||||
|
||||
@ -27,7 +27,8 @@ from api.db.services import UserService
|
||||
from api.db.services.canvas_service import CanvasTemplateService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
@ -64,43 +65,7 @@ def init_superuser():
|
||||
"role": UserTenantRole.OWNER
|
||||
}
|
||||
|
||||
user_id = user_info
|
||||
tenant_llm = []
|
||||
|
||||
seen = set()
|
||||
factory_configs = []
|
||||
for factory_config in [
|
||||
settings.CHAT_CFG["factory"],
|
||||
settings.EMBEDDING_CFG["factory"],
|
||||
settings.ASR_CFG["factory"],
|
||||
settings.IMAGE2TEXT_CFG["factory"],
|
||||
settings.RERANK_CFG["factory"],
|
||||
]:
|
||||
factory_name = factory_config["factory"]
|
||||
if factory_name not in seen:
|
||||
seen.add(factory_name)
|
||||
factory_configs.append(factory_config)
|
||||
|
||||
for factory_config in factory_configs:
|
||||
for llm in LLMService.query(fid=factory_config["factory"]):
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": factory_config["factory"],
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": factory_config["api_key"],
|
||||
"api_base": factory_config["base_url"],
|
||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||
}
|
||||
)
|
||||
|
||||
unique = {}
|
||||
for item in tenant_llm:
|
||||
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
||||
if key not in unique:
|
||||
unique[key] = item
|
||||
tenant_llm = list(unique.values())
|
||||
tenant_llm = get_init_tenant_llm(user_info["id"])
|
||||
|
||||
if not UserService.save(**user_info):
|
||||
logging.error("can't init admin.")
|
||||
|
||||
@ -33,7 +33,8 @@ from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.app.tag import label_question
|
||||
|
||||
@ -18,246 +18,73 @@ import logging
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
|
||||
from langfuse import Langfuse
|
||||
|
||||
from api import settings
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import DB, LLM, LLMFactories, TenantLLM
|
||||
from api.db.db_models import LLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.user_service import TenantService
|
||||
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
|
||||
|
||||
|
||||
class LLMFactoriesService(CommonService):
|
||||
model = LLMFactories
|
||||
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||
|
||||
|
||||
class LLMService(CommonService):
|
||||
model = LLM
|
||||
|
||||
|
||||
class TenantLLMService(CommonService):
|
||||
model = TenantLLM
|
||||
def get_init_tenant_llm(user_id):
|
||||
from api import settings
|
||||
tenant_llm = []
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_api_key(cls, tenant_id, model_name):
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
||||
if not fid:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
||||
else:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
seen = set()
|
||||
factory_configs = []
|
||||
for factory_config in [
|
||||
settings.CHAT_CFG,
|
||||
settings.EMBEDDING_CFG,
|
||||
settings.ASR_CFG,
|
||||
settings.IMAGE2TEXT_CFG,
|
||||
settings.RERANK_CFG,
|
||||
]:
|
||||
factory_name = factory_config["factory"]
|
||||
if factory_name not in seen:
|
||||
seen.add(factory_name)
|
||||
factory_configs.append(factory_config)
|
||||
|
||||
if (not objs) and fid:
|
||||
if fid == "LocalAI":
|
||||
mdlnm += "___LocalAI"
|
||||
elif fid == "HuggingFace":
|
||||
mdlnm += "___HuggingFace"
|
||||
elif fid == "OpenAI-API-Compatible":
|
||||
mdlnm += "___OpenAI-API"
|
||||
elif fid == "VLLM":
|
||||
mdlnm += "___VLLM"
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
if not objs:
|
||||
return
|
||||
return objs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
|
||||
return list(objs)
|
||||
|
||||
@staticmethod
|
||||
def split_model_name_and_factory(model_name):
|
||||
arr = model_name.split("@")
|
||||
if len(arr) < 2:
|
||||
return model_name, None
|
||||
if len(arr) > 2:
|
||||
return "@".join(arr[0:-1]), arr[-1]
|
||||
|
||||
# model name must be xxx@yyy
|
||||
try:
|
||||
model_factories = settings.FACTORY_LLM_INFOS
|
||||
model_providers = set([f["name"] for f in model_factories])
|
||||
if arr[-1] not in model_providers:
|
||||
return model_name, None
|
||||
return arr[0], arr[-1]
|
||||
except Exception as e:
|
||||
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
||||
return model_name, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
raise LookupError("Tenant not found")
|
||||
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
mdlnm = tenant.embd_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.SPEECH2TEXT.value:
|
||||
mdlnm = tenant.asr_id
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.RERANK:
|
||||
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.TTS:
|
||||
mdlnm = tenant.tts_id if not llm_name else llm_name
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
if not model_config: # for some cases seems fid mismatch
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
if model_config:
|
||||
model_config = model_config.to_dict()
|
||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||
if not llm and fid: # for some cases seems fid mismatch
|
||||
llm = LLMService.query(llm_name=mdlnm)
|
||||
if llm:
|
||||
model_config["is_tools"] = llm[0].is_tools
|
||||
if not model_config:
|
||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
||||
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
||||
if not model_config:
|
||||
if mdlnm == "flag-embedding":
|
||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
|
||||
else:
|
||||
if not mdlnm:
|
||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||
raise LookupError("Model({}) not authorized".format(mdlnm))
|
||||
return model_config
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
kwargs.update({"provider": model_config["llm_factory"]})
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
return
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.RERANK:
|
||||
if model_config["llm_factory"] not in RerankModel:
|
||||
return
|
||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.CHAT.value:
|
||||
if model_config["llm_factory"] not in ChatModel:
|
||||
return
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.SPEECH2TEXT:
|
||||
if model_config["llm_factory"] not in Seq2txtModel:
|
||||
return
|
||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
||||
if llm_type == LLMType.TTS:
|
||||
if model_config["llm_factory"] not in TTSModel:
|
||||
return
|
||||
return TTSModel[model_config["llm_factory"]](
|
||||
model_config["api_key"],
|
||||
model_config["llm_name"],
|
||||
base_url=model_config["api_base"],
|
||||
for factory_config in factory_configs:
|
||||
for llm in LLMService.query(fid=factory_config["factory"]):
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": factory_config["factory"],
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": factory_config["api_key"],
|
||||
"api_base": factory_config["base_url"],
|
||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
logging.error(f"Tenant not found: {tenant_id}")
|
||||
return 0
|
||||
|
||||
llm_map = {
|
||||
LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
|
||||
LLMType.SPEECH2TEXT.value: tenant.asr_id,
|
||||
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
|
||||
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
||||
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
||||
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
||||
}
|
||||
|
||||
mdlnm = llm_map.get(llm_type)
|
||||
if mdlnm is None:
|
||||
logging.error(f"LLM type error: {llm_type}")
|
||||
return 0
|
||||
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
|
||||
try:
|
||||
num = (
|
||||
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||
.execute()
|
||||
if settings.LIGHTEN != 1:
|
||||
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": fid,
|
||||
"llm_name": mdlnm,
|
||||
"model_type": "embedding",
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
||||
return 0
|
||||
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_openai_models(cls):
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
return list(objs)
|
||||
|
||||
@staticmethod
|
||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||
llm_factories = settings.FACTORY_LLM_INFOS
|
||||
for llm_factory in llm_factories:
|
||||
for llm in llm_factory["llm"]:
|
||||
if llm_id == llm["llm_name"]:
|
||||
return llm["model_type"].split(",")[-1]
|
||||
|
||||
for llm in LLMService.query(llm_name=llm_id):
|
||||
return llm.model_type
|
||||
|
||||
llm = TenantLLMService.get_or_none(llm_name=llm_id)
|
||||
if llm:
|
||||
return llm.model_type
|
||||
for llm in TenantLLMService.query(llm_name=llm_id):
|
||||
return llm.model_type
|
||||
unique = {}
|
||||
for item in tenant_llm:
|
||||
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
||||
if key not in unique:
|
||||
unique[key] = item
|
||||
return list(unique.values())
|
||||
|
||||
|
||||
class LLMBundle:
|
||||
class LLMBundle(LLM4Tenant):
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_type = llm_type
|
||||
self.llm_name = llm_name
|
||||
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
|
||||
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
self.max_length = model_config.get("max_tokens", 8192)
|
||||
|
||||
self.is_tools = model_config.get("is_tools", False)
|
||||
self.verbose_tool_use = kwargs.get("verbose_tool_use")
|
||||
|
||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
||||
self.langfuse = None
|
||||
if langfuse_keys:
|
||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||||
if langfuse.auth_check():
|
||||
self.langfuse = langfuse
|
||||
trace_id = self.langfuse.create_trace_id()
|
||||
self.trace_context = {"trace_id": trace_id}
|
||||
super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs)
|
||||
|
||||
def bind_tools(self, toolcall_session, tools):
|
||||
if not self.is_tools:
|
||||
|
||||
252
api/db/services/tenant_llm_service.py
Normal file
252
api/db/services/tenant_llm_service.py
Normal file
@ -0,0 +1,252 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
from langfuse import Langfuse
|
||||
from api import settings
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import DB, LLMFactories, TenantLLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.user_service import TenantService
|
||||
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
|
||||
|
||||
|
||||
class LLMFactoriesService(CommonService):
|
||||
model = LLMFactories
|
||||
|
||||
|
||||
class TenantLLMService(CommonService):
|
||||
model = TenantLLM
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_api_key(cls, tenant_id, model_name):
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
||||
if not fid:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
||||
else:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
|
||||
if (not objs) and fid:
|
||||
if fid == "LocalAI":
|
||||
mdlnm += "___LocalAI"
|
||||
elif fid == "HuggingFace":
|
||||
mdlnm += "___HuggingFace"
|
||||
elif fid == "OpenAI-API-Compatible":
|
||||
mdlnm += "___OpenAI-API"
|
||||
elif fid == "VLLM":
|
||||
mdlnm += "___VLLM"
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
if not objs:
|
||||
return
|
||||
return objs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
|
||||
return list(objs)
|
||||
|
||||
@staticmethod
|
||||
def split_model_name_and_factory(model_name):
|
||||
arr = model_name.split("@")
|
||||
if len(arr) < 2:
|
||||
return model_name, None
|
||||
if len(arr) > 2:
|
||||
return "@".join(arr[0:-1]), arr[-1]
|
||||
|
||||
# model name must be xxx@yyy
|
||||
try:
|
||||
model_factories = settings.FACTORY_LLM_INFOS
|
||||
model_providers = set([f["name"] for f in model_factories])
|
||||
if arr[-1] not in model_providers:
|
||||
return model_name, None
|
||||
return arr[0], arr[-1]
|
||||
except Exception as e:
|
||||
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
||||
return model_name, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
||||
from api.db.services.llm_service import LLMService
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
raise LookupError("Tenant not found")
|
||||
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
mdlnm = tenant.embd_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.SPEECH2TEXT.value:
|
||||
mdlnm = tenant.asr_id
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.RERANK:
|
||||
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.TTS:
|
||||
mdlnm = tenant.tts_id if not llm_name else llm_name
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
if not model_config: # for some cases seems fid mismatch
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
if model_config:
|
||||
model_config = model_config.to_dict()
|
||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||
if not llm and fid: # for some cases seems fid mismatch
|
||||
llm = LLMService.query(llm_name=mdlnm)
|
||||
if llm:
|
||||
model_config["is_tools"] = llm[0].is_tools
|
||||
if not model_config:
|
||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
||||
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
||||
if not model_config:
|
||||
if mdlnm == "flag-embedding":
|
||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
|
||||
else:
|
||||
if not mdlnm:
|
||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||
raise LookupError("Model({}) not authorized".format(mdlnm))
|
||||
return model_config
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
kwargs.update({"provider": model_config["llm_factory"]})
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
return
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.RERANK:
|
||||
if model_config["llm_factory"] not in RerankModel:
|
||||
return
|
||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.CHAT.value:
|
||||
if model_config["llm_factory"] not in ChatModel:
|
||||
return
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.SPEECH2TEXT:
|
||||
if model_config["llm_factory"] not in Seq2txtModel:
|
||||
return
|
||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
||||
if llm_type == LLMType.TTS:
|
||||
if model_config["llm_factory"] not in TTSModel:
|
||||
return
|
||||
return TTSModel[model_config["llm_factory"]](
|
||||
model_config["api_key"],
|
||||
model_config["llm_name"],
|
||||
base_url=model_config["api_base"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
logging.error(f"Tenant not found: {tenant_id}")
|
||||
return 0
|
||||
|
||||
llm_map = {
|
||||
LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
|
||||
LLMType.SPEECH2TEXT.value: tenant.asr_id,
|
||||
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
|
||||
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
||||
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
||||
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
||||
}
|
||||
|
||||
mdlnm = llm_map.get(llm_type)
|
||||
if mdlnm is None:
|
||||
logging.error(f"LLM type error: {llm_type}")
|
||||
return 0
|
||||
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
|
||||
try:
|
||||
num = (
|
||||
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||
.execute()
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
||||
return 0
|
||||
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_openai_models(cls):
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
return list(objs)
|
||||
|
||||
@staticmethod
|
||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||
from api.db.services.llm_service import LLMService
|
||||
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||
llm_factories = settings.FACTORY_LLM_INFOS
|
||||
for llm_factory in llm_factories:
|
||||
for llm in llm_factory["llm"]:
|
||||
if llm_id == llm["llm_name"]:
|
||||
return llm["model_type"].split(",")[-1]
|
||||
|
||||
for llm in LLMService.query(llm_name=llm_id):
|
||||
return llm.model_type
|
||||
|
||||
llm = TenantLLMService.get_or_none(llm_name=llm_id)
|
||||
if llm:
|
||||
return llm.model_type
|
||||
for llm in TenantLLMService.query(llm_name=llm_id):
|
||||
return llm.model_type
|
||||
|
||||
|
||||
class LLM4Tenant:
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_type = llm_type
|
||||
self.llm_name = llm_name
|
||||
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
|
||||
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
self.max_length = model_config.get("max_tokens", 8192)
|
||||
|
||||
self.is_tools = model_config.get("is_tools", False)
|
||||
self.verbose_tool_use = kwargs.get("verbose_tool_use")
|
||||
|
||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
||||
self.langfuse = None
|
||||
if langfuse_keys:
|
||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||||
if langfuse.auth_check():
|
||||
self.langfuse = langfuse
|
||||
trace_id = self.langfuse.create_trace_id()
|
||||
self.trace_context = {"trace_id": trace_id}
|
||||
@ -48,7 +48,8 @@ from werkzeug.http import HTTP_STATUS_CODES
|
||||
from api import settings
|
||||
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
Reference in New Issue
Block a user