Refa: add result to callback for agent tool use. (#9137)

### What problem does this PR solve?


### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu
2025-08-01 21:49:39 +08:00
committed by GitHub
parent c5823a33a3
commit a16cd4f110
26 changed files with 10875 additions and 897 deletions

View File

@ -157,7 +157,8 @@ class Agent(LLM, ToolBase):
prompt, msg = self._prepare_prompt_variables()
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure:
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg))
return
@ -169,7 +170,10 @@ class Agent(LLM, ToolBase):
if ans.find("**ERROR**") >= 0:
logging.error(f"Agent._chat got error. response: {ans}")
self.set_output("_ERROR", ans)
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
else:
self.set_output("_ERROR", ans)
return
self.set_output("content", ans)
@ -182,6 +186,12 @@ class Agent(LLM, ToolBase):
answer_without_toolcall = ""
use_tools = []
for delta_ans,_ in self._react_with_tools_streamly(msg, use_tools):
if delta_ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", delta_ans)
answer_without_toolcall += delta_ans
yield delta_ans
@ -204,8 +214,8 @@ class Agent(LLM, ToolBase):
hist = deepcopy(history)
last_calling = ""
if len(hist) > 3:
self.callback("Multi-turn conversation optimization", {}, " running ...")
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
self.callback("Multi-turn conversation optimization", {}, user_request)
else:
user_request = history[-1]["content"]
@ -241,9 +251,6 @@ class Agent(LLM, ToolBase):
cited = True
yield "", token_count
if not cited and need2cite:
self.callback("gen_citations", {}, " running ...")
_hist = hist
if len(hist) > 12:
_hist = [hist[0], hist[1], *hist[-10:]]
@ -255,8 +262,12 @@ class Agent(LLM, ToolBase):
if not need2cite or cited:
return
txt = ""
for delta_ans in self._gen_citations(entire_txt):
yield delta_ans, 0
txt += delta_ans
self.callback("gen_citations", {}, txt)
def append_user_content(hist, content):
if hist[-1]["role"] == "user":
@ -264,8 +275,8 @@ class Agent(LLM, ToolBase):
else:
hist.append({"role": "user", "content": content})
self.callback("analyze_task", {}, " running ...")
task_desc = analyze_task(self.chat_mdl, user_request, tool_metas)
self.callback("analyze_task", {}, task_desc)
for _ in range(self._param.max_rounds + 1):
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc)
# self.callback("next_step", {}, str(response)[:256]+"...")

View File

@ -44,7 +44,6 @@ class ComponentParamBase(ABC):
self.delay_after_error = 2.0
self.exception_method = None
self.exception_default_value = None
self.exception_comment = None
self.exception_goto = None
self.debug_inputs = {}
@ -97,6 +96,14 @@ class ComponentParamBase(ABC):
def as_dict(self):
def _recursive_convert_obj_to_dict(obj):
ret_dict = {}
if isinstance(obj, dict):
for k,v in obj.items():
if isinstance(v, dict) or (v and type(v).__name__ not in dir(builtins)):
ret_dict[k] = _recursive_convert_obj_to_dict(v)
else:
ret_dict[k] = v
return ret_dict
for attr_name in list(obj.__dict__):
if attr_name in [_FEEDED_DEPRECATED_PARAMS, _DEPRECATED_PARAMS, _USER_FEEDED_PARAMS, _IS_RAW_CONF]:
continue
@ -105,7 +112,7 @@ class ComponentParamBase(ABC):
if isinstance(attr, pd.DataFrame):
ret_dict[attr_name] = attr.to_dict()
continue
if attr and type(attr).__name__ not in dir(builtins):
if isinstance(attr, dict) or (attr and type(attr).__name__ not in dir(builtins)):
ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr)
else:
ret_dict[attr_name] = attr
@ -415,7 +422,10 @@ class ComponentBase(ABC):
try:
self._invoke(**kwargs)
except Exception as e:
self._param.outputs["_ERROR"] = {"value": str(e)}
if self.get_exception_default_value():
self.set_exception_default_value()
else:
self.set_output("_ERROR", str(e))
logging.exception(e)
self._param.debug_inputs = {}
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
@ -427,7 +437,7 @@ class ComponentBase(ABC):
def output(self, var_nm: str=None) -> Union[dict[str, Any], Any]:
if var_nm:
return self._param.outputs.get(var_nm, {}).get("value")
return self._param.outputs.get(var_nm, {}).get("value", "")
return {k: o.get("value") for k,o in self._param.outputs.items()}
def set_output(self, key: str, value: Any):
@ -520,7 +530,7 @@ class ComponentBase(ABC):
def string_format(content: str, kv: dict[str, str]) -> str:
for n, v in kv.items():
content = re.sub(
r"\{%s\}" % re.escape(n), re.escape(v), content
r"\{%s\}" % re.escape(n), v, content
)
return content
@ -529,13 +539,17 @@ class ComponentBase(ABC):
return
return {
"goto": self._param.exception_goto,
"comment": self._param.exception_comment,
"default_value": self._param.exception_default_value
}
def get_exception_default_value(self):
if self._param.exception_method != "comment":
return ""
return self._param.exception_default_value
def set_exception_default_value(self):
self.set_output("result", self.get_exception_default_value())
@abstractmethod
def thoughts(self) -> str:
...

