From 8801de2772c6530c10cd02028838dd47111fe803 Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Tue, 1 Jul 2025 09:29:19 +0800 Subject: [PATCH] Refa: change mcp_client module to rag/utils/conn (#8578) ### What problem does this PR solve? Change mcp_client module to rag/utils/conn. ### Type of change - [x] Refactoring --- Dockerfile | 1 - Dockerfile.scratch.oc9 | 1 - api/apps/mcp_server_app.py | 2 +- api/ragflow_server.py | 2 +- mcp_client/__init__.py | 2 -- .../utils/mcp_tool_call_conn.py | 33 ++++++++++++++----- 6 files changed, 27 insertions(+), 14 deletions(-) delete mode 100644 mcp_client/__init__.py rename mcp_client/mcp_tool_call.py => rag/utils/mcp_tool_call_conn.py (88%) diff --git a/Dockerfile b/Dockerfile index 0f0727b63..67fd26456 100644 --- a/Dockerfile +++ b/Dockerfile @@ -200,7 +200,6 @@ COPY graphrag graphrag COPY agentic_reasoning agentic_reasoning COPY pyproject.toml uv.lock ./ COPY mcp mcp -COPY mcp_client mcp_client COPY plugin plugin COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template diff --git a/Dockerfile.scratch.oc9 b/Dockerfile.scratch.oc9 index 2403eae16..64424735e 100644 --- a/Dockerfile.scratch.oc9 +++ b/Dockerfile.scratch.oc9 @@ -33,7 +33,6 @@ ADD ./rag ./rag ADD ./requirements.txt ./requirements.txt ADD ./agent ./agent ADD ./graphrag ./graphrag -ADD ./mcp_client ./mcp_client ADD ./plugin ./plugin RUN dnf install -y openmpi openmpi-devel python3-openmpi diff --git a/api/apps/mcp_server_app.py b/api/apps/mcp_server_app.py index 34bd9dfb2..fb81d51e0 100644 --- a/api/apps/mcp_server_app.py +++ b/api/apps/mcp_server_app.py @@ -9,7 +9,7 @@ from api.settings import RetCode from api.utils import get_uuid from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request from api.utils.web_utils import get_float, safe_json_parse -from mcp_client.mcp_tool_call import MCPToolCallSession, close_multiple_mcp_toolcall_sessions +from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions @manager.route("/list", methods=["POST"]) # noqa: F821 diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 288ae7fd3..6d6c72ea1 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -19,7 +19,6 @@ # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code from api.utils.log_utils import init_root_logger -from mcp_client.mcp_tool_call import shutdown_all_mcp_sessions from plugin import GlobalPluginManager init_root_logger("ragflow_server") @@ -44,6 +43,7 @@ from api.db.init_data import init_web_data from api.versions import get_ragflow_version from api.utils import show_configs from rag.settings import print_rag_settings +from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions from rag.utils.redis_conn import RedisDistributedLock stop_event = threading.Event() diff --git a/mcp_client/__init__.py b/mcp_client/__init__.py deleted file mode 100644 index a07125853..000000000 --- a/mcp_client/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# ruff: noqa: F401 -from .mcp_tool_call import MCPToolCallSession, mcp_tool_metadata_to_openai_tool, close_multiple_mcp_toolcall_sessions diff --git a/mcp_client/mcp_tool_call.py b/rag/utils/mcp_tool_call_conn.py similarity index 88% rename from mcp_client/mcp_tool_call.py rename to rag/utils/mcp_tool_call_conn.py index 22fe5d20b..5aca0e754 100644 --- a/mcp_client/mcp_tool_call.py +++ b/rag/utils/mcp_tool_call_conn.py @@ -1,8 +1,25 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import asyncio import logging import threading import weakref from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import TimeoutError as FuturesTimeoutError from string import Template from typing import Any, Literal @@ -101,7 +118,7 @@ class MCPToolCallSession(ToolCallSession): try: result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout) except asyncio.TimeoutError: - raise TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s") + raise asyncio.TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s") if isinstance(result, Exception): raise result @@ -128,8 +145,8 @@ class MCPToolCallSession(ToolCallSession): future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), self._event_loop) try: return future.result(timeout=timeout) - except TimeoutError: - logging.error(f"Timeout when fetching tools from MCP server: {self._mcp_server.id}") + except FuturesTimeoutError: + logging.error(f"Timeout when fetching tools from MCP server: {self._mcp_server.id} (timeout={timeout})") return [] except Exception: logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}") @@ -140,9 +157,9 @@ class MCPToolCallSession(ToolCallSession): future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop) try: return future.result(timeout=timeout) - except TimeoutError as te: - logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id}") - return f"Timeout calling tool '{name}': {te}." + except FuturesTimeoutError: + logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id} (timeout={timeout})") + return f"Timeout calling tool '{name}' (timeout={timeout})." except Exception as e: logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}") return f"Error calling tool '{name}': {e}." @@ -164,8 +181,8 @@ class MCPToolCallSession(ToolCallSession): future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop) try: future.result(timeout=timeout) - except TimeoutError: - logging.error(f"Timeout while closing session for server {self._mcp_server.id}") + except FuturesTimeoutError: + logging.error(f"Timeout while closing session for server {self._mcp_server.id} (timeout={timeout})") except Exception: logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}")