Feat: Support tool calling in Generate component (#7572)

### What problem does this PR solve?

Hello, our use case requires LLM agent to invoke some tools, so I made a
simple implementation here.

This PR does two things:

1. A simple plugin mechanism based on `pluginlib`:

This mechanism lives in the `plugin` directory. It will only load
plugins from `plugin/embedded_plugins` for now.

A sample plugin `bad_calculator.py` is placed in
`plugin/embedded_plugins/llm_tools`, it accepts two numbers `a` and `b`,
then give a wrong result `a + b + 100`.

In the future, it can load plugins from external location with little
code change.

Plugins are divided into different types. The only plugin type supported
in this PR is `llm_tools`, which must implement the `LLMToolPlugin`
class in the `plugin/llm_tool_plugin.py`.
More plugin types can be added in the future.

2. A tool selector in the `Generate` component:

Added a tool selector to select one or more tools for LLM:


![image](https://github.com/user-attachments/assets/74a21fdf-9333-4175-991b-43df6524c5dc)

And with the `bad_calculator` tool, it results this with the `qwen-max`
model:


![image](https://github.com/user-attachments/assets/93aff9c4-8550-414a-90a2-1a15a5249d94)


### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):

Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
This commit is contained in:
Song Fuchang
2025-05-16 16:32:19 +08:00
committed by GitHub
parent cb26564d50
commit a1f06a4fdc
28 changed files with 625 additions and 61 deletions

View File

@ -21,6 +21,7 @@ import random
import re
import time
from abc import ABC
from typing import Any, Protocol
import openai
import requests
@ -51,6 +52,10 @@ LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
class ToolCallSession(Protocol):
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
class Base(ABC):
def __init__(self, key, model_name, base_url):
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
@ -251,10 +256,8 @@ class Base(ABC):
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
final_tool_calls[index].function.arguments += tool_call.function.arguments
if resp.choices[0].finish_reason != "stop":
continue
else:
final_tool_calls[index].function.arguments += tool_call.function.arguments
else:
if not resp.choices:
continue
@ -276,58 +279,57 @@ class Base(ABC):
else:
total_tokens += tol
finish_reason = resp.choices[0].finish_reason
if finish_reason == "tool_calls" and final_tool_calls:
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
if name == "get_current_weather":
args = json.loads('{"location":"Shanghai"}')
else:
args = json.loads(tool_call.function.arguments)
except Exception:
continue
# args = json.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
history.append(
{
"role": "assistant",
"refusal": "",
"content": "",
"audio": "",
"function_call": "",
"tool_calls": [
{
"index": tool_call.index,
"id": tool_call.id,
"function": tool_call.function,
"type": "function",
finish_reason = resp.choices[0].finish_reason
if finish_reason == "tool_calls" and final_tool_calls:
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
args = json.loads(tool_call.function.arguments)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
yield ans + "\n**ERROR**: " + str(e)
finish_completion = True
break
tool_response = self.toolcall_session.tool_call(name, args)
history.append(
{
"role": "assistant",
"tool_calls": [
{
"index": tool_call.index,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
],
}
)
# if tool_response.choices[0].finish_reason == "length":
# if is_chinese(ans):
# ans += LENGTH_NOTIFICATION_CN
# else:
# ans += LENGTH_NOTIFICATION_EN
# return ans, total_tokens + self.total_token_count(tool_response)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
final_tool_calls = {}
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
continue
if finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, total_tokens + self.total_token_count(resp)
if finish_reason == "stop":
finish_completion = True
yield ans
break
yield ans
"type": "function",
},
],
}
)
# if tool_response.choices[0].finish_reason == "length":
# if is_chinese(ans):
# ans += LENGTH_NOTIFICATION_CN
# else:
# ans += LENGTH_NOTIFICATION_EN
# return ans, total_tokens + self.total_token_count(tool_response)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
final_tool_calls = {}
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
continue
if finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, total_tokens
if finish_reason == "stop":
finish_completion = True
yield ans
break
yield ans
continue
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
@ -854,6 +856,14 @@ class ZhipuChat(Base):
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
return super().chat_with_tools(system, history, gen_conf)
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
@ -886,6 +896,14 @@ class ZhipuChat(Base):
yield tk_count
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
return super().chat_streamly_with_tools(system, history, gen_conf)
class OllamaChat(Base):
def __init__(self, key, model_name, **kwargs):