Refa: add result to callback for agent tool use. (#9137)

### What problem does this PR solve?


### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu
2025-08-01 21:49:39 +08:00
committed by GitHub
parent c5823a33a3
commit a16cd4f110
26 changed files with 10875 additions and 897 deletions

View File

@ -22,6 +22,8 @@ from typing import Any
import json_repair
from copy import deepcopy
from functools import partial
from api.db import LLMType
from api.db.services.llm_service import LLMBundle, TenantLLMService
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
@ -49,27 +51,33 @@ class LLMParam(ComponentParamBase):
self.visual_files_var = None
def check(self):
self.check_decimal_float(self.temperature, "[Agent] Temperature")
self.check_decimal_float(self.presence_penalty, "[Agent] Presence penalty")
self.check_decimal_float(self.frequency_penalty, "[Agent] Frequency penalty")
self.check_nonnegative_number(self.max_tokens, "[Agent] Max tokens")
self.check_decimal_float(self.top_p, "[Agent] Top P")
self.check_decimal_float(float(self.temperature), "[Agent] Temperature")
self.check_decimal_float(float(self.presence_penalty), "[Agent] Presence penalty")
self.check_decimal_float(float(self.frequency_penalty), "[Agent] Frequency penalty")
self.check_nonnegative_number(int(self.max_tokens), "[Agent] Max tokens")
self.check_decimal_float(float(self.top_p), "[Agent] Top P")
self.check_empty(self.llm_id, "[Agent] LLM")
self.check_empty(self.sys_prompt, "[Agent] System prompt")
self.check_empty(self.prompts, "[Agent] User prompt")
def gen_conf(self):
conf = {}
if self.max_tokens > 0:
conf["max_tokens"] = self.max_tokens
if self.temperature > 0:
conf["temperature"] = self.temperature
if self.top_p > 0:
conf["top_p"] = self.top_p
if self.presence_penalty > 0:
conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0:
conf["frequency_penalty"] = self.frequency_penalty
def get_attr(nm):
try:
return getattr(self, nm)
except Exception:
pass
if int(self.max_tokens) > 0 and get_attr("maxTokensEnabled"):
conf["max_tokens"] = int(self.max_tokens)
if float(self.temperature) > 0 and get_attr("temperatureEnabled"):
conf["temperature"] = float(self.temperature)
if float(self.top_p) > 0 and get_attr("topPEnabled"):
conf["top_p"] = float(self.top_p)
if float(self.presence_penalty) > 0 and get_attr("presencePenaltyEnabled"):
conf["presence_penalty"] = float(self.presence_penalty)
if float(self.frequency_penalty) > 0 and get_attr("frequencyPenaltyEnabled"):
conf["frequency_penalty"] = float(self.frequency_penalty)
return conf
@ -112,6 +120,12 @@ class LLM(ComponentBase):
if not self.imgs:
self.imgs = []
self.imgs = [img for img in self.imgs if img[:len("data:image/")] == "data:image/"]
if self.imgs and TenantLLMService.llm_id2llm_type(self._param.llm_id) == LLMType.CHAT.value:
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT.value,
self._param.llm_id, max_retries=self._param.max_retries,
retry_interval=self._param.delay_after_error
)
args = {}
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
@ -207,7 +221,8 @@ class LLM(ComponentBase):
return
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure:
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
self.set_output("content", partial(self._stream_output, prompt, msg))
return
@ -224,14 +239,22 @@ class LLM(ComponentBase):
break
if error:
self.set_output("_ERROR", error)
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
else:
self.set_output("_ERROR", error)
def _stream_output(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = ""
for ans in self._generate_streamly(msg):
if ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", ans)
return
yield ans
answer += ans
self.set_output("content", answer)
@ -243,4 +266,4 @@ class LLM(ComponentBase):
def thoughts(self) -> str:
_, msg = self._prepare_prompt_variables()
return f"Im thinking and planning the next move, starting from the prompt:<br/>“{msg[-1]['content']}”<span class=\"collapse\"> (tap to see full text)</span>"
return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nIll figure out our best next move."