Files
ragflow/agent/component/agent_with_tools.py
Yongteng Lei 27b0550876 Refa: cleanup synchronous functions in agent_with_tools (#11736)
### What problem does this PR solve?

Cleanup synchronous functions in agent_with_tools.

### Type of change

- [x] Refactoring
2025-12-04 14:15:05 +08:00

420 lines
18 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import os
import re
from copy import deepcopy
from functools import partial
from typing import Any
import json_repair
from timeit import default_timer as timer
from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.mcp_server_service import MCPServerService
from common.connection_utils import timeout
from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from agent.component.llm import LLMParam, LLM
class AgentParam(LLMParam, ToolParamBase):
"""
Define the Agent component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "agent",
"description": "This is an agent for a specific task.",
"parameters": {
"user_prompt": {
"type": "string",
"description": "This is the order you need to send to the agent.",
"default": "",
"required": True
},
"reasoning": {
"type": "string",
"description": (
"Supervisor's reasoning for choosing the this agent. "
"Explain why this agent is being invoked and what is expected of it."
),
"required": True
},
"context": {
"type": "string",
"description": (
"All relevant background information, prior facts, decisions, "
"and state needed by the agent to solve the current query. "
"Should be as detailed and self-contained as possible."
),
"required": True
},
}
}
super().__init__()
self.function_name = "agent"
self.tools = []
self.mcp = []
self.max_rounds = 5
self.description = ""
class Agent(LLM, ToolBase):
component_name = "Agent"
def __init__(self, canvas, id, param: LLMParam):
LLM.__init__(self, canvas, id, param)
self.tools = {}
for cpn in self._param.tools:
cpn = self._load_tool_obj(cpn)
self.tools[cpn.get_meta()["function"]["name"]] = cpn
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id,
max_retries=self._param.max_retries,
retry_interval=self._param.delay_after_error,
max_rounds=self._param.max_rounds,
verbose_tool_use=True
)
self.tool_meta = [v.get_meta() for _,v in self.tools.items()]
for mcp in self._param.mcp:
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
for tnm, meta in mcp["tools"].items():
self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta))
self.tools[tnm] = tool_call_session
self.callback = partial(self._canvas.tool_use_callback, id)
self.toolcall_session = LLMToolPluginCallSession(self.tools, self.callback)
#self.chat_mdl.bind_tools(self.toolcall_session, self.tool_metas)
def _load_tool_obj(self, cpn: dict) -> object:
from agent.component import component_class
param = component_class(cpn["component_name"] + "Param")()
param.update(cpn["params"])
try:
param.check()
except Exception as e:
self.set_output("_ERROR", cpn["component_name"] + f" configuration error: {e}")
raise
cpn_id = f"{self._id}-->" + cpn.get("name", "").replace(" ", "_")
return component_class(cpn["component_name"])(self._canvas, cpn_id, param)
def get_meta(self) -> dict[str, Any]:
self._param.function_name= self._id.split("-->")[-1]
m = super().get_meta()
if hasattr(self._param, "user_prompt") and self._param.user_prompt:
m["function"]["parameters"]["properties"]["user_prompt"] = self._param.user_prompt
return m
def get_input_form(self) -> dict[str, dict]:
res = {}
for k, v in self.get_input_elements().items():
res[k] = {
"type": "line",
"name": v["name"]
}
for cpn in self._param.tools:
if not isinstance(cpn, LLM):
continue
res.update(cpn.get_input_form())
return res
def _get_output_schema(self):
try:
cand = self._param.outputs.get("structured")
except Exception:
return None
if isinstance(cand, dict):
if isinstance(cand.get("properties"), dict) and len(cand["properties"]) > 0:
return cand
for k in ("schema", "structured"):
if isinstance(cand.get(k), dict) and isinstance(cand[k].get("properties"), dict) and len(cand[k]["properties"]) > 0:
return cand[k]
return None
async def _force_format_to_schema_async(self, text: str, schema_prompt: str) -> str:
fmt_msgs = [
{"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."},
{"role": "user", "content": text},
]
_, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97))
return await self._generate_async(fmt_msgs)
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Agent processing"):
return
if kwargs.get("user_prompt"):
usr_pmt = ""
if kwargs.get("reasoning"):
usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
if kwargs.get("context"):
usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
if usr_pmt:
usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
else:
usr_pmt = str(kwargs["user_prompt"])
self._param.prompts = [{"role": "user", "content": usr_pmt}]
if not self.tools:
if self.check_if_canceled("Agent processing"):
return
return await LLM._invoke_async(self, **kwargs)
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
output_schema = self._get_output_schema()
schema_prompt = ""
if output_schema:
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
schema_prompt = structured_output_prompt(schema)
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt))
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
if self.check_if_canceled("Agent processing"):
return
ans += delta_ans
if ans.find("**ERROR**") >= 0:
logging.error(f"Agent._chat got error. response: {ans}")
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
else:
self.set_output("_ERROR", ans)
return
if output_schema:
error = ""
for _ in range(self._param.max_retries + 1):
try:
def clean_formated_answer(ans: str) -> str:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
obj = json_repair.loads(clean_formated_answer(ans))
self.set_output("structured", obj)
if use_tools:
self.set_output("use_tools", use_tools)
return obj
except Exception:
error = "The answer cannot be parsed as JSON"
ans = await self._force_format_to_schema_async(ans, schema_prompt)
if ans.find("**ERROR**") >= 0:
continue
self.set_output("_ERROR", error)
return
self.set_output("content", ans)
if use_tools:
self.set_output("use_tools", use_tools)
return ans
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
use_tools = []
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent streaming"):
return
if delta_ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", delta_ans)
return
answer_without_toolcall += delta_ans
yield delta_ans
self.set_output("content", answer_without_toolcall)
if use_tools:
self.set_output("use_tools", use_tools)
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
token_count = 0
tool_metas = self.tool_meta
hist = deepcopy(history)
last_calling = ""
if len(hist) > 3:
st = timer()
user_request = await asyncio.to_thread(full_question, messages=history, chat_mdl=self.chat_mdl)
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
else:
user_request = history[-1]["content"]
async def use_tool_async(name, args):
nonlocal hist, use_tools, last_calling
logging.info(f"{last_calling=} == {name=}")
last_calling = name
tool_response = await self.toolcall_session.tool_call_async(name, args)
use_tools.append({
"name": name,
"arguments": args,
"results": tool_response
})
# self.callback("add_memory", {}, "...")
#self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt)
return name, tool_response
async def complete():
nonlocal hist
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
if schema_prompt:
need2cite = False
cited = False
if hist and hist[0]["role"] == "system":
if schema_prompt:
hist[0]["content"] += "\n" + schema_prompt
if need2cite and len(hist) < 7:
hist[0]["content"] += citation_prompt()
cited = True
yield "", token_count
_hist = hist
if len(hist) > 12:
_hist = [hist[0], hist[1], *hist[-10:]]
entire_txt = ""
async for delta_ans in self._generate_streamly_async(_hist):
if not need2cite or cited:
yield delta_ans, 0
entire_txt += delta_ans
if not need2cite or cited:
return
st = timer()
txt = ""
async for delta_ans in self._gen_citations_async(entire_txt):
if self.check_if_canceled("Agent streaming"):
return
yield delta_ans, 0
txt += delta_ans
self.callback("gen_citations", {}, txt, elapsed_time=timer()-st)
def append_user_content(hist, content):
if hist[-1]["role"] == "user":
hist[-1]["content"] += content
else:
hist.append({"role": "user", "content": content})
st = timer()
task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
for _ in range(self._param.max_rounds + 1):
if self.check_if_canceled("Agent streaming"):
return
response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
# self.callback("next_step", {}, str(response)[:256]+"...")
token_count += tk or 0
hist.append({"role": "assistant", "content": response})
try:
functions = json_repair.loads(re.sub(r"```.*", "", response))
if not isinstance(functions, list):
raise TypeError(f"List should be returned, but `{functions}`")
for f in functions:
if not isinstance(f, dict):
raise TypeError(f"An object type should be returned, but `{f}`")
tool_tasks = []
for func in functions:
name = func["name"]
args = func["arguments"]
if name == COMPLETE_TASK:
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
async for txt, tkcnt in complete():
yield txt, tkcnt
return
tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
st = timer()
reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
append_user_content(hist, reflection)
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}"
append_user_content(hist, str(e))
logging.warning( f"Exceed max rounds: {self._param.max_rounds}")
final_instruction = f"""
{user_request}
IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request.
Instructions:
1. SYNTHESIZE all information collected during this conversation
2. Provide a COMPLETE response using existing data - do not suggest additional research
3. Structure your response as a FINAL DELIVERABLE, not a plan
4. If information is incomplete, state what you found and provide the best analysis possible with available data
5. DO NOT mention conversation limits or suggest further steps
6. Focus on delivering VALUE with the information already gathered
Respond immediately with your final comprehensive answer.
"""
if self.check_if_canceled("Agent final instruction"):
return
append_user_content(hist, final_instruction)
async for txt, tkcnt in complete():
yield txt, tkcnt
async def _gen_citations_async(self, text):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
async for delta_ans in self._generate_streamly_async([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
{"role": "user", "content": text}
]):
yield delta_ans
def reset(self, only_output=False):
"""
Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession.
"""
for k in self._param.outputs.keys():
self._param.outputs[k]["value"] = None
for k, cpn in self.tools.items():
if hasattr(cpn, "reset") and callable(cpn.reset):
cpn.reset()
if only_output:
return
for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None
self._param.debug_inputs = {}