Move 'timeout' to common folder (#10983)

### What problem does this PR solve?

As title.

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-11-04 11:51:12 +08:00
committed by GitHub
parent 5283a10387
commit 1e45137284
34 changed files with 135 additions and 112 deletions

View File

@ -27,7 +27,7 @@ from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase,
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.mcp_server_service import MCPServerService
from api.utils.api_utils import timeout
from common.connection_utils import timeout
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool

View File

@ -25,7 +25,7 @@ from typing import Any, List, Union
import pandas as pd
import trio
from agent import settings
from api.utils.api_utils import timeout
from common.connection_utils import timeout
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"

View File

@ -21,7 +21,7 @@ from abc import ABC
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from agent.component.llm import LLMParam, LLM
from api.utils.api_utils import timeout
from common.connection_utils import timeout
from rag.llm.chat_model import ERROR_PREFIX

View File

@ -23,7 +23,7 @@ from abc import ABC
import requests
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
from deepdoc.parser import HtmlParser

View File

@ -25,7 +25,7 @@ from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
from rag.prompts.generator import tool_call_summary, message_fit_in, citation_prompt, structured_output_prompt

View File

@ -23,7 +23,7 @@ from typing import Any
from agent.component.base import ComponentBase, ComponentParamBase
from jinja2 import Template as Jinja2Template
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class MessageParam(ComponentParamBase):

View File

@ -18,7 +18,7 @@ import re
from abc import ABC
from jinja2 import Template as Jinja2Template
from agent.component.base import ComponentParamBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
from .message import Message

View File

@ -19,7 +19,7 @@ from abc import ABC
from typing import Any
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class SwitchParam(ComponentParamBase):

View File

@ -19,7 +19,7 @@ import time
from abc import ABC
import arxiv
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class ArXivParam(ToolParamBase):

View File

@ -22,7 +22,7 @@ from typing import Optional
from pydantic import BaseModel, Field, field_validator
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from api import settings
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class Language(StrEnum):

View File

@ -19,7 +19,7 @@ import time
from abc import ABC
from duckduckgo_search import DDGS
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class DuckDuckGoParam(ToolParamBase):

View File

@ -25,7 +25,7 @@ from email.header import Header
from email.utils import formataddr
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class EmailParam(ToolParamBase):

View File

@ -22,7 +22,7 @@ import pymysql
import psycopg2
import pyodbc
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class ExeSQLParam(ToolParamBase):

View File

@ -19,7 +19,7 @@ import time
from abc import ABC
import requests
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class GitHubParam(ToolParamBase):

View File

@ -19,7 +19,7 @@ import time
from abc import ABC
from serpapi import GoogleSearch
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class GoogleParam(ToolParamBase):

View File

@ -19,7 +19,7 @@ import time
from abc import ABC
from scholarly import scholarly
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class GoogleScholarParam(ToolParamBase):

View File

@ -21,7 +21,7 @@ from Bio import Entrez
import re
import xml.etree.ElementTree as ET
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class PubMedParam(ToolParamBase):

View File

@ -25,7 +25,7 @@ from api.db.services.dialog_service import meta_filter
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
from api.utils.api_utils import timeout
from common.connection_utils import timeout
from rag.app.tag import label_question
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter

View File

@ -19,7 +19,7 @@ import time
from abc import ABC
import requests
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class SearXNGParam(ToolParamBase):

View File

@ -19,7 +19,7 @@ import time
from abc import ABC
from tavily import TavilyClient
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class TavilySearchParam(ToolParamBase):

View File

@ -21,7 +21,7 @@ import pandas as pd
import pywencai
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class WenCaiParam(ToolParamBase):

View File

@ -19,7 +19,7 @@ import time
from abc import ABC
import wikipedia
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class WikipediaParam(ToolParamBase):

View File

@ -20,7 +20,7 @@ from abc import ABC
import pandas as pd
import yfinance as yf
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class YahooFinanceParam(ToolParamBase):

View File

@ -13,17 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import functools
import json
import logging
import os
import queue
import threading
import time
from copy import deepcopy
from functools import wraps
from typing import Any, Callable, Coroutine, Optional, Type, Union
import requests
import trio
@ -43,6 +40,7 @@ from api.db import ActiveEnum
from api.db.db_models import APIToken
from api.utils.json_encode import CustomJSONEncoder
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from common.connection_utils import timeout
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
@ -604,82 +602,6 @@ def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, s
return {}, str(e)
TimeoutException = Union[Type[BaseException], BaseException]
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None):
if isinstance(seconds, str):
seconds = float(seconds)
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
result_queue = queue.Queue(maxsize=1)
def target():
try:
result = func(*args, **kwargs)
result_queue.put(result)
except Exception as e:
result_queue.put(e)
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
for a in range(attempts):
try:
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
result = result_queue.get(timeout=seconds)
else:
result = result_queue.get()
if isinstance(result, Exception):
raise result
return result
except queue.Empty:
pass
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.")
@wraps(func)
async def async_wrapper(*args, **kwargs) -> Any:
if seconds is None:
return await func(*args, **kwargs)
for a in range(attempts):
try:
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
with trio.fail_after(seconds):
return await func(*args, **kwargs)
else:
return await func(*args, **kwargs)
except trio.TooSlowError:
if a < attempts - 1:
continue
if on_timeout is not None:
if callable(on_timeout):
result = on_timeout()
if isinstance(result, Coroutine):
return await result
return result
return on_timeout
if exception is None:
raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
if isinstance(exception, BaseException):
raise exception
if isinstance(exception, type) and issubclass(exception, BaseException):
raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
raise RuntimeError("Invalid exception type provided")
if asyncio.iscoroutinefunction(func):
return async_wrapper
return wrapper
return decorator
async def is_strong_enough(chat_model, embedding_model):
count = settings.STRONG_TEST_COUNT

