Refactor: move mcp connection utilities to common (#11304)

### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-11-17 15:34:17 +08:00
committed by GitHub
parent 0569b50fed
commit bd4bc57009
10 changed files with 27 additions and 27 deletions

View File

@ -30,7 +30,7 @@ from api.db.services.mcp_server_service import MCPServerService
from common.connection_utils import timeout from common.connection_utils import timeout
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \ from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from agent.component.llm import LLMParam, LLM from agent.component.llm import LLMParam, LLM

View File

@ -21,9 +21,8 @@ from functools import partial
from typing import TypedDict, List, Any from typing import TypedDict, List, Any
from agent.component.base import ComponentParamBase, ComponentBase from agent.component.base import ComponentParamBase, ComponentBase
from common.misc_utils import hash_str2int from common.misc_utils import hash_str2int
from rag.llm.chat_model import ToolCallSession
from rag.prompts.generator import kb_prompt from rag.prompts.generator import kb_prompt
from rag.utils.mcp_tool_call_conn import MCPToolCallSession from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
from timeit import default_timer as timer from timeit import default_timer as timer

View File

@ -25,7 +25,7 @@ from common.misc_utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
get_mcp_tools get_mcp_tools
from api.utils.web_utils import get_float, safe_json_parse from api.utils.web_utils import get_float, safe_json_parse
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@manager.route("/list", methods=["POST"]) # noqa: F821 @manager.route("/list", methods=["POST"]) # noqa: F821

View File

@ -41,12 +41,12 @@ def list_agents(tenant_id):
return get_error_data_result("The agent doesn't exist.") return get_error_data_result("The agent doesn't exist.")
page_number = int(request.args.get("page", 1)) page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 30)) items_per_page = int(request.args.get("page_size", 30))
orderby = request.args.get("orderby", "update_time") order_by = request.args.get("orderby", "update_time")
if request.args.get("desc") == "False" or request.args.get("desc") == "false": if request.args.get("desc") == "False" or request.args.get("desc") == "false":
desc = False desc = False
else: else:
desc = True desc = True
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title) canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, order_by, desc, id, title)
return get_result(data=canvas) return get_result(data=canvas)

View File

@ -41,7 +41,7 @@ from api.db.db_models import init_database_tables as init_web_db
from api.db.init_data import init_web_data from api.db.init_data import init_web_data
from common.versions import get_ragflow_version from common.versions import get_ragflow_version
from common.config_utils import show_configs from common.config_utils import show_configs
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions from common.mcp_tool_call_conn import shutdown_all_mcp_sessions
from rag.utils.redis_conn import RedisDistributedLock from rag.utils.redis_conn import RedisDistributedLock
stop_event = threading.Event() stop_event = threading.Event()

View File

@ -37,7 +37,7 @@ from peewee import OperationalError
from common.constants import ActiveEnum from common.constants import ActiveEnum
from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.utils.json_encode import CustomJSONEncoder from api.utils.json_encode import CustomJSONEncoder
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from api.db.services.tenant_llm_service import LLMFactoriesService from api.db.services.tenant_llm_service import LLMFactoriesService
from common.connection_utils import timeout from common.connection_utils import timeout
from common.constants import RetCode from common.constants import RetCode

View File

@ -69,7 +69,7 @@ class SlimConnectorWithPermSync(ABC):
class CheckpointedConnectorWithPermSync(ABC): class CheckpointedConnectorWithPermSync(ABC):
"""Checkpointed connector interface (with permission sync)""" """Checkpoint connector interface (with permission sync)"""
@abstractmethod @abstractmethod
def load_from_checkpoint( def load_from_checkpoint(
@ -143,7 +143,7 @@ class CredentialsProviderInterface(abc.ABC, Generic[T]):
@abc.abstractmethod @abc.abstractmethod
def is_dynamic(self) -> bool: def is_dynamic(self) -> bool:
"""If dynamic, the credentials may change during usage ... maening the client """If dynamic, the credentials may change during usage ... meaning the client
needs to use the locking features of the credentials provider to operate needs to use the locking features of the credentials provider to operate
correctly. correctly.

View File

@ -21,7 +21,7 @@ import weakref
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError from concurrent.futures import TimeoutError as FuturesTimeoutError
from string import Template from string import Template
from typing import Any, Literal from typing import Any, Literal, Protocol
from typing_extensions import override from typing_extensions import override
@ -30,12 +30,15 @@ from mcp.client.session import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client from mcp.client.streamable_http import streamablehttp_client
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
from rag.llm.chat_model import ToolCallSession
MCPTaskType = Literal["list_tools", "tool_call"] MCPTaskType = Literal["list_tools", "tool_call"]
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]] MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
class ToolCallSession(Protocol):
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
class MCPToolCallSession(ToolCallSession): class MCPToolCallSession(ToolCallSession):
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet() _ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
@ -106,7 +109,8 @@ class MCPToolCallSession(ToolCallSession):
await self._process_mcp_tasks(None, msg) await self._process_mcp_tasks(None, msg)
else: else:
await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}") await self._process_mcp_tasks(None,
f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None: async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None:
while not self._close: while not self._close:
@ -164,7 +168,8 @@ class MCPToolCallSession(ToolCallSession):
raise raise
async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str: async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, timeout=timeout) result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments,
timeout=timeout)
if result.isError: if result.isError:
return f"MCP server error: {result.content}" return f"MCP server error: {result.content}"
@ -283,7 +288,8 @@ def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) ->
except Exception: except Exception:
logging.exception("Exception during MCP session cleanup thread management") logging.exception("Exception during MCP session cleanup thread management")
logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.") logging.info(
f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
def shutdown_all_mcp_sessions(): def shutdown_all_mcp_sessions():
@ -298,7 +304,7 @@ def shutdown_all_mcp_sessions():
logging.info("All MCPToolCallSession instances have been closed.") logging.info("All MCPToolCallSession instances have been closed.")
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool|dict) -> dict[str, Any]: def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict) -> dict[str, Any]:
if isinstance(mcp_tool, dict): if isinstance(mcp_tool, dict):
return { return {
"type": "function", "type": "function",

View File

@ -22,7 +22,6 @@ import re
import time import time
from abc import ABC from abc import ABC
from copy import deepcopy from copy import deepcopy
from typing import Any, Protocol
from urllib.parse import urljoin from urllib.parse import urljoin
import json_repair import json_repair
@ -65,10 +64,6 @@ LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length." LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
class ToolCallSession(Protocol):
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
class Base(ABC): class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs): def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))

View File

@ -155,13 +155,13 @@ def qbullets_category(sections):
if re.match(pro, sec) and not not_bullet(sec): if re.match(pro, sec) and not not_bullet(sec):
hits[i] += 1 hits[i] += 1
break break
maxium = 0 maximum = 0
res = -1 res = -1
for i, h in enumerate(hits): for i, h in enumerate(hits):
if h <= maxium: if h <= maximum:
continue continue
res = i res = i
maxium = h maximum = h
return res, QUESTION_PATTERN[res] return res, QUESTION_PATTERN[res]
@ -222,13 +222,13 @@ def bullets_category(sections):
if re.match(p, sec) and not not_bullet(sec): if re.match(p, sec) and not not_bullet(sec):
hits[i] += 1 hits[i] += 1
break break
maxium = 0 maximum = 0
res = -1 res = -1
for i, h in enumerate(hits): for i, h in enumerate(hits):
if h <= maxium: if h <= maximum:
continue continue
res = i res = i
maxium = h maximum = h
return res return res