mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: add result to callback for agent tool use. (#9137)
### What problem does this PR solve? ### Type of change - [x] Refactoring
This commit is contained in:
@ -157,7 +157,8 @@ class Agent(LLM, ToolBase):
|
||||
prompt, msg = self._prepare_prompt_variables()
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure:
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg))
|
||||
return
|
||||
|
||||
@ -169,7 +170,10 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"Agent._chat got error. response: {ans}")
|
||||
self.set_output("_ERROR", ans)
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
else:
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
|
||||
self.set_output("content", ans)
|
||||
@ -182,6 +186,12 @@ class Agent(LLM, ToolBase):
|
||||
answer_without_toolcall = ""
|
||||
use_tools = []
|
||||
for delta_ans,_ in self._react_with_tools_streamly(msg, use_tools):
|
||||
if delta_ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", delta_ans)
|
||||
answer_without_toolcall += delta_ans
|
||||
yield delta_ans
|
||||
|
||||
@ -204,8 +214,8 @@ class Agent(LLM, ToolBase):
|
||||
hist = deepcopy(history)
|
||||
last_calling = ""
|
||||
if len(hist) > 3:
|
||||
self.callback("Multi-turn conversation optimization", {}, " running ...")
|
||||
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
|
||||
self.callback("Multi-turn conversation optimization", {}, user_request)
|
||||
else:
|
||||
user_request = history[-1]["content"]
|
||||
|
||||
@ -241,9 +251,6 @@ class Agent(LLM, ToolBase):
|
||||
cited = True
|
||||
yield "", token_count
|
||||
|
||||
if not cited and need2cite:
|
||||
self.callback("gen_citations", {}, " running ...")
|
||||
|
||||
_hist = hist
|
||||
if len(hist) > 12:
|
||||
_hist = [hist[0], hist[1], *hist[-10:]]
|
||||
@ -255,8 +262,12 @@ class Agent(LLM, ToolBase):
|
||||
if not need2cite or cited:
|
||||
return
|
||||
|
||||
txt = ""
|
||||
for delta_ans in self._gen_citations(entire_txt):
|
||||
yield delta_ans, 0
|
||||
txt += delta_ans
|
||||
|
||||
self.callback("gen_citations", {}, txt)
|
||||
|
||||
def append_user_content(hist, content):
|
||||
if hist[-1]["role"] == "user":
|
||||
@ -264,8 +275,8 @@ class Agent(LLM, ToolBase):
|
||||
else:
|
||||
hist.append({"role": "user", "content": content})
|
||||
|
||||
self.callback("analyze_task", {}, " running ...")
|
||||
task_desc = analyze_task(self.chat_mdl, user_request, tool_metas)
|
||||
self.callback("analyze_task", {}, task_desc)
|
||||
for _ in range(self._param.max_rounds + 1):
|
||||
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc)
|
||||
# self.callback("next_step", {}, str(response)[:256]+"...")
|
||||
|
||||
@ -44,7 +44,6 @@ class ComponentParamBase(ABC):
|
||||
self.delay_after_error = 2.0
|
||||
self.exception_method = None
|
||||
self.exception_default_value = None
|
||||
self.exception_comment = None
|
||||
self.exception_goto = None
|
||||
self.debug_inputs = {}
|
||||
|
||||
@ -97,6 +96,14 @@ class ComponentParamBase(ABC):
|
||||
def as_dict(self):
|
||||
def _recursive_convert_obj_to_dict(obj):
|
||||
ret_dict = {}
|
||||
if isinstance(obj, dict):
|
||||
for k,v in obj.items():
|
||||
if isinstance(v, dict) or (v and type(v).__name__ not in dir(builtins)):
|
||||
ret_dict[k] = _recursive_convert_obj_to_dict(v)
|
||||
else:
|
||||
ret_dict[k] = v
|
||||
return ret_dict
|
||||
|
||||
for attr_name in list(obj.__dict__):
|
||||
if attr_name in [_FEEDED_DEPRECATED_PARAMS, _DEPRECATED_PARAMS, _USER_FEEDED_PARAMS, _IS_RAW_CONF]:
|
||||
continue
|
||||
@ -105,7 +112,7 @@ class ComponentParamBase(ABC):
|
||||
if isinstance(attr, pd.DataFrame):
|
||||
ret_dict[attr_name] = attr.to_dict()
|
||||
continue
|
||||
if attr and type(attr).__name__ not in dir(builtins):
|
||||
if isinstance(attr, dict) or (attr and type(attr).__name__ not in dir(builtins)):
|
||||
ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr)
|
||||
else:
|
||||
ret_dict[attr_name] = attr
|
||||
@ -415,7 +422,10 @@ class ComponentBase(ABC):
|
||||
try:
|
||||
self._invoke(**kwargs)
|
||||
except Exception as e:
|
||||
self._param.outputs["_ERROR"] = {"value": str(e)}
|
||||
if self.get_exception_default_value():
|
||||
self.set_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", str(e))
|
||||
logging.exception(e)
|
||||
self._param.debug_inputs = {}
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
@ -427,7 +437,7 @@ class ComponentBase(ABC):
|
||||
|
||||
def output(self, var_nm: str=None) -> Union[dict[str, Any], Any]:
|
||||
if var_nm:
|
||||
return self._param.outputs.get(var_nm, {}).get("value")
|
||||
return self._param.outputs.get(var_nm, {}).get("value", "")
|
||||
return {k: o.get("value") for k,o in self._param.outputs.items()}
|
||||
|
||||
def set_output(self, key: str, value: Any):
|
||||
@ -520,7 +530,7 @@ class ComponentBase(ABC):
|
||||
def string_format(content: str, kv: dict[str, str]) -> str:
|
||||
for n, v in kv.items():
|
||||
content = re.sub(
|
||||
r"\{%s\}" % re.escape(n), re.escape(v), content
|
||||
r"\{%s\}" % re.escape(n), v, content
|
||||
)
|
||||
return content
|
||||
|
||||
@ -529,13 +539,17 @@ class ComponentBase(ABC):
|
||||
return
|
||||
return {
|
||||
"goto": self._param.exception_goto,
|
||||
"comment": self._param.exception_comment,
|
||||
"default_value": self._param.exception_default_value
|
||||
}
|
||||
|
||||
def get_exception_default_value(self):
|
||||
if self._param.exception_method != "comment":
|
||||
return ""
|
||||
return self._param.exception_default_value
|
||||
|
||||
def set_exception_default_value(self):
|
||||
self.set_output("result", self.get_exception_default_value())
|
||||
|
||||
@abstractmethod
|
||||
def thoughts(self) -> str:
|
||||
...
|
||||
|
||||
@ -46,4 +46,4 @@ class Begin(UserFillUp):
|
||||
self.set_input_value(k, v)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "☕ Here we go..."
|
||||
return ""
|
||||
|
||||
@ -22,6 +22,8 @@ from typing import Any
|
||||
import json_repair
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
@ -49,27 +51,33 @@ class LLMParam(ComponentParamBase):
|
||||
self.visual_files_var = None
|
||||
|
||||
def check(self):
|
||||
self.check_decimal_float(self.temperature, "[Agent] Temperature")
|
||||
self.check_decimal_float(self.presence_penalty, "[Agent] Presence penalty")
|
||||
self.check_decimal_float(self.frequency_penalty, "[Agent] Frequency penalty")
|
||||
self.check_nonnegative_number(self.max_tokens, "[Agent] Max tokens")
|
||||
self.check_decimal_float(self.top_p, "[Agent] Top P")
|
||||
self.check_decimal_float(float(self.temperature), "[Agent] Temperature")
|
||||
self.check_decimal_float(float(self.presence_penalty), "[Agent] Presence penalty")
|
||||
self.check_decimal_float(float(self.frequency_penalty), "[Agent] Frequency penalty")
|
||||
self.check_nonnegative_number(int(self.max_tokens), "[Agent] Max tokens")
|
||||
self.check_decimal_float(float(self.top_p), "[Agent] Top P")
|
||||
self.check_empty(self.llm_id, "[Agent] LLM")
|
||||
self.check_empty(self.sys_prompt, "[Agent] System prompt")
|
||||
self.check_empty(self.prompts, "[Agent] User prompt")
|
||||
|
||||
def gen_conf(self):
|
||||
conf = {}
|
||||
if self.max_tokens > 0:
|
||||
conf["max_tokens"] = self.max_tokens
|
||||
if self.temperature > 0:
|
||||
conf["temperature"] = self.temperature
|
||||
if self.top_p > 0:
|
||||
conf["top_p"] = self.top_p
|
||||
if self.presence_penalty > 0:
|
||||
conf["presence_penalty"] = self.presence_penalty
|
||||
if self.frequency_penalty > 0:
|
||||
conf["frequency_penalty"] = self.frequency_penalty
|
||||
def get_attr(nm):
|
||||
try:
|
||||
return getattr(self, nm)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if int(self.max_tokens) > 0 and get_attr("maxTokensEnabled"):
|
||||
conf["max_tokens"] = int(self.max_tokens)
|
||||
if float(self.temperature) > 0 and get_attr("temperatureEnabled"):
|
||||
conf["temperature"] = float(self.temperature)
|
||||
if float(self.top_p) > 0 and get_attr("topPEnabled"):
|
||||
conf["top_p"] = float(self.top_p)
|
||||
if float(self.presence_penalty) > 0 and get_attr("presencePenaltyEnabled"):
|
||||
conf["presence_penalty"] = float(self.presence_penalty)
|
||||
if float(self.frequency_penalty) > 0 and get_attr("frequencyPenaltyEnabled"):
|
||||
conf["frequency_penalty"] = float(self.frequency_penalty)
|
||||
return conf
|
||||
|
||||
|
||||
@ -112,6 +120,12 @@ class LLM(ComponentBase):
|
||||
if not self.imgs:
|
||||
self.imgs = []
|
||||
self.imgs = [img for img in self.imgs if img[:len("data:image/")] == "data:image/"]
|
||||
if self.imgs and TenantLLMService.llm_id2llm_type(self._param.llm_id) == LLMType.CHAT.value:
|
||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT.value,
|
||||
self._param.llm_id, max_retries=self._param.max_retries,
|
||||
retry_interval=self._param.delay_after_error
|
||||
)
|
||||
|
||||
|
||||
args = {}
|
||||
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
|
||||
@ -207,7 +221,8 @@ class LLM(ComponentBase):
|
||||
return
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure:
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self._stream_output, prompt, msg))
|
||||
return
|
||||
|
||||
@ -224,14 +239,22 @@ class LLM(ComponentBase):
|
||||
break
|
||||
|
||||
if error:
|
||||
self.set_output("_ERROR", error)
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
else:
|
||||
self.set_output("_ERROR", error)
|
||||
|
||||
def _stream_output(self, prompt, msg):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer = ""
|
||||
for ans in self._generate_streamly(msg):
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
yield ans
|
||||
answer += ans
|
||||
self.set_output("content", answer)
|
||||
@ -243,4 +266,4 @@ class LLM(ComponentBase):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
_, msg = self._prepare_prompt_variables()
|
||||
return f"I’m thinking and planning the next move, starting from the prompt:<br/>“{msg[-1]['content']}”<span class=\"collapse\"> (tap to see full text)</span>"
|
||||
return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nI’ll figure out our best next move."
|
||||
@ -143,4 +143,4 @@ class Message(ComponentBase):
|
||||
self.set_output("content", content)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Thinking ..."
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user