Refa: change mcp_client module to rag/utils/conn (#8578)

### What problem does this PR solve?

Change mcp_client module to rag/utils/conn.

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-07-01 09:29:19 +08:00
committed by GitHub
parent d620432e3b
commit 8801de2772
6 changed files with 27 additions and 14 deletions

View File

@ -200,7 +200,6 @@ COPY graphrag graphrag
COPY agentic_reasoning agentic_reasoning COPY agentic_reasoning agentic_reasoning
COPY pyproject.toml uv.lock ./ COPY pyproject.toml uv.lock ./
COPY mcp mcp COPY mcp mcp
COPY mcp_client mcp_client
COPY plugin plugin COPY plugin plugin
COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template

View File

@ -33,7 +33,6 @@ ADD ./rag ./rag
ADD ./requirements.txt ./requirements.txt ADD ./requirements.txt ./requirements.txt
ADD ./agent ./agent ADD ./agent ./agent
ADD ./graphrag ./graphrag ADD ./graphrag ./graphrag
ADD ./mcp_client ./mcp_client
ADD ./plugin ./plugin ADD ./plugin ./plugin
RUN dnf install -y openmpi openmpi-devel python3-openmpi RUN dnf install -y openmpi openmpi-devel python3-openmpi

View File

@ -9,7 +9,7 @@ from api.settings import RetCode
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from api.utils.web_utils import get_float, safe_json_parse from api.utils.web_utils import get_float, safe_json_parse
from mcp_client.mcp_tool_call import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@manager.route("/list", methods=["POST"]) # noqa: F821 @manager.route("/list", methods=["POST"]) # noqa: F821

View File

@ -19,7 +19,6 @@
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
from api.utils.log_utils import init_root_logger from api.utils.log_utils import init_root_logger
from mcp_client.mcp_tool_call import shutdown_all_mcp_sessions
from plugin import GlobalPluginManager from plugin import GlobalPluginManager
init_root_logger("ragflow_server") init_root_logger("ragflow_server")
@ -44,6 +43,7 @@ from api.db.init_data import init_web_data
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
from api.utils import show_configs from api.utils import show_configs
from rag.settings import print_rag_settings from rag.settings import print_rag_settings
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
from rag.utils.redis_conn import RedisDistributedLock from rag.utils.redis_conn import RedisDistributedLock
stop_event = threading.Event() stop_event = threading.Event()

View File

@ -1,2 +0,0 @@
# ruff: noqa: F401
from .mcp_tool_call import MCPToolCallSession, mcp_tool_metadata_to_openai_tool, close_multiple_mcp_toolcall_sessions

View File

@ -1,8 +1,25 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio import asyncio
import logging import logging
import threading import threading
import weakref import weakref
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from string import Template from string import Template
from typing import Any, Literal from typing import Any, Literal
@ -101,7 +118,7 @@ class MCPToolCallSession(ToolCallSession):
try: try:
result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout) result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s") raise asyncio.TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s")
if isinstance(result, Exception): if isinstance(result, Exception):
raise result raise result
@ -128,8 +145,8 @@ class MCPToolCallSession(ToolCallSession):
future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), self._event_loop) future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), self._event_loop)
try: try:
return future.result(timeout=timeout) return future.result(timeout=timeout)
except TimeoutError: except FuturesTimeoutError:
logging.error(f"Timeout when fetching tools from MCP server: {self._mcp_server.id}") logging.error(f"Timeout when fetching tools from MCP server: {self._mcp_server.id} (timeout={timeout})")
return [] return []
except Exception: except Exception:
logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}") logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}")
@ -140,9 +157,9 @@ class MCPToolCallSession(ToolCallSession):
future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop) future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop)
try: try:
return future.result(timeout=timeout) return future.result(timeout=timeout)
except TimeoutError as te: except FuturesTimeoutError:
logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id}") logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id} (timeout={timeout})")
return f"Timeout calling tool '{name}': {te}." return f"Timeout calling tool '{name}' (timeout={timeout})."
except Exception as e: except Exception as e:
logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}") logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}")
return f"Error calling tool '{name}': {e}." return f"Error calling tool '{name}': {e}."
@ -164,8 +181,8 @@ class MCPToolCallSession(ToolCallSession):
future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop) future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop)
try: try:
future.result(timeout=timeout) future.result(timeout=timeout)
except TimeoutError: except FuturesTimeoutError:
logging.error(f"Timeout while closing session for server {self._mcp_server.id}") logging.error(f"Timeout while closing session for server {self._mcp_server.id} (timeout={timeout})")
except Exception: except Exception:
logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}") logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}")