Refa: cleanup synchronous functions in agent_with_tools (#11736)

### What problem does this PR solve?

Cleanup synchronous functions in agent_with_tools.

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-12-04 14:15:05 +08:00
committed by GitHub
parent 797e03f843
commit 27b0550876
4 changed files with 105 additions and 256 deletions

View File

@ -18,7 +18,6 @@ import json
import logging
import os
import re
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from functools import partial
from typing import Any
@ -30,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.mcp_server_service import MCPServerService
from common.connection_utils import timeout
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from agent.component.llm import LLMParam, LLM
@ -154,96 +153,19 @@ class Agent(LLM, ToolBase):
return None
def _force_format_to_schema(self, text: str, schema_prompt: str) -> str:
async def _force_format_to_schema_async(self, text: str, schema_prompt: str) -> str:
fmt_msgs = [
{"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."},
{"role": "user", "content": text},
]
_, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97))
return self._generate(fmt_msgs)
return await self._generate_async(fmt_msgs)
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("Agent processing"):
return
if kwargs.get("user_prompt"):
usr_pmt = ""
if kwargs.get("reasoning"):
usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
if kwargs.get("context"):
usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
if usr_pmt:
usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
else:
usr_pmt = str(kwargs["user_prompt"])
self._param.prompts = [{"role": "user", "content": usr_pmt}]
if not self.tools:
if self.check_if_canceled("Agent processing"):
return
return LLM._invoke(self, **kwargs)
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
output_schema = self._get_output_schema()
schema_prompt = ""
if output_schema:
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
schema_prompt = structured_output_prompt(schema)
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
if self.check_if_canceled("Agent processing"):
return
ans += delta_ans
if ans.find("**ERROR**") >= 0:
logging.error(f"Agent._chat got error. response: {ans}")
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
else:
self.set_output("_ERROR", ans)
return
if output_schema:
error = ""
for _ in range(self._param.max_retries + 1):
try:
def clean_formated_answer(ans: str) -> str:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
obj = json_repair.loads(clean_formated_answer(ans))
self.set_output("structured", obj)
if use_tools:
self.set_output("use_tools", use_tools)
return obj
except Exception:
error = "The answer cannot be parsed as JSON"
ans = self._force_format_to_schema(ans, schema_prompt)
if ans.find("**ERROR**") >= 0:
continue
self.set_output("_ERROR", error)
return
self.set_output("content", ans)
if use_tools:
self.set_output("use_tools", use_tools)
return ans
async def _invoke_async(self, **kwargs):
"""
Async entry: reuse existing logic but offload heavy sync parts via async wrappers to reduce blocking.
"""
if self.check_if_canceled("Agent processing"):
return
@ -262,7 +184,7 @@ class Agent(LLM, ToolBase):
if not self.tools:
if self.check_if_canceled("Agent processing"):
return
return await asyncio.to_thread(LLM._invoke, self, **kwargs)
return await LLM._invoke_async(self, **kwargs)
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
output_schema = self._get_output_schema()
@ -274,13 +196,13 @@ class Agent(LLM, ToolBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, msg, user_defined_prompt))
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt))
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
async for delta_ans, tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
if self.check_if_canceled("Agent processing"):
return
ans += delta_ans
@ -308,7 +230,7 @@ class Agent(LLM, ToolBase):
return obj
except Exception:
error = "The answer cannot be parsed as JSON"
ans = self._force_format_to_schema(ans, schema_prompt)
ans = await self._force_format_to_schema_async(ans, schema_prompt)
if ans.find("**ERROR**") >= 0:
continue
@ -320,28 +242,6 @@ class Agent(LLM, ToolBase):
self.set_output("use_tools", use_tools)
return ans
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
use_tools = []
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent streaming"):
return
if delta_ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", delta_ans)
return
answer_without_toolcall += delta_ans
yield delta_ans
self.set_output("content", answer_without_toolcall)
if use_tools:
self.set_output("use_tools", use_tools)
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
@ -365,64 +265,22 @@ class Agent(LLM, ToolBase):
self.set_output("use_tools", use_tools)
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
"""
Async wrapper that offloads synchronous flow to a thread, yielding results without blocking the event loop.
"""
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for delta_ans, tk in self._react_with_tools_streamly(prompt, history, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
asyncio.run_coroutine_threadsafe(queue.put((delta_ans, tk)), loop)
except Exception as e:
asyncio.run_coroutine_threadsafe(queue.put(e), loop)
finally:
asyncio.run_coroutine_threadsafe(queue.put(StopAsyncIteration), loop)
await asyncio.to_thread(worker)
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
yield item
def _gen_citations(self, text):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
{"role": "user", "content": text}
]):
yield delta_ans
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
token_count = 0
tool_metas = self.tool_meta
hist = deepcopy(history)
last_calling = ""
if len(hist) > 3:
st = timer()
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
user_request = await asyncio.to_thread(full_question, messages=history, chat_mdl=self.chat_mdl)
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
else:
user_request = history[-1]["content"]
def use_tool(name, args):
nonlocal hist, use_tools, token_count,last_calling,user_request
async def use_tool_async(name, args):
nonlocal hist, use_tools, last_calling
logging.info(f"{last_calling=} == {name=}")
# Summarize of function calling
#if all([
# isinstance(self.toolcall_session.get_tool_obj(name), Agent),
# last_calling,
# last_calling != name
#]):
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt))
last_calling = name
tool_response = self.toolcall_session.tool_call(name, args)
tool_response = await self.toolcall_session.tool_call_async(name, args)
use_tools.append({
"name": name,
"arguments": args,
@ -433,7 +291,7 @@ class Agent(LLM, ToolBase):
return name, tool_response
def complete():
async def complete():
nonlocal hist
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
if schema_prompt:
@ -451,7 +309,7 @@ class Agent(LLM, ToolBase):
if len(hist) > 12:
_hist = [hist[0], hist[1], *hist[-10:]]
entire_txt = ""
for delta_ans in self._generate_streamly(_hist):
async for delta_ans in self._generate_streamly_async(_hist):
if not need2cite or cited:
yield delta_ans, 0
entire_txt += delta_ans
@ -460,7 +318,7 @@ class Agent(LLM, ToolBase):
st = timer()
txt = ""
for delta_ans in self._gen_citations(entire_txt):
async for delta_ans in self._gen_citations_async(entire_txt):
if self.check_if_canceled("Agent streaming"):
return
yield delta_ans, 0
@ -475,14 +333,14 @@ class Agent(LLM, ToolBase):
hist.append({"role": "user", "content": content})
st = timer()
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
for _ in range(self._param.max_rounds + 1):
if self.check_if_canceled("Agent streaming"):
return
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
# self.callback("next_step", {}, str(response)[:256]+"...")
token_count += tk
token_count += tk or 0
hist.append({"role": "assistant", "content": response})
try:
functions = json_repair.loads(re.sub(r"```.*", "", response))
@ -491,23 +349,24 @@ class Agent(LLM, ToolBase):
for f in functions:
if not isinstance(f, dict):
raise TypeError(f"An object type should be returned, but `{f}`")
with ThreadPoolExecutor(max_workers=5) as executor:
thr = []
for func in functions:
name = func["name"]
args = func["arguments"]
if name == COMPLETE_TASK:
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
for txt, tkcnt in complete():
yield txt, tkcnt
return
thr.append(executor.submit(use_tool, name, args))
tool_tasks = []
for func in functions:
name = func["name"]
args = func["arguments"]
if name == COMPLETE_TASK:
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
async for txt, tkcnt in complete():
yield txt, tkcnt
return
st = timer()
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt)
append_user_content(hist, reflection)
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
st = timer()
reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
append_user_content(hist, reflection)
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
@ -531,21 +390,17 @@ Respond immediately with your final comprehensive answer.
return
append_user_content(hist, final_instruction)
for txt, tkcnt in complete():
async for txt, tkcnt in complete():
yield txt, tkcnt
def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str:
# self.callback("get_useful_memory", {"topn": 3}, "...")
mems = self._canvas.get_memory()
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt)
try:
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
mems = [mems[r] for r in rank]
return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems])
except Exception as e:
logging.exception(e)
return "Error occurred."
async def _gen_citations_async(self, text):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
async for delta_ans in self._generate_streamly_async([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
{"role": "user", "content": text}
]):
yield delta_ans
def reset(self, only_output=False):
"""

View File

@ -327,7 +327,7 @@ class LLM(ComponentBase):
self.set_output("content", answer)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("LLM processing"):
return
@ -338,22 +338,25 @@ class LLM(ComponentBase):
prompt, msg, _ = self._prepare_prompt_variables()
error: str = ""
output_structure=None
output_structure = None
try:
output_structure = self._param.outputs['structured']
output_structure = self._param.outputs["structured"]
except Exception:
pass
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0:
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
prompt += structured_output_prompt(schema)
for _ in range(self._param.max_retries+1):
schema = json.dumps(output_structure, ensure_ascii=False, indent=2)
prompt_with_schema = prompt + structured_output_prompt(schema)
for _ in range(self._param.max_retries + 1):
if self.check_if_canceled("LLM processing"):
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
_, msg_fit = message_fit_in(
[{"role": "system", "content": prompt_with_schema}, *deepcopy(msg)],
int(self.chat_mdl.max_length * 0.97),
)
error = ""
ans = self._generate(msg)
msg.pop(0)
ans = await self._generate_async(msg_fit)
msg_fit.pop(0)
if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}")
error = ans
@ -362,7 +365,7 @@ class LLM(ComponentBase):
self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
return
except Exception:
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
msg_fit.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
error = "The answer can't not be parsed as JSON"
if error:
self.set_output("_ERROR", error)
@ -370,18 +373,23 @@ class LLM(ComponentBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
self.set_output("content", partial(self._stream_output_async, prompt, msg))
if any([self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams]) and not (
ex and ex["goto"]
):
self.set_output("content", partial(self._stream_output_async, prompt, deepcopy(msg)))
return
for _ in range(self._param.max_retries+1):
error = ""
for _ in range(self._param.max_retries + 1):
if self.check_if_canceled("LLM processing"):
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
_, msg_fit = message_fit_in(
[{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97)
)
error = ""
ans = self._generate(msg)
msg.pop(0)
ans = await self._generate_async(msg_fit)
msg_fit.pop(0)
if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}")
error = ans
@ -395,23 +403,9 @@ class LLM(ComponentBase):
else:
self.set_output("_ERROR", error)
def _stream_output(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = ""
for ans in self._generate_streamly(msg):
if self.check_if_canceled("LLM streaming"):
return
if ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", ans)
return
yield ans
answer += ans
self.set_output("content", answer)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)

View File

@ -49,16 +49,19 @@ class LLMToolPluginCallSession(ToolCallSession):
self.callback = callback
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
return asyncio.run(self.tool_call_async(name, arguments))
async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any:
assert name in self.tools_map, f"LLM tool {name} does not exist"
st = timer()
tool_obj = self.tools_map[name]
if isinstance(tool_obj, MCPToolCallSession):
resp = tool_obj.tool_call(name, arguments, 60)
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
else:
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
resp = asyncio.run(tool_obj.invoke_async(**arguments))
resp = await tool_obj.invoke_async(**arguments)
else:
resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments))
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
self.callback(name, arguments, resp, elapsed_time=timer()-st)
return resp

View File

@ -343,7 +343,8 @@ def form_history(history, limit=-6):
return context
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
tools_desc = tool_schema(tools_description)
context = ""
@ -352,7 +353,7 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
else:
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
kwd = chat_mdl.chat(context, [{"role": "user", "content": "Please analyze it."}])
kwd = await _chat_async(chat_mdl, context, [{"role": "user", "content": "Please analyze it."}])
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
@ -361,13 +362,17 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
return kwd
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
return await asyncio.to_thread(analyze_task, chat_mdl, prompt, task_name, tools_description, user_defined_prompts)
async def _chat_async(chat_mdl, system: str, history: list, **kwargs):
chat_async = getattr(chat_mdl, "async_chat", None)
if chat_async and asyncio.iscoroutinefunction(chat_async):
return await chat_async(system, history, **kwargs)
return await asyncio.to_thread(chat_mdl.chat, system, history, **kwargs)
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
if not tools_description:
return ""
return "", 0
desc = tool_schema(tools_description)
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP))
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
@ -376,18 +381,18 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc,
hist[-1]["content"] += user_prompt
else:
hist.append({"role": "user", "content": user_prompt})
json_str = chat_mdl.chat(template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
hist[1:], stop=["<|stop|>"])
json_str = await _chat_async(
chat_mdl,
template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
hist[1:],
stop=["<|stop|>"],
)
tk_cnt = num_tokens_from_string(json_str)
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
return json_str, tk_cnt
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
return await asyncio.to_thread(next_step, chat_mdl, history, tools_description, task_desc, user_defined_prompts)
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"]
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
@ -398,7 +403,7 @@ def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defi
else:
hist.append({"role": "user", "content": user_prompt})
_, msg = message_fit_in(hist, chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return """
**Observation**
@ -429,23 +434,15 @@ def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defin
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
return await asyncio.to_thread(reflect, chat_mdl, history, tool_call_res, user_defined_prompts)
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
return await asyncio.to_thread(rank_memories, chat_mdl, goal, sub_goal, tool_call_summaries, user_defined_prompts)
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
meta_data_structure = {}
for key, values in meta_data.items():
@ -514,7 +511,7 @@ def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
The response should be in the following JSON format:
The response should be in the following JSON format:
[
{
"structure": <structure index, "x.x.x" or None> (string),
@ -641,8 +638,8 @@ def toc_transformer(toc_pages, chat_mdl):
The `structure` is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
The `title` is a short phrase or a several-words term.
The response should be in the following JSON format:
The response should be in the following JSON format:
[
{
"structure": <structure index, "x.x.x" or None> (string),
@ -667,7 +664,7 @@ def toc_transformer(toc_pages, chat_mdl):
while not (if_complete == "yes"):
prompt = f"""
Your task is to continue the table of contents json structure, directly output the remaining part of the json structure.
The response should be in the following JSON format:
The response should be in the following JSON format:
The raw table of contents json structure is:
{toc_content}
@ -756,7 +753,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
for chunk in chunks_res:
titles.extend(chunk.get("toc", []))
# Filter out entries with title == -1
prune = len(titles) > 512
max_len = 12 if prune else 22