Refa: cleanup synchronous functions in agent_with_tools (#11736)

### What problem does this PR solve?

Cleanup synchronous functions in agent_with_tools.

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-12-04 14:15:05 +08:00
committed by GitHub
parent 797e03f843
commit 27b0550876
4 changed files with 105 additions and 256 deletions

View File

@ -327,7 +327,7 @@ class LLM(ComponentBase):
self.set_output("content", answer)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("LLM processing"):
return
@ -338,22 +338,25 @@ class LLM(ComponentBase):
prompt, msg, _ = self._prepare_prompt_variables()
error: str = ""
output_structure=None
output_structure = None
try:
output_structure = self._param.outputs['structured']
output_structure = self._param.outputs["structured"]
except Exception:
pass
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0:
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
prompt += structured_output_prompt(schema)
for _ in range(self._param.max_retries+1):
schema = json.dumps(output_structure, ensure_ascii=False, indent=2)
prompt_with_schema = prompt + structured_output_prompt(schema)
for _ in range(self._param.max_retries + 1):
if self.check_if_canceled("LLM processing"):
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
_, msg_fit = message_fit_in(
[{"role": "system", "content": prompt_with_schema}, *deepcopy(msg)],
int(self.chat_mdl.max_length * 0.97),
)
error = ""
ans = self._generate(msg)
msg.pop(0)
ans = await self._generate_async(msg_fit)
msg_fit.pop(0)
if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}")
error = ans
@ -362,7 +365,7 @@ class LLM(ComponentBase):
self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
return
except Exception:
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
msg_fit.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
error = "The answer can't not be parsed as JSON"
if error:
self.set_output("_ERROR", error)
@ -370,18 +373,23 @@ class LLM(ComponentBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
self.set_output("content", partial(self._stream_output_async, prompt, msg))
if any([self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams]) and not (
ex and ex["goto"]
):
self.set_output("content", partial(self._stream_output_async, prompt, deepcopy(msg)))
return
for _ in range(self._param.max_retries+1):
error = ""
for _ in range(self._param.max_retries + 1):
if self.check_if_canceled("LLM processing"):
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
_, msg_fit = message_fit_in(
[{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97)
)
error = ""
ans = self._generate(msg)
msg.pop(0)
ans = await self._generate_async(msg_fit)
msg_fit.pop(0)
if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}")
error = ans
@ -395,23 +403,9 @@ class LLM(ComponentBase):
else:
self.set_output("_ERROR", error)
def _stream_output(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = ""
for ans in self._generate_streamly(msg):
if self.check_if_canceled("LLM streaming"):
return
if ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", ans)
return
yield ans
answer += ans
self.set_output("content", answer)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)