mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: more robust mcp tool call (#8631)
### What problem does this PR solve? More robust MCP tool call conn. ### Type of change - [x] Refactoring
This commit is contained in:
@ -65,32 +65,42 @@ class MCPToolCallSession(ToolCallSession):
|
|||||||
|
|
||||||
if self._mcp_server.server_type == MCPServerType.SSE:
|
if self._mcp_server.server_type == MCPServerType.SSE:
|
||||||
# SSE transport
|
# SSE transport
|
||||||
|
try:
|
||||||
async with sse_client(url, headers) as stream:
|
async with sse_client(url, headers) as stream:
|
||||||
async with ClientSession(*stream) as client_session:
|
async with ClientSession(*stream) as client_session:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(client_session.initialize(), timeout=5)
|
await asyncio.wait_for(client_session.initialize(), timeout=5)
|
||||||
logging.info("client_session initialized successfully")
|
logging.info("client_session initialized successfully")
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}")
|
|
||||||
return
|
|
||||||
await self._process_mcp_tasks(client_session)
|
await self._process_mcp_tasks(client_session)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
msg = f"Timeout initializing client_session for server {self._mcp_server.id}"
|
||||||
|
logging.error(msg)
|
||||||
|
await self._process_mcp_tasks(None, msg)
|
||||||
|
except Exception:
|
||||||
|
msg = "Connection failed (possibly due to auth error). Please check authentication settings first"
|
||||||
|
await self._process_mcp_tasks(None, msg)
|
||||||
|
|
||||||
elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
|
elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
|
||||||
# Streamable HTTP transport
|
# Streamable HTTP transport
|
||||||
|
try:
|
||||||
async with streamablehttp_client(url, headers) as (read_stream, write_stream, _):
|
async with streamablehttp_client(url, headers) as (read_stream, write_stream, _):
|
||||||
async with ClientSession(read_stream, write_stream) as client_session:
|
async with ClientSession(read_stream, write_stream) as client_session:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(client_session.initialize(), timeout=5)
|
await asyncio.wait_for(client_session.initialize(), timeout=5)
|
||||||
logging.info("client_session initialized successfully")
|
logging.info("client_session initialized successfully")
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}")
|
|
||||||
return
|
|
||||||
await asyncio.wait_for(client_session.initialize(), timeout=5)
|
|
||||||
await self._process_mcp_tasks(client_session)
|
await self._process_mcp_tasks(client_session)
|
||||||
else:
|
except asyncio.TimeoutError:
|
||||||
raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}")
|
msg = f"Timeout initializing client_session for server {self._mcp_server.id}"
|
||||||
|
logging.error(msg)
|
||||||
|
await self._process_mcp_tasks(None, msg)
|
||||||
|
except Exception:
|
||||||
|
msg = "Connection failed (possibly due to auth error). Please check authentication settings first"
|
||||||
|
await self._process_mcp_tasks(None, msg)
|
||||||
|
|
||||||
async def _process_mcp_tasks(self, client_session: ClientSession) -> None:
|
else:
|
||||||
|
await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
|
||||||
|
|
||||||
|
async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None:
|
||||||
while not self._close:
|
while not self._close:
|
||||||
try:
|
try:
|
||||||
mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1)
|
mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1)
|
||||||
@ -100,6 +110,12 @@ class MCPToolCallSession(ToolCallSession):
|
|||||||
logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")
|
logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")
|
||||||
|
|
||||||
r: Any = None
|
r: Any = None
|
||||||
|
|
||||||
|
if not client_session or error_message:
|
||||||
|
r = ValueError(error_message)
|
||||||
|
await result_queue.put(r)
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if mcp_task == "list_tools":
|
if mcp_task == "list_tools":
|
||||||
r = await client_session.list_tools()
|
r = await client_session.list_tools()
|
||||||
@ -112,21 +128,22 @@ class MCPToolCallSession(ToolCallSession):
|
|||||||
|
|
||||||
await result_queue.put(r)
|
await result_queue.put(r)
|
||||||
|
|
||||||
async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float = 8, **kwargs) -> Any:
|
async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float | int = 8, **kwargs) -> Any:
|
||||||
results = asyncio.Queue()
|
results = asyncio.Queue()
|
||||||
await self._queue.put((task_type, kwargs, results))
|
await self._queue.put((task_type, kwargs, results))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout)
|
result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout)
|
||||||
except asyncio.TimeoutError:
|
|
||||||
raise asyncio.TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s")
|
|
||||||
|
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
raise result
|
raise result
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise asyncio.TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s")
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
async def _call_mcp_tool(self, name: str, arguments: dict[str, Any]) -> str:
|
async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
|
||||||
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments)
|
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, timeout=timeout)
|
||||||
|
|
||||||
if result.isError:
|
if result.isError:
|
||||||
return f"MCP server error: {result.content}"
|
return f"MCP server error: {result.content}"
|
||||||
@ -137,23 +154,27 @@ class MCPToolCallSession(ToolCallSession):
|
|||||||
else:
|
else:
|
||||||
return f"Unsupported content type {type(result.content)}"
|
return f"Unsupported content type {type(result.content)}"
|
||||||
|
|
||||||
async def _get_tools_from_mcp_server(self) -> list[Tool]:
|
async def _get_tools_from_mcp_server(self, timeout: float | int = 8) -> list[Tool]:
|
||||||
result: ListToolsResult = await self._call_mcp_server("list_tools")
|
try:
|
||||||
|
result: ListToolsResult = await self._call_mcp_server("list_tools", timeout=timeout)
|
||||||
return result.tools
|
return result.tools
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
def get_tools(self, timeout: float = 10) -> list[Tool]:
|
def get_tools(self, timeout: float | int = 10) -> list[Tool]:
|
||||||
future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), self._event_loop)
|
future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(timeout=timeout), self._event_loop)
|
||||||
try:
|
try:
|
||||||
return future.result(timeout=timeout)
|
return future.result(timeout=timeout)
|
||||||
except FuturesTimeoutError:
|
except FuturesTimeoutError:
|
||||||
logging.error(f"Timeout when fetching tools from MCP server: {self._mcp_server.id} (timeout={timeout})")
|
msg = f"Timeout when fetching tools from MCP server: {self._mcp_server.id} (timeout={timeout})"
|
||||||
return []
|
logging.error(msg)
|
||||||
|
raise RuntimeError(msg)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}")
|
logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}")
|
||||||
return []
|
raise
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def tool_call(self, name: str, arguments: dict[str, Any], timeout: float = 10) -> str:
|
def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
|
||||||
future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop)
|
future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop)
|
||||||
try:
|
try:
|
||||||
return future.result(timeout=timeout)
|
return future.result(timeout=timeout)
|
||||||
@ -173,7 +194,7 @@ class MCPToolCallSession(ToolCallSession):
|
|||||||
self._thread_pool.shutdown(wait=True)
|
self._thread_pool.shutdown(wait=True)
|
||||||
self.__class__._ALL_INSTANCES.discard(self)
|
self.__class__._ALL_INSTANCES.discard(self)
|
||||||
|
|
||||||
def close_sync(self, timeout: float = 5) -> None:
|
def close_sync(self, timeout: float | int = 5) -> None:
|
||||||
if not self._event_loop.is_running():
|
if not self._event_loop.is_running():
|
||||||
logging.warning(f"Event loop already stopped for {self._mcp_server.id}")
|
logging.warning(f"Event loop already stopped for {self._mcp_server.id}")
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user