mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: add result to callback for agent tool use. (#9137)
### What problem does this PR solve? ### Type of change - [x] Refactoring
This commit is contained in:
@ -22,6 +22,8 @@ from typing import Any
|
||||
import json_repair
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
@ -49,27 +51,33 @@ class LLMParam(ComponentParamBase):
|
||||
self.visual_files_var = None
|
||||
|
||||
def check(self):
|
||||
self.check_decimal_float(self.temperature, "[Agent] Temperature")
|
||||
self.check_decimal_float(self.presence_penalty, "[Agent] Presence penalty")
|
||||
self.check_decimal_float(self.frequency_penalty, "[Agent] Frequency penalty")
|
||||
self.check_nonnegative_number(self.max_tokens, "[Agent] Max tokens")
|
||||
self.check_decimal_float(self.top_p, "[Agent] Top P")
|
||||
self.check_decimal_float(float(self.temperature), "[Agent] Temperature")
|
||||
self.check_decimal_float(float(self.presence_penalty), "[Agent] Presence penalty")
|
||||
self.check_decimal_float(float(self.frequency_penalty), "[Agent] Frequency penalty")
|
||||
self.check_nonnegative_number(int(self.max_tokens), "[Agent] Max tokens")
|
||||
self.check_decimal_float(float(self.top_p), "[Agent] Top P")
|
||||
self.check_empty(self.llm_id, "[Agent] LLM")
|
||||
self.check_empty(self.sys_prompt, "[Agent] System prompt")
|
||||
self.check_empty(self.prompts, "[Agent] User prompt")
|
||||
|
||||
def gen_conf(self):
|
||||
conf = {}
|
||||
if self.max_tokens > 0:
|
||||
conf["max_tokens"] = self.max_tokens
|
||||
if self.temperature > 0:
|
||||
conf["temperature"] = self.temperature
|
||||
if self.top_p > 0:
|
||||
conf["top_p"] = self.top_p
|
||||
if self.presence_penalty > 0:
|
||||
conf["presence_penalty"] = self.presence_penalty
|
||||
if self.frequency_penalty > 0:
|
||||
conf["frequency_penalty"] = self.frequency_penalty
|
||||
def get_attr(nm):
|
||||
try:
|
||||
return getattr(self, nm)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if int(self.max_tokens) > 0 and get_attr("maxTokensEnabled"):
|
||||
conf["max_tokens"] = int(self.max_tokens)
|
||||
if float(self.temperature) > 0 and get_attr("temperatureEnabled"):
|
||||
conf["temperature"] = float(self.temperature)
|
||||
if float(self.top_p) > 0 and get_attr("topPEnabled"):
|
||||
conf["top_p"] = float(self.top_p)
|
||||
if float(self.presence_penalty) > 0 and get_attr("presencePenaltyEnabled"):
|
||||
conf["presence_penalty"] = float(self.presence_penalty)
|
||||
if float(self.frequency_penalty) > 0 and get_attr("frequencyPenaltyEnabled"):
|
||||
conf["frequency_penalty"] = float(self.frequency_penalty)
|
||||
return conf
|
||||
|
||||
|
||||
@ -112,6 +120,12 @@ class LLM(ComponentBase):
|
||||
if not self.imgs:
|
||||
self.imgs = []
|
||||
self.imgs = [img for img in self.imgs if img[:len("data:image/")] == "data:image/"]
|
||||
if self.imgs and TenantLLMService.llm_id2llm_type(self._param.llm_id) == LLMType.CHAT.value:
|
||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT.value,
|
||||
self._param.llm_id, max_retries=self._param.max_retries,
|
||||
retry_interval=self._param.delay_after_error
|
||||
)
|
||||
|
||||
|
||||
args = {}
|
||||
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
|
||||
@ -207,7 +221,8 @@ class LLM(ComponentBase):
|
||||
return
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure:
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self._stream_output, prompt, msg))
|
||||
return
|
||||
|
||||
@ -224,14 +239,22 @@ class LLM(ComponentBase):
|
||||
break
|
||||
|
||||
if error:
|
||||
self.set_output("_ERROR", error)
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
else:
|
||||
self.set_output("_ERROR", error)
|
||||
|
||||
def _stream_output(self, prompt, msg):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer = ""
|
||||
for ans in self._generate_streamly(msg):
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
yield ans
|
||||
answer += ans
|
||||
self.set_output("content", answer)
|
||||
@ -243,4 +266,4 @@ class LLM(ComponentBase):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
_, msg = self._prepare_prompt_variables()
|
||||
return f"I’m thinking and planning the next move, starting from the prompt:<br/>“{msg[-1]['content']}”<span class=\"collapse\"> (tap to see full text)</span>"
|
||||
return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nI’ll figure out our best next move."
|
||||
Reference in New Issue
Block a user