mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Support tool calling in Generate component (#7572)
### What problem does this PR solve? Hello, our use case requires LLM agent to invoke some tools, so I made a simple implementation here. This PR does two things: 1. A simple plugin mechanism based on `pluginlib`: This mechanism lives in the `plugin` directory. It will only load plugins from `plugin/embedded_plugins` for now. A sample plugin `bad_calculator.py` is placed in `plugin/embedded_plugins/llm_tools`, it accepts two numbers `a` and `b`, then give a wrong result `a + b + 100`. In the future, it can load plugins from external location with little code change. Plugins are divided into different types. The only plugin type supported in this PR is `llm_tools`, which must implement the `LLMToolPlugin` class in the `plugin/llm_tool_plugin.py`. More plugin types can be added in the future. 2. A tool selector in the `Generate` component: Added a tool selector to select one or more tools for LLM:  And with the `bad_calculator` tool, it results this with the `qwen-max` model:  ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
This commit is contained in:
@ -21,6 +21,7 @@ import random
|
||||
import re
|
||||
import time
|
||||
from abc import ABC
|
||||
from typing import Any, Protocol
|
||||
|
||||
import openai
|
||||
import requests
|
||||
@ -51,6 +52,10 @@ LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小
|
||||
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
|
||||
|
||||
|
||||
class ToolCallSession(Protocol):
|
||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, base_url):
|
||||
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
||||
@ -251,10 +256,8 @@ class Base(ABC):
|
||||
|
||||
if index not in final_tool_calls:
|
||||
final_tool_calls[index] = tool_call
|
||||
|
||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
||||
if resp.choices[0].finish_reason != "stop":
|
||||
continue
|
||||
else:
|
||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
||||
else:
|
||||
if not resp.choices:
|
||||
continue
|
||||
@ -276,58 +279,57 @@ class Base(ABC):
|
||||
else:
|
||||
total_tokens += tol
|
||||
|
||||
finish_reason = resp.choices[0].finish_reason
|
||||
if finish_reason == "tool_calls" and final_tool_calls:
|
||||
for tool_call in final_tool_calls.values():
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
if name == "get_current_weather":
|
||||
args = json.loads('{"location":"Shanghai"}')
|
||||
else:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except Exception:
|
||||
continue
|
||||
# args = json.loads(tool_call.function.arguments)
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"refusal": "",
|
||||
"content": "",
|
||||
"audio": "",
|
||||
"function_call": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": tool_call.index,
|
||||
"id": tool_call.id,
|
||||
"function": tool_call.function,
|
||||
"type": "function",
|
||||
finish_reason = resp.choices[0].finish_reason
|
||||
if finish_reason == "tool_calls" and final_tool_calls:
|
||||
for tool_call in final_tool_calls.values():
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except Exception as e:
|
||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
finish_completion = True
|
||||
break
|
||||
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": tool_call.index,
|
||||
"id": tool_call.id,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
# if tool_response.choices[0].finish_reason == "length":
|
||||
# if is_chinese(ans):
|
||||
# ans += LENGTH_NOTIFICATION_CN
|
||||
# else:
|
||||
# ans += LENGTH_NOTIFICATION_EN
|
||||
# return ans, total_tokens + self.total_token_count(tool_response)
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
||||
final_tool_calls = {}
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||
continue
|
||||
if finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, total_tokens + self.total_token_count(resp)
|
||||
if finish_reason == "stop":
|
||||
finish_completion = True
|
||||
yield ans
|
||||
break
|
||||
yield ans
|
||||
"type": "function",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
# if tool_response.choices[0].finish_reason == "length":
|
||||
# if is_chinese(ans):
|
||||
# ans += LENGTH_NOTIFICATION_CN
|
||||
# else:
|
||||
# ans += LENGTH_NOTIFICATION_EN
|
||||
# return ans, total_tokens + self.total_token_count(tool_response)
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
||||
final_tool_calls = {}
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||
continue
|
||||
if finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, total_tokens
|
||||
if finish_reason == "stop":
|
||||
finish_completion = True
|
||||
yield ans
|
||||
break
|
||||
yield ans
|
||||
continue
|
||||
|
||||
except openai.APIError as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
@ -854,6 +856,14 @@ class ZhipuChat(Base):
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
del gen_conf["frequency_penalty"]
|
||||
|
||||
return super().chat_with_tools(system, history, gen_conf)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
@ -886,6 +896,14 @@ class ZhipuChat(Base):
|
||||
|
||||
yield tk_count
|
||||
|
||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
del gen_conf["frequency_penalty"]
|
||||
|
||||
return super().chat_streamly_with_tools(system, history, gen_conf)
|
||||
|
||||
|
||||
class OllamaChat(Base):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user