mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Refactor: move mcp connection utilities to common (#11304)
### What problem does this PR solve? As title ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -30,7 +30,7 @@ from api.db.services.mcp_server_service import MCPServerService
|
|||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
|
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
|
||||||
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
|
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||||
from agent.component.llm import LLMParam, LLM
|
from agent.component.llm import LLMParam, LLM
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -21,9 +21,8 @@ from functools import partial
|
|||||||
from typing import TypedDict, List, Any
|
from typing import TypedDict, List, Any
|
||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
from agent.component.base import ComponentParamBase, ComponentBase
|
||||||
from common.misc_utils import hash_str2int
|
from common.misc_utils import hash_str2int
|
||||||
from rag.llm.chat_model import ToolCallSession
|
|
||||||
from rag.prompts.generator import kb_prompt
|
from rag.prompts.generator import kb_prompt
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from common.misc_utils import get_uuid
|
|||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||||
get_mcp_tools
|
get_mcp_tools
|
||||||
from api.utils.web_utils import get_float, safe_json_parse
|
from api.utils.web_utils import get_float, safe_json_parse
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||||
|
|||||||
@ -41,12 +41,12 @@ def list_agents(tenant_id):
|
|||||||
return get_error_data_result("The agent doesn't exist.")
|
return get_error_data_result("The agent doesn't exist.")
|
||||||
page_number = int(request.args.get("page", 1))
|
page_number = int(request.args.get("page", 1))
|
||||||
items_per_page = int(request.args.get("page_size", 30))
|
items_per_page = int(request.args.get("page_size", 30))
|
||||||
orderby = request.args.get("orderby", "update_time")
|
order_by = request.args.get("orderby", "update_time")
|
||||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
||||||
desc = False
|
desc = False
|
||||||
else:
|
else:
|
||||||
desc = True
|
desc = True
|
||||||
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title)
|
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, order_by, desc, id, title)
|
||||||
return get_result(data=canvas)
|
return get_result(data=canvas)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -41,7 +41,7 @@ from api.db.db_models import init_database_tables as init_web_db
|
|||||||
from api.db.init_data import init_web_data
|
from api.db.init_data import init_web_data
|
||||||
from common.versions import get_ragflow_version
|
from common.versions import get_ragflow_version
|
||||||
from common.config_utils import show_configs
|
from common.config_utils import show_configs
|
||||||
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
from common.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||||
from rag.utils.redis_conn import RedisDistributedLock
|
from rag.utils.redis_conn import RedisDistributedLock
|
||||||
|
|
||||||
stop_event = threading.Event()
|
stop_event = threading.Event()
|
||||||
|
|||||||
@ -37,7 +37,7 @@ from peewee import OperationalError
|
|||||||
from common.constants import ActiveEnum
|
from common.constants import ActiveEnum
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
from api.utils.json_encode import CustomJSONEncoder
|
from api.utils.json_encode import CustomJSONEncoder
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||||
from api.db.services.tenant_llm_service import LLMFactoriesService
|
from api.db.services.tenant_llm_service import LLMFactoriesService
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
|
|||||||
@ -69,7 +69,7 @@ class SlimConnectorWithPermSync(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class CheckpointedConnectorWithPermSync(ABC):
|
class CheckpointedConnectorWithPermSync(ABC):
|
||||||
"""Checkpointed connector interface (with permission sync)"""
|
"""Checkpoint connector interface (with permission sync)"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_from_checkpoint(
|
def load_from_checkpoint(
|
||||||
@ -143,7 +143,7 @@ class CredentialsProviderInterface(abc.ABC, Generic[T]):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def is_dynamic(self) -> bool:
|
def is_dynamic(self) -> bool:
|
||||||
"""If dynamic, the credentials may change during usage ... maening the client
|
"""If dynamic, the credentials may change during usage ... meaning the client
|
||||||
needs to use the locking features of the credentials provider to operate
|
needs to use the locking features of the credentials provider to operate
|
||||||
correctly.
|
correctly.
|
||||||
|
|
||||||
|
|||||||
@ -21,7 +21,7 @@ import weakref
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
||||||
from string import Template
|
from string import Template
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, Protocol
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@ -30,12 +30,15 @@ from mcp.client.session import ClientSession
|
|||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
|
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
|
||||||
from rag.llm.chat_model import ToolCallSession
|
|
||||||
|
|
||||||
MCPTaskType = Literal["list_tools", "tool_call"]
|
MCPTaskType = Literal["list_tools", "tool_call"]
|
||||||
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
|
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallSession(Protocol):
|
||||||
|
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
|
||||||
|
|
||||||
|
|
||||||
class MCPToolCallSession(ToolCallSession):
|
class MCPToolCallSession(ToolCallSession):
|
||||||
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
|
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
|
||||||
|
|
||||||
@ -106,7 +109,8 @@ class MCPToolCallSession(ToolCallSession):
|
|||||||
await self._process_mcp_tasks(None, msg)
|
await self._process_mcp_tasks(None, msg)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
|
await self._process_mcp_tasks(None,
|
||||||
|
f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
|
||||||
|
|
||||||
async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None:
|
async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None:
|
||||||
while not self._close:
|
while not self._close:
|
||||||
@ -164,7 +168,8 @@ class MCPToolCallSession(ToolCallSession):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
|
async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
|
||||||
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, timeout=timeout)
|
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments,
|
||||||
|
timeout=timeout)
|
||||||
|
|
||||||
if result.isError:
|
if result.isError:
|
||||||
return f"MCP server error: {result.content}"
|
return f"MCP server error: {result.content}"
|
||||||
@ -283,7 +288,8 @@ def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) ->
|
|||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("Exception during MCP session cleanup thread management")
|
logging.exception("Exception during MCP session cleanup thread management")
|
||||||
|
|
||||||
logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
|
logging.info(
|
||||||
|
f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
|
||||||
|
|
||||||
|
|
||||||
def shutdown_all_mcp_sessions():
|
def shutdown_all_mcp_sessions():
|
||||||
@ -298,7 +304,7 @@ def shutdown_all_mcp_sessions():
|
|||||||
logging.info("All MCPToolCallSession instances have been closed.")
|
logging.info("All MCPToolCallSession instances have been closed.")
|
||||||
|
|
||||||
|
|
||||||
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool|dict) -> dict[str, Any]:
|
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict) -> dict[str, Any]:
|
||||||
if isinstance(mcp_tool, dict):
|
if isinstance(mcp_tool, dict):
|
||||||
return {
|
return {
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@ -22,7 +22,6 @@ import re
|
|||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Protocol
|
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
@ -65,10 +64,6 @@ LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小
|
|||||||
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
|
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
|
||||||
|
|
||||||
|
|
||||||
class ToolCallSession(Protocol):
|
|
||||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
|
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
def __init__(self, key, model_name, base_url, **kwargs):
|
def __init__(self, key, model_name, base_url, **kwargs):
|
||||||
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
||||||
|
|||||||
@ -155,13 +155,13 @@ def qbullets_category(sections):
|
|||||||
if re.match(pro, sec) and not not_bullet(sec):
|
if re.match(pro, sec) and not not_bullet(sec):
|
||||||
hits[i] += 1
|
hits[i] += 1
|
||||||
break
|
break
|
||||||
maxium = 0
|
maximum = 0
|
||||||
res = -1
|
res = -1
|
||||||
for i, h in enumerate(hits):
|
for i, h in enumerate(hits):
|
||||||
if h <= maxium:
|
if h <= maximum:
|
||||||
continue
|
continue
|
||||||
res = i
|
res = i
|
||||||
maxium = h
|
maximum = h
|
||||||
return res, QUESTION_PATTERN[res]
|
return res, QUESTION_PATTERN[res]
|
||||||
|
|
||||||
|
|
||||||
@ -222,13 +222,13 @@ def bullets_category(sections):
|
|||||||
if re.match(p, sec) and not not_bullet(sec):
|
if re.match(p, sec) and not not_bullet(sec):
|
||||||
hits[i] += 1
|
hits[i] += 1
|
||||||
break
|
break
|
||||||
maxium = 0
|
maximum = 0
|
||||||
res = -1
|
res = -1
|
||||||
for i, h in enumerate(hits):
|
for i, h in enumerate(hits):
|
||||||
if h <= maxium:
|
if h <= maximum:
|
||||||
continue
|
continue
|
||||||
res = i
|
res = i
|
||||||
maxium = h
|
maximum = h
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user