Feat: Redesign and refactor agent module (#9113)

### What problem does this PR solve?

#9082 #6365

<u> **WARNING: it's not compatible with the older version of `Agent`
module, which means that `Agent` from older versions can not work
anymore.**</u>

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-07-30 19:41:09 +08:00
committed by GitHub
parent 07e37560fc
commit d9fe279dde
124 changed files with 7744 additions and 18226 deletions

View File

@ -14,6 +14,8 @@
# limitations under the License.
#
import logging
import re
from functools import partial
from langfuse import Langfuse
@ -137,7 +139,7 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese"):
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
@ -152,12 +154,12 @@ class TenantLLMService(CommonService):
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"])
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel:
@ -221,20 +223,21 @@ class TenantLLMService(CommonService):
for llm_factory in llm_factories:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
return llm["model_type"].strip(",")[-1]
return llm["model_type"].split(",")[-1]
class LLMBundle:
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
self.tenant_id = tenant_id
self.llm_type = llm_type
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang)
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
self.max_length = model_config.get("max_tokens", 8192)
self.is_tools = model_config.get("is_tools", False)
self.verbose_tool_use = kwargs.get("verbose_tool_use")
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
if langfuse_keys:
@ -331,7 +334,7 @@ class LLMBundle:
return txt
def tts(self, text):
def tts(self, text: str) -> None:
if self.langfuse:
span = self.trace.span(name="tts", input={"text": text})
@ -359,17 +362,20 @@ class LLMBundle:
return txt[last_think_end + len("</think>") :]
def chat(self, system, history, gen_conf):
def chat(self, system: str, history: list, gen_conf: dict={}, **kwargs) -> str:
if self.langfuse:
generation = self.trace.generation(name="chat", model=self.llm_name, input={"system": system, "history": history})
chat = self.mdl.chat
chat_partial = partial(self.mdl.chat, system, history, gen_conf)
if self.is_tools and self.mdl.is_tools:
chat = self.mdl.chat_with_tools
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf)
txt, used_tokens = chat(system, history, gen_conf)
txt, used_tokens = chat_partial(**kwargs)
txt = self._remove_reasoning_content(txt)
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
@ -378,17 +384,17 @@ class LLMBundle:
return txt
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system: str, history: list, gen_conf: dict={}, **kwargs):
if self.langfuse:
generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
ans = ""
chat_streamly = self.mdl.chat_streamly
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
total_tokens = 0
if self.is_tools and self.mdl.is_tools:
chat_streamly = self.mdl.chat_streamly_with_tools
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
for txt in chat_streamly(system, history, gen_conf):
for txt in chat_partial(**kwargs):
if isinstance(txt, int):
total_tokens = txt
if self.langfuse:
@ -398,8 +404,12 @@ class LLMBundle:
if txt.endswith("</think>"):
ans = ans.rstrip("</think>")
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
ans += txt
yield ans
if total_tokens > 0:
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))