Files
ragflow/agent/tools/base.py
Kevin Hu d9fe279dde Feat: Redesign and refactor agent module (#9113)
### What problem does this PR solve?

#9082 #6365

<u> **WARNING: it's not compatible with the older version of `Agent`
module, which means that `Agent` from older versions can not work
anymore.**</u>

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-07-30 19:41:09 +08:00

168 lines
5.4 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import time
from copy import deepcopy
from functools import partial
from typing import TypedDict, List, Any
from agent.component.base import ComponentParamBase, ComponentBase
from api.utils import hash_str2int
from rag.llm.chat_model import ToolCallSession
from rag.prompts.prompts import kb_prompt
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
class ToolParameter(TypedDict):
type: str
description: str
displayDescription: str
enum: List[str]
required: bool
class ToolMeta(TypedDict):
name: str
displayName: str
description: str
displayDescription: str
parameters: dict[str, ToolParameter]
class LLMToolPluginCallSession(ToolCallSession):
def __init__(self, tools_map: dict[str, object], callback: partial):
self.tools_map = tools_map
self.callback = callback
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
assert name in self.tools_map, f"LLM tool {name} does not exist"
self.callback(name, arguments, " running ...")
if isinstance(self.tools_map[name], MCPToolCallSession):
resp = self.tools_map[name].tool_call(name, arguments, 60)
else:
resp = self.tools_map[name].invoke(**arguments)
return resp
def get_tool_obj(self, name):
return self.tools_map[name]
class ToolParamBase(ComponentParamBase):
def __init__(self):
#self.meta:ToolMeta = None
super().__init__()
self._init_inputs()
self._init_attr_by_meta()
def _init_inputs(self):
self.inputs = {}
for k,p in self.meta["parameters"].items():
self.inputs[k] = deepcopy(p)
def _init_attr_by_meta(self):
for k,p in self.meta["parameters"].items():
if not hasattr(self, k):
setattr(self, k, p.get("default"))
def get_meta(self):
params = {}
for k, p in self.meta["parameters"].items():
params[k] = {
"type": p["type"],
"description": p["description"]
}
if "enum" in p:
params[k]["enum"] = p["enum"]
desc = self.meta["description"]
if hasattr(self, "description"):
desc = self.description
function_name = self.meta["name"]
if hasattr(self, "function_name"):
function_name = self.function_name
return {
"type": "function",
"function": {
"name": function_name,
"description": desc,
"parameters": {
"type": "object",
"properties": params,
"required": [k for k, p in self.meta["parameters"].items() if p["required"]]
}
}
}
class ToolBase(ComponentBase):
def __init__(self, canvas, id, param: ComponentParamBase):
from agent.canvas import Canvas # Local import to avoid cyclic dependency
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
self._canvas = canvas
self._id = id
self._param = param
self._param.check()
def get_meta(self) -> dict[str, Any]:
return self._param.get_meta()
def invoke(self, **kwargs):
self.set_output("_created_time", time.perf_counter())
try:
res = self._invoke(**kwargs)
except Exception as e:
self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e)
res = str(e)
self._param.debug_inputs = []
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
chunks = []
aggs = []
for r in res_list:
content = get_content(r)
if not content:
continue
content = re.sub(r"!?\[[a-z]+\]\(data:image/png;base64,[ 0-9A-Za-z/_=+-]+\)", "", content)
content = content[:10000]
if not content:
continue
id = str(hash_str2int(content))
title = get_title(r)
url = get_url(r)
score = get_score(r) if get_score else 1
chunks.append({
"chunk_id": id,
"content": content,
"doc_id": id,
"docnm_kwd": title,
"similarity": score,
"url": url
})
aggs.append({
"doc_name": title,
"doc_id": id,
"count": 1,
"url": url
})
self._canvas.add_refernce(chunks, aggs)
self.set_output("formalized_content", "\n".join(kb_prompt({"chunks": chunks, "doc_aggs": aggs}, 200000, True)))