mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Redesign and refactor agent module (#9113)
### What problem does this PR solve? #9082 #6365 <u> **WARNING: it's not compatible with the older version of `Agent` module, which means that `Agent` from older versions can not work anymore.**</u> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
from langfuse import Langfuse
|
||||
|
||||
@ -137,7 +139,7 @@ class TenantLLMService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese"):
|
||||
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
@ -152,12 +154,12 @@ class TenantLLMService(CommonService):
|
||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"])
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.CHAT.value:
|
||||
if model_config["llm_factory"] not in ChatModel:
|
||||
return
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.SPEECH2TEXT:
|
||||
if model_config["llm_factory"] not in Seq2txtModel:
|
||||
@ -221,20 +223,21 @@ class TenantLLMService(CommonService):
|
||||
for llm_factory in llm_factories:
|
||||
for llm in llm_factory["llm"]:
|
||||
if llm_id == llm["llm_name"]:
|
||||
return llm["model_type"].strip(",")[-1]
|
||||
return llm["model_type"].split(",")[-1]
|
||||
|
||||
|
||||
class LLMBundle:
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_type = llm_type
|
||||
self.llm_name = llm_name
|
||||
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang)
|
||||
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
|
||||
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
self.max_length = model_config.get("max_tokens", 8192)
|
||||
|
||||
self.is_tools = model_config.get("is_tools", False)
|
||||
self.verbose_tool_use = kwargs.get("verbose_tool_use")
|
||||
|
||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
||||
if langfuse_keys:
|
||||
@ -331,7 +334,7 @@ class LLMBundle:
|
||||
|
||||
return txt
|
||||
|
||||
def tts(self, text):
|
||||
def tts(self, text: str) -> None:
|
||||
if self.langfuse:
|
||||
span = self.trace.span(name="tts", input={"text": text})
|
||||
|
||||
@ -359,17 +362,20 @@ class LLMBundle:
|
||||
|
||||
return txt[last_think_end + len("</think>") :]
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
def chat(self, system: str, history: list, gen_conf: dict={}, **kwargs) -> str:
|
||||
if self.langfuse:
|
||||
generation = self.trace.generation(name="chat", model=self.llm_name, input={"system": system, "history": history})
|
||||
|
||||
chat = self.mdl.chat
|
||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf)
|
||||
if self.is_tools and self.mdl.is_tools:
|
||||
chat = self.mdl.chat_with_tools
|
||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf)
|
||||
|
||||
txt, used_tokens = chat(system, history, gen_conf)
|
||||
txt, used_tokens = chat_partial(**kwargs)
|
||||
txt = self._remove_reasoning_content(txt)
|
||||
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||
|
||||
@ -378,17 +384,17 @@ class LLMBundle:
|
||||
|
||||
return txt
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system: str, history: list, gen_conf: dict={}, **kwargs):
|
||||
if self.langfuse:
|
||||
generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||
|
||||
ans = ""
|
||||
chat_streamly = self.mdl.chat_streamly
|
||||
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
|
||||
total_tokens = 0
|
||||
if self.is_tools and self.mdl.is_tools:
|
||||
chat_streamly = self.mdl.chat_streamly_with_tools
|
||||
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
|
||||
|
||||
for txt in chat_streamly(system, history, gen_conf):
|
||||
for txt in chat_partial(**kwargs):
|
||||
if isinstance(txt, int):
|
||||
total_tokens = txt
|
||||
if self.langfuse:
|
||||
@ -398,8 +404,12 @@ class LLMBundle:
|
||||
if txt.endswith("</think>"):
|
||||
ans = ans.rstrip("</think>")
|
||||
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
ans += txt
|
||||
yield ans
|
||||
|
||||
if total_tokens > 0:
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
|
||||
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
|
||||
|
||||
Reference in New Issue
Block a user