diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index b3e2df6f5..53c2ed9ad 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -27,7 +27,7 @@ from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.mcp_server_service import MCPServerService -from api.utils.api_utils import timeout +from common.connection_utils import timeout from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \ citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool diff --git a/agent/component/base.py b/agent/component/base.py index 73f11ba95..db9a8f29d 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -25,7 +25,7 @@ from typing import Any, List, Union import pandas as pd import trio from agent import settings -from api.utils.api_utils import timeout +from common.connection_utils import timeout _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" diff --git a/agent/component/categorize.py b/agent/component/categorize.py index af2666fcb..1c611daf7 100644 --- a/agent/component/categorize.py +++ b/agent/component/categorize.py @@ -21,7 +21,7 @@ from abc import ABC from api.db import LLMType from api.db.services.llm_service import LLMBundle from agent.component.llm import LLMParam, LLM -from api.utils.api_utils import timeout +from common.connection_utils import timeout from rag.llm.chat_model import ERROR_PREFIX diff --git a/agent/component/invoke.py b/agent/component/invoke.py index d31c7ed25..00a39b905 100644 --- a/agent/component/invoke.py +++ b/agent/component/invoke.py @@ -23,7 +23,7 @@ from abc import ABC import requests from agent.component.base import ComponentBase, ComponentParamBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout from deepdoc.parser import HtmlParser diff --git a/agent/component/llm.py b/agent/component/llm.py index 61e52b6fa..123124765 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -25,7 +25,7 @@ from api.db import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService from agent.component.base import ComponentBase, ComponentParamBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout from rag.prompts.generator import tool_call_summary, message_fit_in, citation_prompt, structured_output_prompt diff --git a/agent/component/message.py b/agent/component/message.py index 3569065e5..a91c2522e 100644 --- a/agent/component/message.py +++ b/agent/component/message.py @@ -23,7 +23,7 @@ from typing import Any from agent.component.base import ComponentBase, ComponentParamBase from jinja2 import Template as Jinja2Template -from api.utils.api_utils import timeout +from common.connection_utils import timeout class MessageParam(ComponentParamBase): diff --git a/agent/component/string_transform.py b/agent/component/string_transform.py index fe812c0a8..7802075d1 100644 --- a/agent/component/string_transform.py +++ b/agent/component/string_transform.py @@ -18,7 +18,7 @@ import re from abc import ABC from jinja2 import Template as Jinja2Template from agent.component.base import ComponentParamBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout from .message import Message diff --git a/agent/component/switch.py b/agent/component/switch.py index 8cbbde659..41c25c32f 100644 --- a/agent/component/switch.py +++ b/agent/component/switch.py @@ -19,7 +19,7 @@ from abc import ABC from typing import Any from agent.component.base import ComponentBase, ComponentParamBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout class SwitchParam(ComponentParamBase): diff --git a/agent/tools/arxiv.py b/agent/tools/arxiv.py index 616afa31a..74a810c74 100644 --- a/agent/tools/arxiv.py +++ b/agent/tools/arxiv.py @@ -19,7 +19,7 @@ import time from abc import ABC import arxiv from agent.tools.base import ToolParamBase, ToolMeta, ToolBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout class ArXivParam(ToolParamBase): diff --git a/agent/tools/code_exec.py b/agent/tools/code_exec.py index 6bd1af34e..e1fef7fbe 100644 --- a/agent/tools/code_exec.py +++ b/agent/tools/code_exec.py @@ -22,7 +22,7 @@ from typing import Optional from pydantic import BaseModel, Field, field_validator from agent.tools.base import ToolParamBase, ToolBase, ToolMeta from api import settings -from api.utils.api_utils import timeout +from common.connection_utils import timeout class Language(StrEnum): diff --git a/agent/tools/duckduckgo.py b/agent/tools/duckduckgo.py index 0315d6971..fcf5ee077 100644 --- a/agent/tools/duckduckgo.py +++ b/agent/tools/duckduckgo.py @@ -19,7 +19,7 @@ import time from abc import ABC from duckduckgo_search import DDGS from agent.tools.base import ToolMeta, ToolParamBase, ToolBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout class DuckDuckGoParam(ToolParamBase): diff --git a/agent/tools/email.py b/agent/tools/email.py index ab6cc6ea6..42d3e2878 100644 --- a/agent/tools/email.py +++ b/agent/tools/email.py @@ -25,7 +25,7 @@ from email.header import Header from email.utils import formataddr from agent.tools.base import ToolParamBase, ToolBase, ToolMeta -from api.utils.api_utils import timeout +from common.connection_utils import timeout class EmailParam(ToolParamBase): diff --git a/agent/tools/exesql.py b/agent/tools/exesql.py index d93745323..b5917e730 100644 --- a/agent/tools/exesql.py +++ b/agent/tools/exesql.py @@ -22,7 +22,7 @@ import pymysql import psycopg2 import pyodbc from agent.tools.base import ToolParamBase, ToolBase, ToolMeta -from api.utils.api_utils import timeout +from common.connection_utils import timeout class ExeSQLParam(ToolParamBase): diff --git a/agent/tools/github.py b/agent/tools/github.py index 27cb1e346..7b53f0b0b 100644 --- a/agent/tools/github.py +++ b/agent/tools/github.py @@ -19,7 +19,7 @@ import time from abc import ABC import requests from agent.tools.base import ToolParamBase, ToolMeta, ToolBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout class GitHubParam(ToolParamBase): diff --git a/agent/tools/google.py b/agent/tools/google.py index 455038abe..3184aaaeb 100644 --- a/agent/tools/google.py +++ b/agent/tools/google.py @@ -19,7 +19,7 @@ import time from abc import ABC from serpapi import GoogleSearch from agent.tools.base import ToolParamBase, ToolMeta, ToolBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout class GoogleParam(ToolParamBase): diff --git a/agent/tools/googlescholar.py b/agent/tools/googlescholar.py index bf906da4b..da7f6ef69 100644 --- a/agent/tools/googlescholar.py +++ b/agent/tools/googlescholar.py @@ -19,7 +19,7 @@ import time from abc import ABC from scholarly import scholarly from agent.tools.base import ToolMeta, ToolParamBase, ToolBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout class GoogleScholarParam(ToolParamBase): diff --git a/agent/tools/pubmed.py b/agent/tools/pubmed.py index 0920b3e23..afa171768 100644 --- a/agent/tools/pubmed.py +++ b/agent/tools/pubmed.py @@ -21,7 +21,7 @@ from Bio import Entrez import re import xml.etree.ElementTree as ET from agent.tools.base import ToolParamBase, ToolMeta, ToolBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout class PubMedParam(ToolParamBase): diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index bb44b124a..99ed04e2f 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -25,7 +25,7 @@ from api.db.services.dialog_service import meta_filter from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api import settings -from api.utils.api_utils import timeout +from common.connection_utils import timeout from rag.app.tag import label_question from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter diff --git a/agent/tools/searxng.py b/agent/tools/searxng.py index 32e807585..44ad18bae 100644 --- a/agent/tools/searxng.py +++ b/agent/tools/searxng.py @@ -19,7 +19,7 @@ import time from abc import ABC import requests from agent.tools.base import ToolMeta, ToolParamBase, ToolBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout class SearXNGParam(ToolParamBase): diff --git a/agent/tools/tavily.py b/agent/tools/tavily.py index 80203feec..6912c3695 100644 --- a/agent/tools/tavily.py +++ b/agent/tools/tavily.py @@ -19,7 +19,7 @@ import time from abc import ABC from tavily import TavilyClient from agent.tools.base import ToolParamBase, ToolBase, ToolMeta -from api.utils.api_utils import timeout +from common.connection_utils import timeout class TavilySearchParam(ToolParamBase): diff --git a/agent/tools/wencai.py b/agent/tools/wencai.py index e2f8adefc..7ddf27ac3 100644 --- a/agent/tools/wencai.py +++ b/agent/tools/wencai.py @@ -21,7 +21,7 @@ import pandas as pd import pywencai from agent.tools.base import ToolParamBase, ToolMeta, ToolBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout class WenCaiParam(ToolParamBase): diff --git a/agent/tools/wikipedia.py b/agent/tools/wikipedia.py index 83e3b13a8..8dcddc9b9 100644 --- a/agent/tools/wikipedia.py +++ b/agent/tools/wikipedia.py @@ -19,7 +19,7 @@ import time from abc import ABC import wikipedia from agent.tools.base import ToolMeta, ToolParamBase, ToolBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout class WikipediaParam(ToolParamBase): diff --git a/agent/tools/yahoofinance.py b/agent/tools/yahoofinance.py index 9feea20af..3cca93f3d 100644 --- a/agent/tools/yahoofinance.py +++ b/agent/tools/yahoofinance.py @@ -20,7 +20,7 @@ from abc import ABC import pandas as pd import yfinance as yf from agent.tools.base import ToolMeta, ToolParamBase, ToolBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout class YahooFinanceParam(ToolParamBase): diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 3e4acc37c..5847858ac 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -13,17 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import asyncio + import functools import json import logging import os -import queue -import threading import time from copy import deepcopy from functools import wraps -from typing import Any, Callable, Coroutine, Optional, Type, Union import requests import trio @@ -43,6 +40,7 @@ from api.db import ActiveEnum from api.db.db_models import APIToken from api.utils.json_encode import CustomJSONEncoder from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions +from common.connection_utils import timeout requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) @@ -604,82 +602,6 @@ def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, s return {}, str(e) -TimeoutException = Union[Type[BaseException], BaseException] -OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]] - - -def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None): - if isinstance(seconds, str): - seconds = float(seconds) - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - result_queue = queue.Queue(maxsize=1) - - def target(): - try: - result = func(*args, **kwargs) - result_queue.put(result) - except Exception as e: - result_queue.put(e) - - thread = threading.Thread(target=target) - thread.daemon = True - thread.start() - - for a in range(attempts): - try: - if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): - result = result_queue.get(timeout=seconds) - else: - result = result_queue.get() - if isinstance(result, Exception): - raise result - return result - except queue.Empty: - pass - raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.") - - @wraps(func) - async def async_wrapper(*args, **kwargs) -> Any: - if seconds is None: - return await func(*args, **kwargs) - - for a in range(attempts): - try: - if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): - with trio.fail_after(seconds): - return await func(*args, **kwargs) - else: - return await func(*args, **kwargs) - except trio.TooSlowError: - if a < attempts - 1: - continue - if on_timeout is not None: - if callable(on_timeout): - result = on_timeout() - if isinstance(result, Coroutine): - return await result - return result - return on_timeout - - if exception is None: - raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.") - - if isinstance(exception, BaseException): - raise exception - - if isinstance(exception, type) and issubclass(exception, BaseException): - raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.") - - raise RuntimeError("Invalid exception type provided") - - if asyncio.iscoroutinefunction(func): - return async_wrapper - return wrapper - - return decorator - async def is_strong_enough(chat_model, embedding_model): count = settings.STRONG_TEST_COUNT diff --git a/common/connection_utils.py b/common/connection_utils.py new file mode 100644 index 000000000..bebef0f4c --- /dev/null +++ b/common/connection_utils.py @@ -0,0 +1,101 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import queue +import threading +from typing import Any, Callable, Coroutine, Optional, Type, Union +import asyncio +import trio +from functools import wraps + +TimeoutException = Union[Type[BaseException], BaseException] +OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]] + + +def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, + on_timeout: Optional[OnTimeoutCallback] = None): + if isinstance(seconds, str): + seconds = float(seconds) + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + result_queue = queue.Queue(maxsize=1) + + def target(): + try: + result = func(*args, **kwargs) + result_queue.put(result) + except Exception as e: + result_queue.put(e) + + thread = threading.Thread(target=target) + thread.daemon = True + thread.start() + + for a in range(attempts): + try: + if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): + result = result_queue.get(timeout=seconds) + else: + result = result_queue.get() + if isinstance(result, Exception): + raise result + return result + except queue.Empty: + pass + raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.") + + @wraps(func) + async def async_wrapper(*args, **kwargs) -> Any: + if seconds is None: + return await func(*args, **kwargs) + + for a in range(attempts): + try: + if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): + with trio.fail_after(seconds): + return await func(*args, **kwargs) + else: + return await func(*args, **kwargs) + except trio.TooSlowError: + if a < attempts - 1: + continue + if on_timeout is not None: + if callable(on_timeout): + result = on_timeout() + if isinstance(result, Coroutine): + return await result + return result + return on_timeout + + if exception is None: + raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.") + + if isinstance(exception, BaseException): + raise exception + + if isinstance(exception, type) and issubclass(exception, BaseException): + raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.") + + raise RuntimeError("Invalid exception type provided") + + if asyncio.iscoroutinefunction(func): + return async_wrapper + return wrapper + + return decorator diff --git a/deepdoc/parser/figure_parser.py b/deepdoc/parser/figure_parser.py index 98c1d3349..b71ece121 100644 --- a/deepdoc/parser/figure_parser.py +++ b/deepdoc/parser/figure_parser.py @@ -19,7 +19,7 @@ from PIL import Image from api.db import LLMType from api.db.services.llm_service import LLMBundle -from api.utils.api_utils import timeout +from common.connection_utils import timeout from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk from rag.prompts.generator import vision_llm_figure_describe_prompt diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index fe611dedf..6c49c0a73 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -14,7 +14,7 @@ from dataclasses import dataclass import networkx as nx import pandas as pd -from api.utils.api_utils import timeout +from common.connection_utils import timeout from graphrag.general import leiden from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT from graphrag.general.extractor import Extractor diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index a41ffd6a2..9b18d694f 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -23,7 +23,7 @@ from typing import Callable import networkx as nx import trio -from api.utils.api_utils import timeout +from common.connection_utils import timeout from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT from graphrag.utils import ( GraphChange, diff --git a/graphrag/general/index.py b/graphrag/general/index.py index 52b298e32..51f79f57f 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -23,7 +23,7 @@ import trio from api import settings from api.db.services.document_service import DocumentService from common.misc_utils import get_uuid -from api.utils.api_utils import timeout +from common.connection_utils import timeout from graphrag.entity_resolution import EntityResolution from graphrag.general.community_reports_extractor import CommunityReportsExtractor from graphrag.general.extractor import Extractor diff --git a/graphrag/utils.py b/graphrag/utils.py index 8250dea8c..b64a12265 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -25,7 +25,7 @@ from networkx.readwrite import json_graph from api import settings from common.misc_utils import get_uuid -from api.utils.api_utils import timeout +from common.connection_utils import timeout from rag.nlp import rag_tokenizer, search from rag.utils.doc_store_conn import OrderByExpr from rag.utils.redis_conn import REDIS_CONN diff --git a/rag/flow/base.py b/rag/flow/base.py index 5edc280f8..4b256e78f 100644 --- a/rag/flow/base.py +++ b/rag/flow/base.py @@ -20,7 +20,7 @@ from functools import partial from typing import Any import trio from agent.component.base import ComponentBase, ComponentParamBase -from api.utils.api_utils import timeout +from common.connection_utils import timeout class ProcessParamBase(ComponentParamBase): diff --git a/rag/flow/tokenizer/tokenizer.py b/rag/flow/tokenizer/tokenizer.py index 686a0bb7e..ee4b277fc 100644 --- a/rag/flow/tokenizer/tokenizer.py +++ b/rag/flow/tokenizer/tokenizer.py @@ -23,7 +23,7 @@ from api.db import LLMType from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.user_service import TenantService -from api.utils.api_utils import timeout +from common.connection_utils import timeout from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.tokenizer.schema import TokenizerFromUpstream from rag.nlp import rag_tokenizer diff --git a/rag/raptor.py b/rag/raptor.py index a3d369189..22f8ce397 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -20,7 +20,7 @@ import numpy as np from sklearn.mixture import GaussianMixture import trio -from api.utils.api_utils import timeout +from common.connection_utils import timeout from graphrag.utils import ( get_llm_cache, get_embed_cache, diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index f42523492..5bc267ecd 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -26,7 +26,7 @@ import json_repair from api.db.services.canvas_service import UserCanvasService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService -from api.utils.api_utils import timeout +from common.connection_utils import timeout from common.base64_image import image2id from common.log_utils import init_root_logger from common.file_utils import get_project_base_directory