Refa: cleanup synchronous functions in agent_with_tools (#11736)

### What problem does this PR solve?

Cleanup synchronous functions in agent_with_tools.

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-12-04 14:15:05 +08:00
committed by GitHub
parent 797e03f843
commit 27b0550876
4 changed files with 105 additions and 256 deletions

View File

@ -18,7 +18,6 @@ import json
import logging import logging
import os import os
import re import re
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Any from typing import Any
@ -30,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.mcp_server_service import MCPServerService from api.db.services.mcp_server_service import MCPServerService
from common.connection_utils import timeout from common.connection_utils import timeout
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \ from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from agent.component.llm import LLMParam, LLM from agent.component.llm import LLMParam, LLM
@ -154,96 +153,19 @@ class Agent(LLM, ToolBase):
return None return None
def _force_format_to_schema(self, text: str, schema_prompt: str) -> str: async def _force_format_to_schema_async(self, text: str, schema_prompt: str) -> str:
fmt_msgs = [ fmt_msgs = [
{"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."}, {"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."},
{"role": "user", "content": text}, {"role": "user", "content": text},
] ]
_, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97)) _, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97))
return self._generate(fmt_msgs) return await self._generate_async(fmt_msgs)
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("Agent processing"):
return
if kwargs.get("user_prompt"):
usr_pmt = ""
if kwargs.get("reasoning"):
usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
if kwargs.get("context"):
usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
if usr_pmt:
usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
else:
usr_pmt = str(kwargs["user_prompt"])
self._param.prompts = [{"role": "user", "content": usr_pmt}]
if not self.tools:
if self.check_if_canceled("Agent processing"):
return
return LLM._invoke(self, **kwargs)
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
output_schema = self._get_output_schema()
schema_prompt = ""
if output_schema:
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
schema_prompt = structured_output_prompt(schema)
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
if self.check_if_canceled("Agent processing"):
return
ans += delta_ans
if ans.find("**ERROR**") >= 0:
logging.error(f"Agent._chat got error. response: {ans}")
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
else:
self.set_output("_ERROR", ans)
return
if output_schema:
error = ""
for _ in range(self._param.max_retries + 1):
try:
def clean_formated_answer(ans: str) -> str:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
obj = json_repair.loads(clean_formated_answer(ans))
self.set_output("structured", obj)
if use_tools:
self.set_output("use_tools", use_tools)
return obj
except Exception:
error = "The answer cannot be parsed as JSON"
ans = self._force_format_to_schema(ans, schema_prompt)
if ans.find("**ERROR**") >= 0:
continue
self.set_output("_ERROR", error)
return
self.set_output("content", ans)
if use_tools:
self.set_output("use_tools", use_tools)
return ans
async def _invoke_async(self, **kwargs): async def _invoke_async(self, **kwargs):
"""
Async entry: reuse existing logic but offload heavy sync parts via async wrappers to reduce blocking.
"""
if self.check_if_canceled("Agent processing"): if self.check_if_canceled("Agent processing"):
return return
@ -262,7 +184,7 @@ class Agent(LLM, ToolBase):
if not self.tools: if not self.tools:
if self.check_if_canceled("Agent processing"): if self.check_if_canceled("Agent processing"):
return return
return await asyncio.to_thread(LLM._invoke, self, **kwargs) return await LLM._invoke_async(self, **kwargs)
prompt, msg, user_defined_prompt = self._prepare_prompt_variables() prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
output_schema = self._get_output_schema() output_schema = self._get_output_schema()
@ -274,13 +196,13 @@ class Agent(LLM, ToolBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler() ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema: if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, msg, user_defined_prompt)) self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt))
return return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = [] use_tools = []
ans = "" ans = ""
async for delta_ans, tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt, schema_prompt=schema_prompt): async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
if self.check_if_canceled("Agent processing"): if self.check_if_canceled("Agent processing"):
return return
ans += delta_ans ans += delta_ans
@ -308,7 +230,7 @@ class Agent(LLM, ToolBase):
return obj return obj
except Exception: except Exception:
error = "The answer cannot be parsed as JSON" error = "The answer cannot be parsed as JSON"
ans = self._force_format_to_schema(ans, schema_prompt) ans = await self._force_format_to_schema_async(ans, schema_prompt)
if ans.find("**ERROR**") >= 0: if ans.find("**ERROR**") >= 0:
continue continue
@ -320,28 +242,6 @@ class Agent(LLM, ToolBase):
self.set_output("use_tools", use_tools) self.set_output("use_tools", use_tools)
return ans return ans
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
use_tools = []
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent streaming"):
return
if delta_ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", delta_ans)
return
answer_without_toolcall += delta_ans
yield delta_ans
self.set_output("content", answer_without_toolcall)
if use_tools:
self.set_output("use_tools", use_tools)
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}): async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = "" answer_without_toolcall = ""
@ -365,64 +265,22 @@ class Agent(LLM, ToolBase):
self.set_output("use_tools", use_tools) self.set_output("use_tools", use_tools)
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""): async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
"""
Async wrapper that offloads synchronous flow to a thread, yielding results without blocking the event loop.
"""
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for delta_ans, tk in self._react_with_tools_streamly(prompt, history, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
asyncio.run_coroutine_threadsafe(queue.put((delta_ans, tk)), loop)
except Exception as e:
asyncio.run_coroutine_threadsafe(queue.put(e), loop)
finally:
asyncio.run_coroutine_threadsafe(queue.put(StopAsyncIteration), loop)
await asyncio.to_thread(worker)
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
yield item
def _gen_citations(self, text):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
{"role": "user", "content": text}
]):
yield delta_ans
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
token_count = 0 token_count = 0
tool_metas = self.tool_meta tool_metas = self.tool_meta
hist = deepcopy(history) hist = deepcopy(history)
last_calling = "" last_calling = ""
if len(hist) > 3: if len(hist) > 3:
st = timer() st = timer()
user_request = full_question(messages=history, chat_mdl=self.chat_mdl) user_request = await asyncio.to_thread(full_question, messages=history, chat_mdl=self.chat_mdl)
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st) self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
else: else:
user_request = history[-1]["content"] user_request = history[-1]["content"]
def use_tool(name, args): async def use_tool_async(name, args):
nonlocal hist, use_tools, token_count,last_calling,user_request nonlocal hist, use_tools, last_calling
logging.info(f"{last_calling=} == {name=}") logging.info(f"{last_calling=} == {name=}")
# Summarize of function calling
#if all([
# isinstance(self.toolcall_session.get_tool_obj(name), Agent),
# last_calling,
# last_calling != name
#]):
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt))
last_calling = name last_calling = name
tool_response = self.toolcall_session.tool_call(name, args) tool_response = await self.toolcall_session.tool_call_async(name, args)
use_tools.append({ use_tools.append({
"name": name, "name": name,
"arguments": args, "arguments": args,
@ -433,7 +291,7 @@ class Agent(LLM, ToolBase):
return name, tool_response return name, tool_response
def complete(): async def complete():
nonlocal hist nonlocal hist
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0 need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
if schema_prompt: if schema_prompt:
@ -451,7 +309,7 @@ class Agent(LLM, ToolBase):
if len(hist) > 12: if len(hist) > 12:
_hist = [hist[0], hist[1], *hist[-10:]] _hist = [hist[0], hist[1], *hist[-10:]]
entire_txt = "" entire_txt = ""
for delta_ans in self._generate_streamly(_hist): async for delta_ans in self._generate_streamly_async(_hist):
if not need2cite or cited: if not need2cite or cited:
yield delta_ans, 0 yield delta_ans, 0
entire_txt += delta_ans entire_txt += delta_ans
@ -460,7 +318,7 @@ class Agent(LLM, ToolBase):
st = timer() st = timer()
txt = "" txt = ""
for delta_ans in self._gen_citations(entire_txt): async for delta_ans in self._gen_citations_async(entire_txt):
if self.check_if_canceled("Agent streaming"): if self.check_if_canceled("Agent streaming"):
return return
yield delta_ans, 0 yield delta_ans, 0
@ -475,14 +333,14 @@ class Agent(LLM, ToolBase):
hist.append({"role": "user", "content": content}) hist.append({"role": "user", "content": content})
st = timer() st = timer()
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt) task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st) self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
for _ in range(self._param.max_rounds + 1): for _ in range(self._param.max_rounds + 1):
if self.check_if_canceled("Agent streaming"): if self.check_if_canceled("Agent streaming"):
return return
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt) response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
# self.callback("next_step", {}, str(response)[:256]+"...") # self.callback("next_step", {}, str(response)[:256]+"...")
token_count += tk token_count += tk or 0
hist.append({"role": "assistant", "content": response}) hist.append({"role": "assistant", "content": response})
try: try:
functions = json_repair.loads(re.sub(r"```.*", "", response)) functions = json_repair.loads(re.sub(r"```.*", "", response))
@ -491,21 +349,22 @@ class Agent(LLM, ToolBase):
for f in functions: for f in functions:
if not isinstance(f, dict): if not isinstance(f, dict):
raise TypeError(f"An object type should be returned, but `{f}`") raise TypeError(f"An object type should be returned, but `{f}`")
with ThreadPoolExecutor(max_workers=5) as executor:
thr = [] tool_tasks = []
for func in functions: for func in functions:
name = func["name"] name = func["name"]
args = func["arguments"] args = func["arguments"]
if name == COMPLETE_TASK: if name == COMPLETE_TASK:
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n") append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
for txt, tkcnt in complete(): async for txt, tkcnt in complete():
yield txt, tkcnt yield txt, tkcnt
return return
thr.append(executor.submit(use_tool, name, args)) tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
st = timer() st = timer()
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt) reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
append_user_content(hist, reflection) append_user_content(hist, reflection)
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st) self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
@ -531,21 +390,17 @@ Respond immediately with your final comprehensive answer.
return return
append_user_content(hist, final_instruction) append_user_content(hist, final_instruction)
for txt, tkcnt in complete(): async for txt, tkcnt in complete():
yield txt, tkcnt yield txt, tkcnt
def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str: async def _gen_citations_async(self, text):
# self.callback("get_useful_memory", {"topn": 3}, "...") retrievals = self._canvas.get_reference()
mems = self._canvas.get_memory() retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt) formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
try: async for delta_ans in self._generate_streamly_async([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn] {"role": "user", "content": text}
mems = [mems[r] for r in rank] ]):
return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems]) yield delta_ans
except Exception as e:
logging.exception(e)
return "Error occurred."
def reset(self, only_output=False): def reset(self, only_output=False):
""" """