View File

@ -46,4 +46,4 @@ class Begin(UserFillUp):
self.set_input_value(k, v)
def thoughts(self) -> str:
return "☕ Here we go..."
return ""

View File

@ -22,6 +22,8 @@ from typing import Any
import json_repair
from copy import deepcopy
from functools import partial
from api.db import LLMType
from api.db.services.llm_service import LLMBundle, TenantLLMService
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
@ -49,27 +51,33 @@ class LLMParam(ComponentParamBase):
self.visual_files_var = None
def check(self):
self.check_decimal_float(self.temperature, "[Agent] Temperature")
self.check_decimal_float(self.presence_penalty, "[Agent] Presence penalty")
self.check_decimal_float(self.frequency_penalty, "[Agent] Frequency penalty")
self.check_nonnegative_number(self.max_tokens, "[Agent] Max tokens")
self.check_decimal_float(self.top_p, "[Agent] Top P")
self.check_decimal_float(float(self.temperature), "[Agent] Temperature")
self.check_decimal_float(float(self.presence_penalty), "[Agent] Presence penalty")
self.check_decimal_float(float(self.frequency_penalty), "[Agent] Frequency penalty")
self.check_nonnegative_number(int(self.max_tokens), "[Agent] Max tokens")
self.check_decimal_float(float(self.top_p), "[Agent] Top P")
self.check_empty(self.llm_id, "[Agent] LLM")
self.check_empty(self.sys_prompt, "[Agent] System prompt")
self.check_empty(self.prompts, "[Agent] User prompt")
def gen_conf(self):
conf = {}
if self.max_tokens > 0:
conf["max_tokens"] = self.max_tokens
if self.temperature > 0:
conf["temperature"] = self.temperature
if self.top_p > 0:
conf["top_p"] = self.top_p
if self.presence_penalty > 0:
conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0:
conf["frequency_penalty"] = self.frequency_penalty
def get_attr(nm):
try:
return getattr(self, nm)
except Exception:
pass
if int(self.max_tokens) > 0 and get_attr("maxTokensEnabled"):
conf["max_tokens"] = int(self.max_tokens)
if float(self.temperature) > 0 and get_attr("temperatureEnabled"):
conf["temperature"] = float(self.temperature)
if float(self.top_p) > 0 and get_attr("topPEnabled"):
conf["top_p"] = float(self.top_p)
if float(self.presence_penalty) > 0 and get_attr("presencePenaltyEnabled"):
conf["presence_penalty"] = float(self.presence_penalty)
if float(self.frequency_penalty) > 0 and get_attr("frequencyPenaltyEnabled"):
conf["frequency_penalty"] = float(self.frequency_penalty)
return conf
@ -112,6 +120,12 @@ class LLM(ComponentBase):
if not self.imgs:
self.imgs = []
self.imgs = [img for img in self.imgs if img[:len("data:image/")] == "data:image/"]
if self.imgs and TenantLLMService.llm_id2llm_type(self._param.llm_id) == LLMType.CHAT.value:
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT.value,
self._param.llm_id, max_retries=self._param.max_retries,
retry_interval=self._param.delay_after_error
)
args = {}
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
@ -207,7 +221,8 @@ class LLM(ComponentBase):
return
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure:
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
self.set_output("content", partial(self._stream_output, prompt, msg))
return
@ -224,14 +239,22 @@ class LLM(ComponentBase):
break
if error:
self.set_output("_ERROR", error)
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
else:
self.set_output("_ERROR", error)
def _stream_output(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = ""
for ans in self._generate_streamly(msg):
if ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", ans)
return
yield ans
answer += ans
self.set_output("content", answer)
@ -243,4 +266,4 @@ class LLM(ComponentBase):
def thoughts(self) -> str:
_, msg = self._prepare_prompt_variables()
return f"Im thinking and planning the next move, starting from the prompt:<br/>“{msg[-1]['content']}”<span class=\"collapse\"> (tap to see full text)</span>"
return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nIll figure out our best next move."

View File

@ -143,4 +143,4 @@ class Message(ComponentBase):
self.set_output("content", content)
def thoughts(self) -> str:
return "Thinking ..."
return ""