101
common/connection_utils.py Normal file
View File

@ -0,0 +1,101 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import queue
import threading
from typing import Any, Callable, Coroutine, Optional, Type, Union
import asyncio
import trio
from functools import wraps
TimeoutException = Union[Type[BaseException], BaseException]
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None,
on_timeout: Optional[OnTimeoutCallback] = None):
if isinstance(seconds, str):
seconds = float(seconds)
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
result_queue = queue.Queue(maxsize=1)
def target():
try:
result = func(*args, **kwargs)
result_queue.put(result)
except Exception as e:
result_queue.put(e)
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
for a in range(attempts):
try:
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
result = result_queue.get(timeout=seconds)
else:
result = result_queue.get()
if isinstance(result, Exception):
raise result
return result
except queue.Empty:
pass
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.")
@wraps(func)
async def async_wrapper(*args, **kwargs) -> Any:
if seconds is None:
return await func(*args, **kwargs)
for a in range(attempts):
try:
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
with trio.fail_after(seconds):
return await func(*args, **kwargs)
else:
return await func(*args, **kwargs)
except trio.TooSlowError:
if a < attempts - 1:
continue
if on_timeout is not None:
if callable(on_timeout):
result = on_timeout()
if isinstance(result, Coroutine):
return await result
return result
return on_timeout
if exception is None:
raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
if isinstance(exception, BaseException):
raise exception
if isinstance(exception, type) and issubclass(exception, BaseException):
raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
raise RuntimeError("Invalid exception type provided")
if asyncio.iscoroutinefunction(func):
return async_wrapper
return wrapper
return decorator

View File

@ -19,7 +19,7 @@ from PIL import Image
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.utils.api_utils import timeout
from common.connection_utils import timeout
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
from rag.prompts.generator import vision_llm_figure_describe_prompt

View File

@ -14,7 +14,7 @@ from dataclasses import dataclass
import networkx as nx
import pandas as pd
from api.utils.api_utils import timeout
from common.connection_utils import timeout
from graphrag.general import leiden
from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.general.extractor import Extractor

View File

@ -23,7 +23,7 @@ from typing import Callable
import networkx as nx
import trio
from api.utils.api_utils import timeout
from common.connection_utils import timeout
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
from graphrag.utils import (
GraphChange,

View File

@ -23,7 +23,7 @@ import trio
from api import settings
from api.db.services.document_service import DocumentService
from common.misc_utils import get_uuid
from api.utils.api_utils import timeout
from common.connection_utils import timeout
from graphrag.entity_resolution import EntityResolution
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
from graphrag.general.extractor import Extractor

View File

@ -25,7 +25,7 @@ from networkx.readwrite import json_graph
from api import settings
from common.misc_utils import get_uuid
from api.utils.api_utils import timeout
from common.connection_utils import timeout
from rag.nlp import rag_tokenizer, search
from rag.utils.doc_store_conn import OrderByExpr
from rag.utils.redis_conn import REDIS_CONN

View File

@ -20,7 +20,7 @@ from functools import partial
from typing import Any
import trio
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
from common.connection_utils import timeout
class ProcessParamBase(ComponentParamBase):

View File

@ -23,7 +23,7 @@ from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from api.utils.api_utils import timeout
from common.connection_utils import timeout
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.tokenizer.schema import TokenizerFromUpstream
from rag.nlp import rag_tokenizer

View File

@ -20,7 +20,7 @@ import numpy as np
from sklearn.mixture import GaussianMixture
import trio
from api.utils.api_utils import timeout
from common.connection_utils import timeout
from graphrag.utils import (
get_llm_cache,
get_embed_cache,

View File

@ -26,7 +26,7 @@ import json_repair
from api.db.services.canvas_service import UserCanvasService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.utils.api_utils import timeout
from common.connection_utils import timeout
from common.base64_image import image2id
from common.log_utils import init_root_logger
from common.file_utils import get_project_base_directory