diff --git a/rag/utils/mcp_tool_call_conn.py b/rag/utils/mcp_tool_call_conn.py index 5aca0e754..5e449477f 100644 --- a/rag/utils/mcp_tool_call_conn.py +++ b/rag/utils/mcp_tool_call_conn.py @@ -65,32 +65,42 @@ class MCPToolCallSession(ToolCallSession): if self._mcp_server.server_type == MCPServerType.SSE: # SSE transport - async with sse_client(url, headers) as stream: - async with ClientSession(*stream) as client_session: - try: - await asyncio.wait_for(client_session.initialize(), timeout=5) - logging.info("client_session initialized successfully") - except asyncio.TimeoutError: - logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}") - return - await self._process_mcp_tasks(client_session) + try: + async with sse_client(url, headers) as stream: + async with ClientSession(*stream) as client_session: + try: + await asyncio.wait_for(client_session.initialize(), timeout=5) + logging.info("client_session initialized successfully") + await self._process_mcp_tasks(client_session) + except asyncio.TimeoutError: + msg = f"Timeout initializing client_session for server {self._mcp_server.id}" + logging.error(msg) + await self._process_mcp_tasks(None, msg) + except Exception: + msg = "Connection failed (possibly due to auth error). Please check authentication settings first" + await self._process_mcp_tasks(None, msg) elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP: # Streamable HTTP transport - async with streamablehttp_client(url, headers) as (read_stream, write_stream, _): - async with ClientSession(read_stream, write_stream) as client_session: - try: - await asyncio.wait_for(client_session.initialize(), timeout=5) - logging.info("client_session initialized successfully") - except asyncio.TimeoutError: - logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}") - return - await asyncio.wait_for(client_session.initialize(), timeout=5) - await self._process_mcp_tasks(client_session) - else: - raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}") + try: + async with streamablehttp_client(url, headers) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as client_session: + try: + await asyncio.wait_for(client_session.initialize(), timeout=5) + logging.info("client_session initialized successfully") + await self._process_mcp_tasks(client_session) + except asyncio.TimeoutError: + msg = f"Timeout initializing client_session for server {self._mcp_server.id}" + logging.error(msg) + await self._process_mcp_tasks(None, msg) + except Exception: + msg = "Connection failed (possibly due to auth error). Please check authentication settings first" + await self._process_mcp_tasks(None, msg) - async def _process_mcp_tasks(self, client_session: ClientSession) -> None: + else: + await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}") + + async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None: while not self._close: try: mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1) @@ -100,6 +110,12 @@ class MCPToolCallSession(ToolCallSession): logging.debug(f"Got MCP task {mcp_task} arguments {arguments}") r: Any = None + + if not client_session or error_message: + r = ValueError(error_message) + await result_queue.put(r) + continue + try: if mcp_task == "list_tools": r = await client_session.list_tools() @@ -112,21 +128,22 @@ class MCPToolCallSession(ToolCallSession): await result_queue.put(r) - async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float = 8, **kwargs) -> Any: + async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float | int = 8, **kwargs) -> Any: results = asyncio.Queue() await self._queue.put((task_type, kwargs, results)) + try: result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout) + if isinstance(result, Exception): + raise result + return result except asyncio.TimeoutError: raise asyncio.TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s") + except Exception: + raise - if isinstance(result, Exception): - raise result - - return result - - async def _call_mcp_tool(self, name: str, arguments: dict[str, Any]) -> str: - result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments) + async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str: + result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, timeout=timeout) if result.isError: return f"MCP server error: {result.content}" @@ -137,23 +154,27 @@ class MCPToolCallSession(ToolCallSession): else: return f"Unsupported content type {type(result.content)}" - async def _get_tools_from_mcp_server(self) -> list[Tool]: - result: ListToolsResult = await self._call_mcp_server("list_tools") - return result.tools + async def _get_tools_from_mcp_server(self, timeout: float | int = 8) -> list[Tool]: + try: + result: ListToolsResult = await self._call_mcp_server("list_tools", timeout=timeout) + return result.tools + except Exception: + raise - def get_tools(self, timeout: float = 10) -> list[Tool]: - future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), self._event_loop) + def get_tools(self, timeout: float | int = 10) -> list[Tool]: + future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(timeout=timeout), self._event_loop) try: return future.result(timeout=timeout) except FuturesTimeoutError: - logging.error(f"Timeout when fetching tools from MCP server: {self._mcp_server.id} (timeout={timeout})") - return [] + msg = f"Timeout when fetching tools from MCP server: {self._mcp_server.id} (timeout={timeout})" + logging.error(msg) + raise RuntimeError(msg) except Exception: logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}") - return [] + raise @override - def tool_call(self, name: str, arguments: dict[str, Any], timeout: float = 10) -> str: + def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str: future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop) try: return future.result(timeout=timeout) @@ -173,7 +194,7 @@ class MCPToolCallSession(ToolCallSession): self._thread_pool.shutdown(wait=True) self.__class__._ALL_INSTANCES.discard(self) - def close_sync(self, timeout: float = 5) -> None: + def close_sync(self, timeout: float | int = 5) -> None: if not self._event_loop.is_running(): logging.warning(f"Event loop already stopped for {self._mcp_server.id}") return