Feat: Scratch MCP tool calling support. (#8263)

### What problem does this PR solve?

This is a cherry-pick from #7781 as requested.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
Song Fuchang
2025-06-23 17:45:35 +08:00
committed by GitHub
parent e9c6891e24
commit fd7ac17605
14 changed files with 445 additions and 7 deletions

View File

@ -61,6 +61,9 @@ class ToolCallSession(Protocol):
class Base(ABC):
tools: list[Any]
toolcall_sessions: dict[str, ToolCallSession]
def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
@ -70,6 +73,8 @@ class Base(ABC):
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
self.max_rounds = kwargs.get("max_rounds", 5)
self.is_tools = False
self.tools = []
self.toolcall_sessions = {}
def _get_delay(self):
"""Calculate retry delay time"""
@ -145,8 +150,10 @@ class Base(ABC):
if not (toolcall_session and tools):
return
self.is_tools = True
self.toolcall_session = toolcall_session
self.tools = tools
for tool in tools:
self.toolcall_sessions[tool["function"]["name"]] = toolcall_session
self.tools.append(tool)
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
gen_conf = self._clean_conf()
@ -180,7 +187,7 @@ class Base(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
tool_response = self.toolcall_sessions[name].tool_call(name, args)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
except Exception as e:
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
@ -286,7 +293,7 @@ class Base(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
tool_response = self.toolcall_sessions[name].tool_call(name, args)
history.append(
{
"role": "assistant",
@ -585,7 +592,7 @@ class QWenChat(Base):
tool_name = assistant_output.tool_calls[0]["function"]["name"]
if tool_name:
arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"])
tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=arguments)
tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=arguments)
history.append(tool_info)
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf)
@ -708,7 +715,7 @@ class QWenChat(Base):
tool_name = toolcall_message.tool_calls[0]["function"]["name"]
history.append(toolcall_message)
tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=tool_arguments)
tool_info["content"] = self.toolcall_sessions[tool_name].tool_call(name=tool_name, arguments=tool_arguments)
history.append(tool_info)
tool_info = {"content": "", "role": "tool"}
tool_name = ""