View File

@ -327,7 +327,7 @@ class LLM(ComponentBase):
self.set_output("content", answer) self.set_output("content", answer)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): async def _invoke_async(self, **kwargs):
if self.check_if_canceled("LLM processing"): if self.check_if_canceled("LLM processing"):
return return
@ -338,22 +338,25 @@ class LLM(ComponentBase):
prompt, msg, _ = self._prepare_prompt_variables() prompt, msg, _ = self._prepare_prompt_variables()
error: str = "" error: str = ""
output_structure=None output_structure = None
try: try:
output_structure = self._param.outputs['structured'] output_structure = self._param.outputs["structured"]
except Exception: except Exception:
pass pass
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0: if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0:
schema=json.dumps(output_structure, ensure_ascii=False, indent=2) schema = json.dumps(output_structure, ensure_ascii=False, indent=2)
prompt += structured_output_prompt(schema) prompt_with_schema = prompt + structured_output_prompt(schema)
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries + 1):
if self.check_if_canceled("LLM processing"): if self.check_if_canceled("LLM processing"):
return return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg_fit = message_fit_in(
[{"role": "system", "content": prompt_with_schema}, *deepcopy(msg)],
int(self.chat_mdl.max_length * 0.97),
)
error = "" error = ""
ans = self._generate(msg) ans = await self._generate_async(msg_fit)
msg.pop(0) msg_fit.pop(0)
if ans.find("**ERROR**") >= 0: if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}") logging.error(f"LLM response error: {ans}")
error = ans error = ans
@ -362,7 +365,7 @@ class LLM(ComponentBase):
self.set_output("structured", json_repair.loads(clean_formated_answer(ans))) self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
return return
except Exception: except Exception:
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"}) msg_fit.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
error = "The answer can't not be parsed as JSON" error = "The answer can't not be parsed as JSON"
if error: if error:
self.set_output("_ERROR", error) self.set_output("_ERROR", error)
@ -370,18 +373,23 @@ class LLM(ComponentBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler() ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]): if any([self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams]) and not (
self.set_output("content", partial(self._stream_output_async, prompt, msg)) ex and ex["goto"]
):
self.set_output("content", partial(self._stream_output_async, prompt, deepcopy(msg)))
return return
for _ in range(self._param.max_retries+1): error = ""
for _ in range(self._param.max_retries + 1):
if self.check_if_canceled("LLM processing"): if self.check_if_canceled("LLM processing"):
return return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg_fit = message_fit_in(
[{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97)
)
error = "" error = ""
ans = self._generate(msg) ans = await self._generate_async(msg_fit)
msg.pop(0) msg_fit.pop(0)
if ans.find("**ERROR**") >= 0: if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}") logging.error(f"LLM response error: {ans}")
error = ans error = ans
@ -395,23 +403,9 @@ class LLM(ComponentBase):
else: else:
self.set_output("_ERROR", error) self.set_output("_ERROR", error)
def _stream_output(self, prompt, msg): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) def _invoke(self, **kwargs):
answer = "" return asyncio.run(self._invoke_async(**kwargs))
for ans in self._generate_streamly(msg):
if self.check_if_canceled("LLM streaming"):
return
if ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", ans)
return
yield ans
answer += ans
self.set_output("content", answer)
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}): def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt) summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)

