mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Scratch MCP tool calling support. (#8263)
### What problem does this PR solve? This is a cherry-pick from #7781 as requested. ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@ -61,6 +61,9 @@ class ToolCallSession(Protocol):
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
tools: list[Any]
|
||||
toolcall_sessions: dict[str, ToolCallSession]
|
||||
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
||||
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
|
||||
@ -70,6 +73,8 @@ class Base(ABC):
|
||||
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
|
||||
self.max_rounds = kwargs.get("max_rounds", 5)
|
||||
self.is_tools = False
|
||||
self.tools = []
|
||||
self.toolcall_sessions = {}
|
||||
|
||||
def _get_delay(self):
|
||||
"""Calculate retry delay time"""
|
||||
@ -145,8 +150,10 @@ class Base(ABC):
|
||||
if not (toolcall_session and tools):
|
||||
return
|
||||
self.is_tools = True
|
||||
self.toolcall_session = toolcall_session
|
||||
self.tools = tools
|
||||
|
||||
for tool in tools:
|
||||
self.toolcall_sessions[tool["function"]["name"]] = toolcall_session
|
||||
self.tools.append(tool)
|
||||
|
||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||
gen_conf = self._clean_conf()
|
||||
@ -180,7 +187,7 @@ class Base(ABC):
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
tool_response = self.toolcall_sessions[name].tool_call(name, args)
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
||||
except Exception as e:
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||
@ -286,7 +293,7 @@ class Base(ABC):
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
tool_response = self.toolcall_sessions[name].tool_call(name, args)
|
||||
history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
@ -585,7 +592,7 @@ class QWenChat(Base):
|
||||
tool_name = assistant_output.tool_calls[0]["function"]["name"]
|
||||
if tool_name:
|
||||
arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"])
|
||||
tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=arguments)
|
||||
tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=arguments)
|
||||
history.append(tool_info)
|
||||
|
||||
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf)
|
||||
@ -708,7 +715,7 @@ class QWenChat(Base):
|
||||
|
||||
tool_name = toolcall_message.tool_calls[0]["function"]["name"]
|
||||
history.append(toolcall_message)
|
||||
tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=tool_arguments)
|
||||
tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=tool_arguments)
|
||||
history.append(tool_info)
|
||||
tool_info = {"content": "", "role": "tool"}
|
||||
tool_name = ""
|
||||
|
||||
Reference in New Issue
Block a user