Feat: Redesign and refactor agent module (#9113)

### What problem does this PR solve?

#9082 #6365

<u> **WARNING: it's not compatible with the older version of `Agent`
module, which means that `Agent` from older versions can not work
anymore.**</u>

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-07-30 19:41:09 +08:00
committed by GitHub
parent 07e37560fc
commit d9fe279dde
124 changed files with 7744 additions and 18226 deletions

View File

@ -14,13 +14,18 @@
# limitations under the License.
#
import logging
import os
import re
from abc import ABC
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from agent.component import GenerateParam, Generate
from agent.component import LLMParam, LLM
from api.utils.api_utils import timeout
from rag.llm.chat_model import ERROR_PREFIX
class CategorizeParam(GenerateParam):
class CategorizeParam(LLMParam):
"""
Define the Categorize component parameters.
@ -28,10 +33,12 @@ class CategorizeParam(GenerateParam):
def __init__(self):
super().__init__()
self.category_description = {}
self.prompt = ""
self.query = "sys.query"
self.message_history_window_size = 1
self.update_prompt()
def check(self):
super().check()
self.check_positive_integer(self.message_history_window_size, "[Categorize] Message window size > 0")
self.check_empty(self.category_description, "[Categorize] Category examples")
for k, v in self.category_description.items():
if not k:
@ -39,76 +46,90 @@ class CategorizeParam(GenerateParam):
if not v.get("to"):
raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
def get_prompt(self, chat_hist):
def get_input_form(self) -> dict[str, dict]:
return {
"query": {
"type": "line",
"name": "Query"
}
}
def update_prompt(self):
cate_lines = []
for c, desc in self.category_description.items():
for line in desc.get("examples", "").split("\n"):
for line in desc.get("examples", []):
if not line:
continue
cate_lines.append("USER: {}\nCategory: {}".format(line, c))
cate_lines.append("USER: \"" + re.sub(r"\n", " ", line, flags=re.DOTALL) + "\""+c)
descriptions = []
for c, desc in self.category_description.items():
if desc.get("description"):
descriptions.append(
"\nCategory: {}\nDescription: {}".format(c, desc["description"]))
"\n------\nCategory: {}\nDescription: {}".format(c, desc["description"]))
self.prompt = """
Role: You're a text classifier.
Task: You need to categorize the users questions into {} categories, namely: {}
self.sys_prompt = """
You are an advanced classification system that categorizes user questions into specific types. Analyze the input question and classify it into ONE of the following categories:
{}
Here's description of each category:
{}
- {}
You could learn from the following examples:
{}
You could learn from the above examples.
Requirements:
- Just mention the category names, no need for any additional words.
---- Real Data ----
USER: {}\n
""".format(
len(self.category_description.keys()),
"/".join(list(self.category_description.keys())),
"\n".join(descriptions),
"\n\n- ".join(cate_lines),
chat_hist
---- Instructions ----
- Consider both explicit mentions and implied context
- Prioritize the most specific applicable category
- Return only the category name without explanations
- Use "Other" only when no other category fits
""".format(
"\n - ".join(list(self.category_description.keys())),
"\n".join(descriptions)
)
return self.prompt
if cate_lines:
self.sys_prompt += """
---- Examples ----
{}
""".format("\n".join(cate_lines))
class Categorize(Generate, ABC):
class Categorize(LLM, ABC):
component_name = "Categorize"
def _run(self, history, **kwargs):
input = self.get_input()
input = " - ".join(input["content"]) if "content" in input else ""
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
def _invoke(self, **kwargs):
msg = self._canvas.get_history(self._param.message_history_window_size)
if not msg:
msg = [{"role": "user", "content": ""}]
if kwargs.get("sys.query"):
msg[-1]["content"] = kwargs["sys.query"]
self.set_input_value("sys.query", kwargs["sys.query"])
else:
msg[-1]["content"] = self._canvas.get_variable_value(self._param.query)
self.set_input_value(self._param.query, msg[-1]["content"])
self._param.update_prompt()
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
self._canvas.set_component_infor(self._id, {"prompt":self._param.get_prompt(input),"messages": [{"role": "user", "content": "\nCategory: "}],"conf": self._param.gen_conf()})
ans = chat_mdl.chat(self._param.get_prompt(input), [{"role": "user", "content": "\nCategory: "}],
self._param.gen_conf())
logging.debug(f"input: {input}, answer: {str(ans)}")
user_prompt = """
---- Real Data ----
{}
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
if ERROR_PREFIX in ans:
raise Exception(ans)
# Count the number of times each category appears in the answer.
category_counts = {}
for c in self._param.category_description.keys():
count = ans.lower().count(c.lower())
category_counts[c] = count
# If a category is found, return the category with the highest count.
cpn_ids = list(self._param.category_description.items())[-1][1]["to"]
max_category = list(self._param.category_description.keys())[0]
if any(category_counts.values()):
max_category = max(category_counts.items(), key=lambda x: x[1])
res = Categorize.be_output(self._param.category_description[max_category[0]]["to"])
self.set_output(res)
return res
max_category = max(category_counts.items(), key=lambda x: x[1])[0]
cpn_ids = self._param.category_description[max_category]["to"]
res = Categorize.be_output(list(self._param.category_description.items())[-1][1]["to"])
self.set_output(res)
return res
def debug(self, **kwargs):
df = self._run([], **kwargs)
cpn_id = df.iloc[0, 0]
return Categorize.be_output(self._canvas.get_component_name(cpn_id))
self.set_output("category_name", max_category)
self.set_output("_next", cpn_ids)