diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index a27504139..906a9eca3 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -30,7 +30,7 @@ from api.db.services.mcp_server_service import MCPServerService from common.connection_utils import timeout from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \ citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in -from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool +from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool from agent.component.llm import LLMParam, LLM diff --git a/agent/tools/base.py b/agent/tools/base.py index a3d569694..791242d59 100644 --- a/agent/tools/base.py +++ b/agent/tools/base.py @@ -21,9 +21,8 @@ from functools import partial from typing import TypedDict, List, Any from agent.component.base import ComponentParamBase, ComponentBase from common.misc_utils import hash_str2int -from rag.llm.chat_model import ToolCallSession from rag.prompts.generator import kb_prompt -from rag.utils.mcp_tool_call_conn import MCPToolCallSession +from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession from timeit import default_timer as timer diff --git a/api/apps/mcp_server_app.py b/api/apps/mcp_server_app.py index 66d447491..a8ac2aef1 100644 --- a/api/apps/mcp_server_app.py +++ b/api/apps/mcp_server_app.py @@ -25,7 +25,7 @@ from common.misc_utils import get_uuid from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ get_mcp_tools from api.utils.web_utils import get_float, safe_json_parse -from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions +from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions @manager.route("/list", methods=["POST"]) # noqa: F821 diff --git a/api/apps/sdk/agents.py b/api/apps/sdk/agents.py index 208b7a1be..14ea97fb6 100644 --- a/api/apps/sdk/agents.py +++ b/api/apps/sdk/agents.py @@ -41,12 +41,12 @@ def list_agents(tenant_id): return get_error_data_result("The agent doesn't exist.") page_number = int(request.args.get("page", 1)) items_per_page = int(request.args.get("page_size", 30)) - orderby = request.args.get("orderby", "update_time") + order_by = request.args.get("orderby", "update_time") if request.args.get("desc") == "False" or request.args.get("desc") == "false": desc = False else: desc = True - canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title) + canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, order_by, desc, id, title) return get_result(data=canvas) diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 868e054ae..c340255e7 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -41,7 +41,7 @@ from api.db.db_models import init_database_tables as init_web_db from api.db.init_data import init_web_data from common.versions import get_ragflow_version from common.config_utils import show_configs -from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions +from common.mcp_tool_call_conn import shutdown_all_mcp_sessions from rag.utils.redis_conn import RedisDistributedLock stop_event = threading.Event() diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 4cace9eca..1bd3f3e3c 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -37,7 +37,7 @@ from peewee import OperationalError from common.constants import ActiveEnum from api.db.db_models import APIToken from api.utils.json_encode import CustomJSONEncoder -from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions +from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from api.db.services.tenant_llm_service import LLMFactoriesService from common.connection_utils import timeout from common.constants import RetCode diff --git a/common/data_source/interfaces.py b/common/data_source/interfaces.py index 9c5f00141..5e5d3aa2e 100644 --- a/common/data_source/interfaces.py +++ b/common/data_source/interfaces.py @@ -69,7 +69,7 @@ class SlimConnectorWithPermSync(ABC): class CheckpointedConnectorWithPermSync(ABC): - """Checkpointed connector interface (with permission sync)""" + """Checkpoint connector interface (with permission sync)""" @abstractmethod def load_from_checkpoint( @@ -143,7 +143,7 @@ class CredentialsProviderInterface(abc.ABC, Generic[T]): @abc.abstractmethod def is_dynamic(self) -> bool: - """If dynamic, the credentials may change during usage ... maening the client + """If dynamic, the credentials may change during usage ... meaning the client needs to use the locking features of the credentials provider to operate correctly. diff --git a/rag/utils/mcp_tool_call_conn.py b/common/mcp_tool_call_conn.py similarity index 94% rename from rag/utils/mcp_tool_call_conn.py rename to common/mcp_tool_call_conn.py index 2093f7bc8..b19f063e1 100644 --- a/rag/utils/mcp_tool_call_conn.py +++ b/common/mcp_tool_call_conn.py @@ -21,7 +21,7 @@ import weakref from concurrent.futures import ThreadPoolExecutor from concurrent.futures import TimeoutError as FuturesTimeoutError from string import Template -from typing import Any, Literal +from typing import Any, Literal, Protocol from typing_extensions import override @@ -30,12 +30,15 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool -from rag.llm.chat_model import ToolCallSession MCPTaskType = Literal["list_tools", "tool_call"] MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]] +class ToolCallSession(Protocol): + def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ... + + class MCPToolCallSession(ToolCallSession): _ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet() @@ -106,7 +109,8 @@ class MCPToolCallSession(ToolCallSession): await self._process_mcp_tasks(None, msg) else: - await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}") + await self._process_mcp_tasks(None, + f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}") async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None: while not self._close: @@ -164,7 +168,8 @@ class MCPToolCallSession(ToolCallSession): raise async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str: - result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, timeout=timeout) + result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, + timeout=timeout) if result.isError: return f"MCP server error: {result.content}" @@ -283,7 +288,8 @@ def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> except Exception: logging.exception("Exception during MCP session cleanup thread management") - logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.") + logging.info( + f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.") def shutdown_all_mcp_sessions(): @@ -298,7 +304,7 @@ def shutdown_all_mcp_sessions(): logging.info("All MCPToolCallSession instances have been closed.") -def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool|dict) -> dict[str, Any]: +def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict) -> dict[str, Any]: if isinstance(mcp_tool, dict): return { "type": "function", diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 17ddbc138..c9e3b29f7 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -22,7 +22,6 @@ import re import time from abc import ABC from copy import deepcopy -from typing import Any, Protocol from urllib.parse import urljoin import json_repair @@ -65,10 +64,6 @@ LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小 LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length." -class ToolCallSession(Protocol): - def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ... - - class Base(ABC): def __init__(self, key, model_name, base_url, **kwargs): timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index 80acf1d8f..de7c2ce60 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -155,13 +155,13 @@ def qbullets_category(sections): if re.match(pro, sec) and not not_bullet(sec): hits[i] += 1 break - maxium = 0 + maximum = 0 res = -1 for i, h in enumerate(hits): - if h <= maxium: + if h <= maximum: continue res = i - maxium = h + maximum = h return res, QUESTION_PATTERN[res] @@ -222,13 +222,13 @@ def bullets_category(sections): if re.match(p, sec) and not not_bullet(sec): hits[i] += 1 break - maxium = 0 + maximum = 0 res = -1 for i, h in enumerate(hits): - if h <= maxium: + if h <= maximum: continue res = i - maxium = h + maximum = h return res