mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Redesign and refactor agent module (#9113)
### What problem does this PR solve? #9082 #6365 <u> **WARNING: it's not compatible with the older version of `Agent` module, which means that `Agent` from older versions can not work anymore.**</u> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -14,13 +14,18 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from abc import ABC
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from agent.component import GenerateParam, Generate
|
||||
from agent.component import LLMParam, LLM
|
||||
from api.utils.api_utils import timeout
|
||||
from rag.llm.chat_model import ERROR_PREFIX
|
||||
|
||||
|
||||
class CategorizeParam(GenerateParam):
|
||||
class CategorizeParam(LLMParam):
|
||||
|
||||
"""
|
||||
Define the Categorize component parameters.
|
||||
@ -28,10 +33,12 @@ class CategorizeParam(GenerateParam):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.category_description = {}
|
||||
self.prompt = ""
|
||||
self.query = "sys.query"
|
||||
self.message_history_window_size = 1
|
||||
self.update_prompt()
|
||||
|
||||
def check(self):
|
||||
super().check()
|
||||
self.check_positive_integer(self.message_history_window_size, "[Categorize] Message window size > 0")
|
||||
self.check_empty(self.category_description, "[Categorize] Category examples")
|
||||
for k, v in self.category_description.items():
|
||||
if not k:
|
||||
@ -39,76 +46,90 @@ class CategorizeParam(GenerateParam):
|
||||
if not v.get("to"):
|
||||
raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
|
||||
|
||||
def get_prompt(self, chat_hist):
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"query": {
|
||||
"type": "line",
|
||||
"name": "Query"
|
||||
}
|
||||
}
|
||||
|
||||
def update_prompt(self):
|
||||
cate_lines = []
|
||||
for c, desc in self.category_description.items():
|
||||
for line in desc.get("examples", "").split("\n"):
|
||||
for line in desc.get("examples", []):
|
||||
if not line:
|
||||
continue
|
||||
cate_lines.append("USER: {}\nCategory: {}".format(line, c))
|
||||
cate_lines.append("USER: \"" + re.sub(r"\n", " ", line, flags=re.DOTALL) + "\" → "+c)
|
||||
|
||||
descriptions = []
|
||||
for c, desc in self.category_description.items():
|
||||
if desc.get("description"):
|
||||
descriptions.append(
|
||||
"\nCategory: {}\nDescription: {}".format(c, desc["description"]))
|
||||
"\n------\nCategory: {}\nDescription: {}".format(c, desc["description"]))
|
||||
|
||||
self.prompt = """
|
||||
Role: You're a text classifier.
|
||||
Task: You need to categorize the user’s questions into {} categories, namely: {}
|
||||
self.sys_prompt = """
|
||||
You are an advanced classification system that categorizes user questions into specific types. Analyze the input question and classify it into ONE of the following categories:
|
||||
{}
|
||||
|
||||
Here's description of each category:
|
||||
{}
|
||||
- {}
|
||||
|
||||
You could learn from the following examples:
|
||||
{}
|
||||
You could learn from the above examples.
|
||||
|
||||
Requirements:
|
||||
- Just mention the category names, no need for any additional words.
|
||||
|
||||
---- Real Data ----
|
||||
USER: {}\n
|
||||
""".format(
|
||||
len(self.category_description.keys()),
|
||||
"/".join(list(self.category_description.keys())),
|
||||
"\n".join(descriptions),
|
||||
"\n\n- ".join(cate_lines),
|
||||
chat_hist
|
||||
---- Instructions ----
|
||||
- Consider both explicit mentions and implied context
|
||||
- Prioritize the most specific applicable category
|
||||
- Return only the category name without explanations
|
||||
- Use "Other" only when no other category fits
|
||||
|
||||
""".format(
|
||||
"\n - ".join(list(self.category_description.keys())),
|
||||
"\n".join(descriptions)
|
||||
)
|
||||
return self.prompt
|
||||
|
||||
if cate_lines:
|
||||
self.sys_prompt += """
|
||||
---- Examples ----
|
||||
{}
|
||||
""".format("\n".join(cate_lines))
|
||||
|
||||
|
||||
class Categorize(Generate, ABC):
|
||||
class Categorize(LLM, ABC):
|
||||
component_name = "Categorize"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
input = self.get_input()
|
||||
input = " - ".join(input["content"]) if "content" in input else ""
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
def _invoke(self, **kwargs):
|
||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||
if not msg:
|
||||
msg = [{"role": "user", "content": ""}]
|
||||
if kwargs.get("sys.query"):
|
||||
msg[-1]["content"] = kwargs["sys.query"]
|
||||
self.set_input_value("sys.query", kwargs["sys.query"])
|
||||
else:
|
||||
msg[-1]["content"] = self._canvas.get_variable_value(self._param.query)
|
||||
self.set_input_value(self._param.query, msg[-1]["content"])
|
||||
self._param.update_prompt()
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
||||
self._canvas.set_component_infor(self._id, {"prompt":self._param.get_prompt(input),"messages": [{"role": "user", "content": "\nCategory: "}],"conf": self._param.gen_conf()})
|
||||
|
||||
ans = chat_mdl.chat(self._param.get_prompt(input), [{"role": "user", "content": "\nCategory: "}],
|
||||
self._param.gen_conf())
|
||||
logging.debug(f"input: {input}, answer: {str(ans)}")
|
||||
user_prompt = """
|
||||
---- Real Data ----
|
||||
{} →
|
||||
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
|
||||
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
||||
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
||||
if ERROR_PREFIX in ans:
|
||||
raise Exception(ans)
|
||||
# Count the number of times each category appears in the answer.
|
||||
category_counts = {}
|
||||
for c in self._param.category_description.keys():
|
||||
count = ans.lower().count(c.lower())
|
||||
category_counts[c] = count
|
||||
|
||||
# If a category is found, return the category with the highest count.
|
||||
|
||||
cpn_ids = list(self._param.category_description.items())[-1][1]["to"]
|
||||
max_category = list(self._param.category_description.keys())[0]
|
||||
if any(category_counts.values()):
|
||||
max_category = max(category_counts.items(), key=lambda x: x[1])
|
||||
res = Categorize.be_output(self._param.category_description[max_category[0]]["to"])
|
||||
self.set_output(res)
|
||||
return res
|
||||
max_category = max(category_counts.items(), key=lambda x: x[1])[0]
|
||||
cpn_ids = self._param.category_description[max_category]["to"]
|
||||
|
||||
res = Categorize.be_output(list(self._param.category_description.items())[-1][1]["to"])
|
||||
self.set_output(res)
|
||||
return res
|
||||
|
||||
def debug(self, **kwargs):
|
||||
df = self._run([], **kwargs)
|
||||
cpn_id = df.iloc[0, 0]
|
||||
return Categorize.be_output(self._canvas.get_component_name(cpn_id))
|
||||
self.set_output("category_name", max_category)
|
||||
self.set_output("_next", cpn_ids)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user