mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 04:22:28 +08:00
Refa: cleanup synchronous functions in agent_with_tools (#11736)
### What problem does this PR solve? Cleanup synchronous functions in agent_with_tools. ### Type of change - [x] Refactoring
This commit is contained in:
@ -18,7 +18,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
@ -30,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from common.connection_utils import timeout
|
||||
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
|
||||
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
|
||||
from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
|
||||
citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
|
||||
@ -154,96 +153,19 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
return None
|
||||
|
||||
def _force_format_to_schema(self, text: str, schema_prompt: str) -> str:
|
||||
async def _force_format_to_schema_async(self, text: str, schema_prompt: str) -> str:
|
||||
fmt_msgs = [
|
||||
{"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."},
|
||||
{"role": "user", "content": text},
|
||||
]
|
||||
_, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97))
|
||||
return self._generate(fmt_msgs)
|
||||
return await self._generate_async(fmt_msgs)
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
|
||||
if kwargs.get("user_prompt"):
|
||||
usr_pmt = ""
|
||||
if kwargs.get("reasoning"):
|
||||
usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
|
||||
if kwargs.get("context"):
|
||||
usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
|
||||
if usr_pmt:
|
||||
usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
|
||||
else:
|
||||
usr_pmt = str(kwargs["user_prompt"])
|
||||
self._param.prompts = [{"role": "user", "content": usr_pmt}]
|
||||
|
||||
if not self.tools:
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
return LLM._invoke(self, **kwargs)
|
||||
|
||||
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
||||
output_schema = self._get_output_schema()
|
||||
schema_prompt = ""
|
||||
if output_schema:
|
||||
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
|
||||
schema_prompt = structured_output_prompt(schema)
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
|
||||
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
use_tools = []
|
||||
ans = ""
|
||||
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
ans += delta_ans
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"Agent._chat got error. response: {ans}")
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
else:
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
|
||||
if output_schema:
|
||||
error = ""
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
try:
|
||||
def clean_formated_answer(ans: str) -> str:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
|
||||
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||
obj = json_repair.loads(clean_formated_answer(ans))
|
||||
self.set_output("structured", obj)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
return obj
|
||||
except Exception:
|
||||
error = "The answer cannot be parsed as JSON"
|
||||
ans = self._force_format_to_schema(ans, schema_prompt)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
continue
|
||||
|
||||
self.set_output("_ERROR", error)
|
||||
return
|
||||
|
||||
self.set_output("content", ans)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
return ans
|
||||
|
||||
async def _invoke_async(self, **kwargs):
|
||||
"""
|
||||
Async entry: reuse existing logic but offload heavy sync parts via async wrappers to reduce blocking.
|
||||
"""
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
|
||||
@ -262,7 +184,7 @@ class Agent(LLM, ToolBase):
|
||||
if not self.tools:
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
return await asyncio.to_thread(LLM._invoke, self, **kwargs)
|
||||
return await LLM._invoke_async(self, **kwargs)
|
||||
|
||||
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
||||
output_schema = self._get_output_schema()
|
||||
@ -274,13 +196,13 @@ class Agent(LLM, ToolBase):
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
|
||||
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, msg, user_defined_prompt))
|
||||
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt))
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
use_tools = []
|
||||
ans = ""
|
||||
async for delta_ans, tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
|
||||
async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
ans += delta_ans
|
||||
@ -308,7 +230,7 @@ class Agent(LLM, ToolBase):
|
||||
return obj
|
||||
except Exception:
|
||||
error = "The answer cannot be parsed as JSON"
|
||||
ans = self._force_format_to_schema(ans, schema_prompt)
|
||||
ans = await self._force_format_to_schema_async(ans, schema_prompt)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
continue
|
||||
|
||||
@ -320,28 +242,6 @@ class Agent(LLM, ToolBase):
|
||||
self.set_output("use_tools", use_tools)
|
||||
return ans
|
||||
|
||||
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer_without_toolcall = ""
|
||||
use_tools = []
|
||||
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
|
||||
if delta_ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", delta_ans)
|
||||
return
|
||||
answer_without_toolcall += delta_ans
|
||||
yield delta_ans
|
||||
|
||||
self.set_output("content", answer_without_toolcall)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
|
||||
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer_without_toolcall = ""
|
||||
@ -365,64 +265,22 @@ class Agent(LLM, ToolBase):
|
||||
self.set_output("use_tools", use_tools)
|
||||
|
||||
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||
"""
|
||||
Async wrapper that offloads synchronous flow to a thread, yielding results without blocking the event loop.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def worker():
|
||||
try:
|
||||
for delta_ans, tk in self._react_with_tools_streamly(prompt, history, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
|
||||
asyncio.run_coroutine_threadsafe(queue.put((delta_ans, tk)), loop)
|
||||
except Exception as e:
|
||||
asyncio.run_coroutine_threadsafe(queue.put(e), loop)
|
||||
finally:
|
||||
asyncio.run_coroutine_threadsafe(queue.put(StopAsyncIteration), loop)
|
||||
|
||||
await asyncio.to_thread(worker)
|
||||
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is StopAsyncIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
def _gen_citations(self, text):
|
||||
retrievals = self._canvas.get_reference()
|
||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
||||
for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
||||
{"role": "user", "content": text}
|
||||
]):
|
||||
yield delta_ans
|
||||
|
||||
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||
token_count = 0
|
||||
tool_metas = self.tool_meta
|
||||
hist = deepcopy(history)
|
||||
last_calling = ""
|
||||
if len(hist) > 3:
|
||||
st = timer()
|
||||
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
|
||||
user_request = await asyncio.to_thread(full_question, messages=history, chat_mdl=self.chat_mdl)
|
||||
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
||||
else:
|
||||
user_request = history[-1]["content"]
|
||||
|
||||
def use_tool(name, args):
|
||||
nonlocal hist, use_tools, token_count,last_calling,user_request
|
||||
async def use_tool_async(name, args):
|
||||
nonlocal hist, use_tools, last_calling
|
||||
logging.info(f"{last_calling=} == {name=}")
|
||||
# Summarize of function calling
|
||||
#if all([
|
||||
# isinstance(self.toolcall_session.get_tool_obj(name), Agent),
|
||||
# last_calling,
|
||||
# last_calling != name
|
||||
#]):
|
||||
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt))
|
||||
last_calling = name
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
tool_response = await self.toolcall_session.tool_call_async(name, args)
|
||||
use_tools.append({
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
@ -433,7 +291,7 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
return name, tool_response
|
||||
|
||||
def complete():
|
||||
async def complete():
|
||||
nonlocal hist
|
||||
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
|
||||
if schema_prompt:
|
||||
@ -451,7 +309,7 @@ class Agent(LLM, ToolBase):
|
||||
if len(hist) > 12:
|
||||
_hist = [hist[0], hist[1], *hist[-10:]]
|
||||
entire_txt = ""
|
||||
for delta_ans in self._generate_streamly(_hist):
|
||||
async for delta_ans in self._generate_streamly_async(_hist):
|
||||
if not need2cite or cited:
|
||||
yield delta_ans, 0
|
||||
entire_txt += delta_ans
|
||||
@ -460,7 +318,7 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
st = timer()
|
||||
txt = ""
|
||||
for delta_ans in self._gen_citations(entire_txt):
|
||||
async for delta_ans in self._gen_citations_async(entire_txt):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
yield delta_ans, 0
|
||||
@ -475,14 +333,14 @@ class Agent(LLM, ToolBase):
|
||||
hist.append({"role": "user", "content": content})
|
||||
|
||||
st = timer()
|
||||
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||
task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||
for _ in range(self._param.max_rounds + 1):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||
response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||
# self.callback("next_step", {}, str(response)[:256]+"...")
|
||||
token_count += tk
|
||||
token_count += tk or 0
|
||||
hist.append({"role": "assistant", "content": response})
|
||||
try:
|
||||
functions = json_repair.loads(re.sub(r"```.*", "", response))
|
||||
@ -491,23 +349,24 @@ class Agent(LLM, ToolBase):
|
||||
for f in functions:
|
||||
if not isinstance(f, dict):
|
||||
raise TypeError(f"An object type should be returned, but `{f}`")
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
thr = []
|
||||
for func in functions:
|
||||
name = func["name"]
|
||||
args = func["arguments"]
|
||||
if name == COMPLETE_TASK:
|
||||
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
||||
for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
return
|
||||
|
||||
thr.append(executor.submit(use_tool, name, args))
|
||||
tool_tasks = []
|
||||
for func in functions:
|
||||
name = func["name"]
|
||||
args = func["arguments"]
|
||||
if name == COMPLETE_TASK:
|
||||
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
||||
async for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
return
|
||||
|
||||
st = timer()
|
||||
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt)
|
||||
append_user_content(hist, reflection)
|
||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||
tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
|
||||
|
||||
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
|
||||
st = timer()
|
||||
reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
|
||||
append_user_content(hist, reflection)
|
||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
|
||||
@ -531,21 +390,17 @@ Respond immediately with your final comprehensive answer.
|
||||
return
|
||||
append_user_content(hist, final_instruction)
|
||||
|
||||
for txt, tkcnt in complete():
|
||||
async for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
|
||||
def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str:
|
||||
# self.callback("get_useful_memory", {"topn": 3}, "...")
|
||||
mems = self._canvas.get_memory()
|
||||
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt)
|
||||
try:
|
||||
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
|
||||
mems = [mems[r] for r in rank]
|
||||
return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
return "Error occurred."
|
||||
async def _gen_citations_async(self, text):
|
||||
retrievals = self._canvas.get_reference()
|
||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
||||
async for delta_ans in self._generate_streamly_async([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
||||
{"role": "user", "content": text}
|
||||
]):
|
||||
yield delta_ans
|
||||
|
||||
def reset(self, only_output=False):
|
||||
"""
|
||||
|
||||
@ -327,7 +327,7 @@ class LLM(ComponentBase):
|
||||
self.set_output("content", answer)
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
@ -338,22 +338,25 @@ class LLM(ComponentBase):
|
||||
|
||||
prompt, msg, _ = self._prepare_prompt_variables()
|
||||
error: str = ""
|
||||
output_structure=None
|
||||
output_structure = None
|
||||
try:
|
||||
output_structure = self._param.outputs['structured']
|
||||
output_structure = self._param.outputs["structured"]
|
||||
except Exception:
|
||||
pass
|
||||
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0:
|
||||
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||
prompt += structured_output_prompt(schema)
|
||||
for _ in range(self._param.max_retries+1):
|
||||
schema = json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||
prompt_with_schema = prompt + structured_output_prompt(schema)
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
_, msg_fit = message_fit_in(
|
||||
[{"role": "system", "content": prompt_with_schema}, *deepcopy(msg)],
|
||||
int(self.chat_mdl.max_length * 0.97),
|
||||
)
|
||||
error = ""
|
||||
ans = self._generate(msg)
|
||||
msg.pop(0)
|
||||
ans = await self._generate_async(msg_fit)
|
||||
msg_fit.pop(0)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"LLM response error: {ans}")
|
||||
error = ans
|
||||
@ -362,7 +365,7 @@ class LLM(ComponentBase):
|
||||
self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
|
||||
return
|
||||
except Exception:
|
||||
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
|
||||
msg_fit.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
|
||||
error = "The answer can't not be parsed as JSON"
|
||||
if error:
|
||||
self.set_output("_ERROR", error)
|
||||
@ -370,18 +373,23 @@ class LLM(ComponentBase):
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self._stream_output_async, prompt, msg))
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams]) and not (
|
||||
ex and ex["goto"]
|
||||
):
|
||||
self.set_output("content", partial(self._stream_output_async, prompt, deepcopy(msg)))
|
||||
return
|
||||
|
||||
for _ in range(self._param.max_retries+1):
|
||||
error = ""
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
_, msg_fit = message_fit_in(
|
||||
[{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97)
|
||||
)
|
||||
error = ""
|
||||
ans = self._generate(msg)
|
||||
msg.pop(0)
|
||||
ans = await self._generate_async(msg_fit)
|
||||
msg_fit.pop(0)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"LLM response error: {ans}")
|
||||
error = ans
|
||||
@ -395,23 +403,9 @@ class LLM(ComponentBase):
|
||||
else:
|
||||
self.set_output("_ERROR", error)
|
||||
|
||||
def _stream_output(self, prompt, msg):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer = ""
|
||||
for ans in self._generate_streamly(msg):
|
||||
if self.check_if_canceled("LLM streaming"):
|
||||
return
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
yield ans
|
||||
answer += ans
|
||||
self.set_output("content", answer)
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
|
||||
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
|
||||
|
||||
@ -49,16 +49,19 @@ class LLMToolPluginCallSession(ToolCallSession):
|
||||
self.callback = callback
|
||||
|
||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||
return asyncio.run(self.tool_call_async(name, arguments))
|
||||
|
||||
async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
||||
st = timer()
|
||||
tool_obj = self.tools_map[name]
|
||||
if isinstance(tool_obj, MCPToolCallSession):
|
||||
resp = tool_obj.tool_call(name, arguments, 60)
|
||||
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
|
||||
else:
|
||||
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
|
||||
resp = asyncio.run(tool_obj.invoke_async(**arguments))
|
||||
resp = await tool_obj.invoke_async(**arguments)
|
||||
else:
|
||||
resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments))
|
||||
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
|
||||
|
||||
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
||||
return resp
|
||||
|
||||
@ -343,7 +343,8 @@ def form_history(history, limit=-6):
|
||||
return context
|
||||
|
||||
|
||||
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
|
||||
|
||||
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
|
||||
tools_desc = tool_schema(tools_description)
|
||||
context = ""
|
||||
|
||||
@ -352,7 +353,7 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
|
||||
else:
|
||||
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
|
||||
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
|
||||
kwd = chat_mdl.chat(context, [{"role": "user", "content": "Please analyze it."}])
|
||||
kwd = await _chat_async(chat_mdl, context, [{"role": "user", "content": "Please analyze it."}])
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
@ -361,13 +362,17 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
|
||||
return kwd
|
||||
|
||||
|
||||
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
|
||||
return await asyncio.to_thread(analyze_task, chat_mdl, prompt, task_name, tools_description, user_defined_prompts)
|
||||
async def _chat_async(chat_mdl, system: str, history: list, **kwargs):
|
||||
chat_async = getattr(chat_mdl, "async_chat", None)
|
||||
if chat_async and asyncio.iscoroutinefunction(chat_async):
|
||||
return await chat_async(system, history, **kwargs)
|
||||
return await asyncio.to_thread(chat_mdl.chat, system, history, **kwargs)
|
||||
|
||||
|
||||
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||
|
||||
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||
if not tools_description:
|
||||
return ""
|
||||
return "", 0
|
||||
desc = tool_schema(tools_description)
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP))
|
||||
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
|
||||
@ -376,18 +381,18 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc,
|
||||
hist[-1]["content"] += user_prompt
|
||||
else:
|
||||
hist.append({"role": "user", "content": user_prompt})
|
||||
json_str = chat_mdl.chat(template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
|
||||
hist[1:], stop=["<|stop|>"])
|
||||
json_str = await _chat_async(
|
||||
chat_mdl,
|
||||
template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
|
||||
hist[1:],
|
||||
stop=["<|stop|>"],
|
||||
)
|
||||
tk_cnt = num_tokens_from_string(json_str)
|
||||
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
|
||||
return json_str, tk_cnt
|
||||
|
||||
|
||||
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||
return await asyncio.to_thread(next_step, chat_mdl, history, tools_description, task_desc, user_defined_prompts)
|
||||
|
||||
|
||||
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
||||
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
||||
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
|
||||
goal = history[1]["content"]
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
|
||||
@ -398,7 +403,7 @@ def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defi
|
||||
else:
|
||||
hist.append({"role": "user", "content": user_prompt})
|
||||
_, msg = message_fit_in(hist, chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
|
||||
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:])
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
return """
|
||||
**Observation**
|
||||
@ -429,23 +434,15 @@ def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defin
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
|
||||
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
|
||||
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
|
||||
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
|
||||
user_prompt = " → rank: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
|
||||
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:], stop="<|stop|>")
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
||||
return await asyncio.to_thread(reflect, chat_mdl, history, tool_call_res, user_defined_prompts)
|
||||
|
||||
|
||||
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
|
||||
return await asyncio.to_thread(rank_memories, chat_mdl, goal, sub_goal, tool_call_summaries, user_defined_prompts)
|
||||
|
||||
|
||||
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
|
||||
meta_data_structure = {}
|
||||
for key, values in meta_data.items():
|
||||
@ -514,7 +511,7 @@ def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
|
||||
|
||||
The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
|
||||
|
||||
The response should be in the following JSON format:
|
||||
The response should be in the following JSON format:
|
||||
[
|
||||
{
|
||||
"structure": <structure index, "x.x.x" or None> (string),
|
||||
@ -641,8 +638,8 @@ def toc_transformer(toc_pages, chat_mdl):
|
||||
|
||||
The `structure` is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
|
||||
The `title` is a short phrase or a several-words term.
|
||||
|
||||
The response should be in the following JSON format:
|
||||
|
||||
The response should be in the following JSON format:
|
||||
[
|
||||
{
|
||||
"structure": <structure index, "x.x.x" or None> (string),
|
||||
@ -667,7 +664,7 @@ def toc_transformer(toc_pages, chat_mdl):
|
||||
while not (if_complete == "yes"):
|
||||
prompt = f"""
|
||||
Your task is to continue the table of contents json structure, directly output the remaining part of the json structure.
|
||||
The response should be in the following JSON format:
|
||||
The response should be in the following JSON format:
|
||||
|
||||
The raw table of contents json structure is:
|
||||
{toc_content}
|
||||
@ -756,7 +753,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
||||
|
||||
for chunk in chunks_res:
|
||||
titles.extend(chunk.get("toc", []))
|
||||
|
||||
|
||||
# Filter out entries with title == -1
|
||||
prune = len(titles) > 512
|
||||
max_len = 12 if prune else 22
|
||||
|
||||
Reference in New Issue
Block a user