Perf: Enhance timeout handling. (#8826)

### What problem does this PR solve?


### Type of change

- [x] Performance Improvement
This commit is contained in:
Kevin Hu
2025-07-15 09:36:45 +08:00
committed by GitHub
parent ce140f1393
commit c642dbefca
10 changed files with 207 additions and 85 deletions

View File

@ -13,19 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import functools
import json
import logging
import queue
import random
import threading
import time
from base64 import b64encode
from copy import deepcopy
from functools import wraps
from hmac import HMAC
from io import BytesIO
from typing import Any, Optional, Union, Callable, Coroutine, Type
from urllib.parse import quote, urlencode
from uuid import uuid1
import trio
from api.db.db_models import MCPServer
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
import requests
from flask import (
Response,
@ -558,3 +568,101 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
transformed_data[mapped_key] = value
return transformed_data
def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]:
results = {}
tool_call_sessions = []
try:
for mcp_server in mcp_servers:
server_key = mcp_server.id
cached_tools = mcp_server.variables.get("tools", {})
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
try:
tools = tool_call_session.get_tools(timeout)
except Exception:
tools = []
results[server_key] = []
for tool in tools:
tool_dict = tool.model_dump()
cached_tool = cached_tools.get(tool_dict["name"], {})
tool_dict["enabled"] = cached_tool.get("enabled", True)
results[server_key].append(tool_dict)
# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
return results, ""
except Exception as e:
return {}, str(e)
TimeoutException = Union[Type[BaseException], BaseException]
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
def timeout(
seconds: float |int = None,
*,
exception: Optional[TimeoutException] = None,
on_timeout: Optional[OnTimeoutCallback] = None
):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
result_queue = queue.Queue(maxsize=1)
def target():
try:
result = func(*args, **kwargs)
result_queue.put(result)
except Exception as e:
result_queue.put(e)
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
try:
result = result_queue.get(timeout=seconds)
if isinstance(result, Exception):
raise result
return result
except queue.Empty:
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds")
@wraps(func)
async def async_wrapper(*args, **kwargs) -> Any:
if seconds is None:
return await func(*args, **kwargs)
try:
with trio.fail_after(seconds):
return await func(*args, **kwargs)
except trio.TooSlowError:
if on_timeout is not None:
if callable(on_timeout):
result = on_timeout()
if isinstance(result, Coroutine):
return await result
return result
return on_timeout
if exception is None:
raise TimeoutError(f"Operation timed out after {seconds} seconds")
if isinstance(exception, BaseException):
raise exception
if isinstance(exception, type) and issubclass(exception, BaseException):
raise exception(f"Operation timed out after {seconds} seconds")
raise RuntimeError("Invalid exception type provided")
if asyncio.iscoroutinefunction(func):
return async_wrapper
return wrapper
return decorator

View File

@ -1,34 +0,0 @@
from api.db.db_models import MCPServer
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]:
results = {}
tool_call_sessions = []
try:
for mcp_server in mcp_servers:
server_key = mcp_server.id
cached_tools = mcp_server.variables.get("tools", {})
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
try:
tools = tool_call_session.get_tools(timeout)
except Exception:
tools = []
results[server_key] = []
for tool in tools:
tool_dict = tool.model_dump()
cached_tool = cached_tools.get(tool_dict["name"], {})
tool_dict["enabled"] = cached_tool.get("enabled", True)
results[server_key].append(tool_dict)
# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
return results, ""
except Exception as e:
return {}, str(e)