View File

@ -49,16 +49,19 @@ class LLMToolPluginCallSession(ToolCallSession):
self.callback = callback self.callback = callback
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any: def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
return asyncio.run(self.tool_call_async(name, arguments))
async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any:
assert name in self.tools_map, f"LLM tool {name} does not exist" assert name in self.tools_map, f"LLM tool {name} does not exist"
st = timer() st = timer()
tool_obj = self.tools_map[name] tool_obj = self.tools_map[name]
if isinstance(tool_obj, MCPToolCallSession): if isinstance(tool_obj, MCPToolCallSession):
resp = tool_obj.tool_call(name, arguments, 60) resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
else: else:
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async): if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
resp = asyncio.run(tool_obj.invoke_async(**arguments)) resp = await tool_obj.invoke_async(**arguments)
else: else:
resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments)) resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
self.callback(name, arguments, resp, elapsed_time=timer()-st) self.callback(name, arguments, resp, elapsed_time=timer()-st)
return resp return resp

View File

@ -343,7 +343,8 @@ def form_history(history, limit=-6):
return context return context
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
tools_desc = tool_schema(tools_description) tools_desc = tool_schema(tools_description)
context = "" context = ""
@ -352,7 +353,7 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
else: else:
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER) template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc) context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
kwd = chat_mdl.chat(context, [{"role": "user", "content": "Please analyze it."}]) kwd = await _chat_async(chat_mdl, context, [{"role": "user", "content": "Please analyze it."}])
if isinstance(kwd, tuple): if isinstance(kwd, tuple):
kwd = kwd[0] kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL) kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
@ -361,13 +362,17 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
return kwd return kwd
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}): async def _chat_async(chat_mdl, system: str, history: list, **kwargs):
return await asyncio.to_thread(analyze_task, chat_mdl, prompt, task_name, tools_description, user_defined_prompts) chat_async = getattr(chat_mdl, "async_chat", None)
if chat_async and asyncio.iscoroutinefunction(chat_async):
return await chat_async(system, history, **kwargs)
return await asyncio.to_thread(chat_mdl.chat, system, history, **kwargs)
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
if not tools_description: if not tools_description:
return "" return "", 0
desc = tool_schema(tools_description) desc = tool_schema(tools_description)
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP)) template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP))
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`." user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
@ -376,18 +381,18 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc,
hist[-1]["content"] += user_prompt hist[-1]["content"] += user_prompt
else: else:
hist.append({"role": "user", "content": user_prompt}) hist.append({"role": "user", "content": user_prompt})
json_str = chat_mdl.chat(template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")), json_str = await _chat_async(
hist[1:], stop=["<|stop|>"]) chat_mdl,
template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
hist[1:],
stop=["<|stop|>"],
)
tk_cnt = num_tokens_from_string(json_str) tk_cnt = num_tokens_from_string(json_str)
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL) json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
return json_str, tk_cnt return json_str, tk_cnt
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}): async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
return await asyncio.to_thread(next_step, chat_mdl, history, tools_description, task_desc, user_defined_prompts)
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res] tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"] goal = history[1]["content"]
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT)) template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
@ -398,7 +403,7 @@ def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defi
else: else:
hist.append({"role": "user", "content": user_prompt}) hist.append({"role": "user", "content": user_prompt})
_, msg = message_fit_in(hist, chat_mdl.max_length) _, msg = message_fit_in(hist, chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:]) ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return """ return """
**Observation** **Observation**
@ -429,23 +434,15 @@ def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defin
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}): async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY) template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)]) system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
user_prompt = " → rank: " user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>") ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
return await asyncio.to_thread(reflect, chat_mdl, history, tool_call_res, user_defined_prompts)
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
return await asyncio.to_thread(rank_memories, chat_mdl, goal, sub_goal, tool_call_summaries, user_defined_prompts)
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
meta_data_structure = {} meta_data_structure = {}
for key, values in meta_data.items(): for key, values in meta_data.items():