mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
### What problem does this PR solve? All models pass the mock response tests, which means that if a model can return the correct response, everything should work as expected. However, not all models have been fully tested in a real environment, the real API_KEY. I suggest actively monitoring the refactored models over the coming period to ensure they work correctly and fixing them step by step, or waiting to merge until most have been tested in practical environment. ### Type of change - [x] Refactoring
434 lines
19 KiB
Python
434 lines
19 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import logging
|
|
import re
|
|
from functools import partial
|
|
from typing import Generator
|
|
|
|
from langfuse import Langfuse
|
|
|
|
from api import settings
|
|
from api.db import LLMType
|
|
from api.db.db_models import DB, LLM, LLMFactories, TenantLLM
|
|
from api.db.services.common_service import CommonService
|
|
from api.db.services.langfuse_service import TenantLangfuseService
|
|
from api.db.services.user_service import TenantService
|
|
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
|
|
|
|
|
|
class LLMFactoriesService(CommonService):
|
|
model = LLMFactories
|
|
|
|
|
|
class LLMService(CommonService):
|
|
model = LLM
|
|
|
|
|
|
class TenantLLMService(CommonService):
|
|
model = TenantLLM
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_api_key(cls, tenant_id, model_name):
|
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
|
if not fid:
|
|
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
|
else:
|
|
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
|
|
|
if (not objs) and fid:
|
|
if fid == "LocalAI":
|
|
mdlnm += "___LocalAI"
|
|
elif fid == "HuggingFace":
|
|
mdlnm += "___HuggingFace"
|
|
elif fid == "OpenAI-API-Compatible":
|
|
mdlnm += "___OpenAI-API"
|
|
elif fid == "VLLM":
|
|
mdlnm += "___VLLM"
|
|
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
|
if not objs:
|
|
return
|
|
return objs[0]
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_my_llms(cls, tenant_id):
|
|
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
|
|
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
|
|
|
return list(objs)
|
|
|
|
@staticmethod
|
|
def split_model_name_and_factory(model_name):
|
|
arr = model_name.split("@")
|
|
if len(arr) < 2:
|
|
return model_name, None
|
|
if len(arr) > 2:
|
|
return "@".join(arr[0:-1]), arr[-1]
|
|
|
|
# model name must be xxx@yyy
|
|
try:
|
|
model_factories = settings.FACTORY_LLM_INFOS
|
|
model_providers = set([f["name"] for f in model_factories])
|
|
if arr[-1] not in model_providers:
|
|
return model_name, None
|
|
return arr[0], arr[-1]
|
|
except Exception as e:
|
|
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
|
return model_name, None
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
|
e, tenant = TenantService.get_by_id(tenant_id)
|
|
if not e:
|
|
raise LookupError("Tenant not found")
|
|
|
|
if llm_type == LLMType.EMBEDDING.value:
|
|
mdlnm = tenant.embd_id if not llm_name else llm_name
|
|
elif llm_type == LLMType.SPEECH2TEXT.value:
|
|
mdlnm = tenant.asr_id
|
|
elif llm_type == LLMType.IMAGE2TEXT.value:
|
|
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
|
elif llm_type == LLMType.CHAT.value:
|
|
mdlnm = tenant.llm_id if not llm_name else llm_name
|
|
elif llm_type == LLMType.RERANK:
|
|
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
|
elif llm_type == LLMType.TTS:
|
|
mdlnm = tenant.tts_id if not llm_name else llm_name
|
|
else:
|
|
assert False, "LLM type error"
|
|
|
|
model_config = cls.get_api_key(tenant_id, mdlnm)
|
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
|
if not model_config: # for some cases seems fid mismatch
|
|
model_config = cls.get_api_key(tenant_id, mdlnm)
|
|
if model_config:
|
|
model_config = model_config.to_dict()
|
|
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
|
if not llm and fid: # for some cases seems fid mismatch
|
|
llm = LLMService.query(llm_name=mdlnm)
|
|
if llm:
|
|
model_config["is_tools"] = llm[0].is_tools
|
|
if not model_config:
|
|
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
|
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
|
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
|
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
|
if not model_config:
|
|
if mdlnm == "flag-embedding":
|
|
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
|
|
else:
|
|
if not mdlnm:
|
|
raise LookupError(f"Type of {llm_type} model is not set.")
|
|
raise LookupError("Model({}) not authorized".format(mdlnm))
|
|
return model_config
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
|
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
|
kwargs.update({"provider": model_config["llm_factory"]})
|
|
if llm_type == LLMType.EMBEDDING.value:
|
|
if model_config["llm_factory"] not in EmbeddingModel:
|
|
return
|
|
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
|
|
|
if llm_type == LLMType.RERANK:
|
|
if model_config["llm_factory"] not in RerankModel:
|
|
return
|
|
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
|
|
|
if llm_type == LLMType.IMAGE2TEXT.value:
|
|
if model_config["llm_factory"] not in CvModel:
|
|
return
|
|
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
|
|
|
if llm_type == LLMType.CHAT.value:
|
|
if model_config["llm_factory"] not in ChatModel:
|
|
return
|
|
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
|
|
|
if llm_type == LLMType.SPEECH2TEXT:
|
|
if model_config["llm_factory"] not in Seq2txtModel:
|
|
return
|
|
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
|
if llm_type == LLMType.TTS:
|
|
if model_config["llm_factory"] not in TTSModel:
|
|
return
|
|
return TTSModel[model_config["llm_factory"]](
|
|
model_config["api_key"],
|
|
model_config["llm_name"],
|
|
base_url=model_config["api_base"],
|
|
)
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
|
|
e, tenant = TenantService.get_by_id(tenant_id)
|
|
if not e:
|
|
logging.error(f"Tenant not found: {tenant_id}")
|
|
return 0
|
|
|
|
llm_map = {
|
|
LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
|
|
LLMType.SPEECH2TEXT.value: tenant.asr_id,
|
|
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
|
|
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
|
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
|
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
|
}
|
|
|
|
mdlnm = llm_map.get(llm_type)
|
|
if mdlnm is None:
|
|
logging.error(f"LLM type error: {llm_type}")
|
|
return 0
|
|
|
|
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
|
|
|
try:
|
|
num = (
|
|
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
|
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
|
.execute()
|
|
)
|
|
except Exception:
|
|
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
|
return 0
|
|
|
|
return num
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_openai_models(cls):
|
|
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
|
return list(objs)
|
|
|
|
@staticmethod
|
|
def llm_id2llm_type(llm_id: str) -> str | None:
|
|
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
|
llm_factories = settings.FACTORY_LLM_INFOS
|
|
for llm_factory in llm_factories:
|
|
for llm in llm_factory["llm"]:
|
|
if llm_id == llm["llm_name"]:
|
|
return llm["model_type"].split(",")[-1]
|
|
|
|
for llm in LLMService.query(llm_name=llm_id):
|
|
return llm.model_type
|
|
|
|
llm = TenantLLMService.get_or_none(llm_name=llm_id)
|
|
if llm:
|
|
return llm.model_type
|
|
for llm in TenantLLMService.query(llm_name=llm_id):
|
|
return llm.model_type
|
|
|
|
|
|
class LLMBundle:
|
|
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
|
self.tenant_id = tenant_id
|
|
self.llm_type = llm_type
|
|
self.llm_name = llm_name
|
|
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
|
|
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
|
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
|
self.max_length = model_config.get("max_tokens", 8192)
|
|
|
|
self.is_tools = model_config.get("is_tools", False)
|
|
self.verbose_tool_use = kwargs.get("verbose_tool_use")
|
|
|
|
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
|
self.langfuse = None
|
|
if langfuse_keys:
|
|
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
|
if langfuse.auth_check():
|
|
self.langfuse = langfuse
|
|
trace_id = self.langfuse.create_trace_id()
|
|
self.trace_context = {"trace_id": trace_id}
|
|
|
|
def bind_tools(self, toolcall_session, tools):
|
|
if not self.is_tools:
|
|
logging.warning(f"Model {self.llm_name} does not support tool call, but you have assigned one or more tools to it!")
|
|
return
|
|
self.mdl.bind_tools(toolcall_session, tools)
|
|
|
|
def encode(self, texts: list):
|
|
if self.langfuse:
|
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
|
|
|
|
embeddings, used_tokens = self.mdl.encode(texts)
|
|
llm_name = getattr(self, "llm_name", None)
|
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
|
|
logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
|
|
|
if self.langfuse:
|
|
generation.update(usage_details={"total_tokens": used_tokens})
|
|
generation.end()
|
|
|
|
return embeddings, used_tokens
|
|
|
|
def encode_queries(self, query: str):
|
|
if self.langfuse:
|
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.llm_name, input={"query": query})
|
|
|
|
emd, used_tokens = self.mdl.encode_queries(query)
|
|
llm_name = getattr(self, "llm_name", None)
|
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
|
|
logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
|
|
|
if self.langfuse:
|
|
generation.update(usage_details={"total_tokens": used_tokens})
|
|
generation.end()
|
|
|
|
return emd, used_tokens
|
|
|
|
def similarity(self, query: str, texts: list):
|
|
if self.langfuse:
|
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.llm_name, input={"query": query, "texts": texts})
|
|
|
|
sim, used_tokens = self.mdl.similarity(query, texts)
|
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
|
logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
|
|
|
|
if self.langfuse:
|
|
generation.update(usage_details={"total_tokens": used_tokens})
|
|
generation.end()
|
|
|
|
return sim, used_tokens
|
|
|
|
def describe(self, image, max_tokens=300):
|
|
if self.langfuse:
|
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.llm_name})
|
|
|
|
txt, used_tokens = self.mdl.describe(image)
|
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
|
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
|
|
|
if self.langfuse:
|
|
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
|
generation.end()
|
|
|
|
return txt
|
|
|
|
def describe_with_prompt(self, image, prompt):
|
|
if self.langfuse:
|
|
generation = self.language.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.llm_name, "prompt": prompt})
|
|
|
|
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
|
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
|
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
|
|
|
if self.langfuse:
|
|
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
|
generation.end()
|
|
|
|
return txt
|
|
|
|
def transcription(self, audio):
|
|
if self.langfuse:
|
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.llm_name})
|
|
|
|
txt, used_tokens = self.mdl.transcription(audio)
|
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
|
logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
|
|
|
if self.langfuse:
|
|
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
|
generation.end()
|
|
|
|
return txt
|
|
|
|
def tts(self, text: str) -> Generator[bytes, None, None]:
|
|
if self.langfuse:
|
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text})
|
|
|
|
for chunk in self.mdl.tts(text):
|
|
if isinstance(chunk, int):
|
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
|
|
logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
|
return
|
|
yield chunk
|
|
|
|
if self.langfuse:
|
|
generation.end()
|
|
|
|
def _remove_reasoning_content(self, txt: str) -> str:
|
|
first_think_start = txt.find("<think>")
|
|
if first_think_start == -1:
|
|
return txt
|
|
|
|
last_think_end = txt.rfind("</think>")
|
|
if last_think_end == -1:
|
|
return txt
|
|
|
|
if last_think_end < first_think_start:
|
|
return txt
|
|
|
|
return txt[last_think_end + len("</think>") :]
|
|
|
|
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
|
|
if self.langfuse:
|
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
|
|
|
chat_partial = partial(self.mdl.chat, system, history, gen_conf)
|
|
if self.is_tools and self.mdl.is_tools:
|
|
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf)
|
|
|
|
txt, used_tokens = chat_partial(**kwargs)
|
|
txt = self._remove_reasoning_content(txt)
|
|
|
|
if not self.verbose_tool_use:
|
|
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
|
|
|
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
|
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
|
|
|
if self.langfuse:
|
|
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
|
generation.end()
|
|
|
|
return txt
|
|
|
|
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
|
if self.langfuse:
|
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
|
|
|
ans = ""
|
|
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
|
|
total_tokens = 0
|
|
if self.is_tools and self.mdl.is_tools:
|
|
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
|
|
|
|
for txt in chat_partial(**kwargs):
|
|
if isinstance(txt, int):
|
|
total_tokens = txt
|
|
if self.langfuse:
|
|
generation.update(output={"output": ans})
|
|
generation.end()
|
|
break
|
|
|
|
if txt.endswith("</think>"):
|
|
ans = ans.rstrip("</think>")
|
|
|
|
if not self.verbose_tool_use:
|
|
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
|
|
|
ans += txt
|
|
yield ans
|
|
|
|
if total_tokens > 0:
|
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
|
|
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
|