mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Perf: Enhance timeout handling. (#8826)
### What problem does this PR solve? ### Type of change - [x] Performance Improvement
This commit is contained in:
@ -1,3 +1,18 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
from flask import Response, request
|
from flask import Response, request
|
||||||
from flask_login import current_user, login_required
|
from flask_login import current_user, login_required
|
||||||
|
|
||||||
@ -6,9 +21,10 @@ from api.db.db_models import MCPServer
|
|||||||
from api.db.services.mcp_server_service import MCPServerService
|
from api.db.services.mcp_server_service import MCPServerService
|
||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from api.settings import RetCode
|
from api.settings import RetCode
|
||||||
|
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||||
from api.utils.mcp_server import get_mcp_tools
|
get_mcp_tools
|
||||||
from api.utils.web_utils import get_float, safe_json_parse
|
from api.utils.web_utils import get_float, safe_json_parse
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||||
|
|
||||||
|
|||||||
@ -13,19 +13,29 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import queue
|
||||||
import random
|
import random
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from hmac import HMAC
|
from hmac import HMAC
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from typing import Any, Optional, Union, Callable, Coroutine, Type
|
||||||
from urllib.parse import quote, urlencode
|
from urllib.parse import quote, urlencode
|
||||||
from uuid import uuid1
|
from uuid import uuid1
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from api.db.db_models import MCPServer
|
||||||
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||||
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from flask import (
|
from flask import (
|
||||||
Response,
|
Response,
|
||||||
@ -558,3 +568,101 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
|
|||||||
transformed_data[mapped_key] = value
|
transformed_data[mapped_key] = value
|
||||||
|
|
||||||
return transformed_data
|
return transformed_data
|
||||||
|
|
||||||
|
|
||||||
|
def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]:
|
||||||
|
results = {}
|
||||||
|
tool_call_sessions = []
|
||||||
|
try:
|
||||||
|
for mcp_server in mcp_servers:
|
||||||
|
server_key = mcp_server.id
|
||||||
|
|
||||||
|
cached_tools = mcp_server.variables.get("tools", {})
|
||||||
|
|
||||||
|
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||||
|
tool_call_sessions.append(tool_call_session)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tools = tool_call_session.get_tools(timeout)
|
||||||
|
except Exception:
|
||||||
|
tools = []
|
||||||
|
|
||||||
|
results[server_key] = []
|
||||||
|
for tool in tools:
|
||||||
|
tool_dict = tool.model_dump()
|
||||||
|
cached_tool = cached_tools.get(tool_dict["name"], {})
|
||||||
|
|
||||||
|
tool_dict["enabled"] = cached_tool.get("enabled", True)
|
||||||
|
results[server_key].append(tool_dict)
|
||||||
|
|
||||||
|
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||||
|
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
||||||
|
return results, ""
|
||||||
|
except Exception as e:
|
||||||
|
return {}, str(e)
|
||||||
|
|
||||||
|
|
||||||
|
TimeoutException = Union[Type[BaseException], BaseException]
|
||||||
|
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
|
||||||
|
def timeout(
|
||||||
|
seconds: float |int = None,
|
||||||
|
*,
|
||||||
|
exception: Optional[TimeoutException] = None,
|
||||||
|
on_timeout: Optional[OnTimeoutCallback] = None
|
||||||
|
):
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
result_queue = queue.Queue(maxsize=1)
|
||||||
|
def target():
|
||||||
|
try:
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
result_queue.put(result)
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put(e)
|
||||||
|
|
||||||
|
thread = threading.Thread(target=target)
|
||||||
|
thread.daemon = True
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = result_queue.get(timeout=seconds)
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
raise result
|
||||||
|
return result
|
||||||
|
except queue.Empty:
|
||||||
|
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds")
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def async_wrapper(*args, **kwargs) -> Any:
|
||||||
|
if seconds is None:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with trio.fail_after(seconds):
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except trio.TooSlowError:
|
||||||
|
if on_timeout is not None:
|
||||||
|
if callable(on_timeout):
|
||||||
|
result = on_timeout()
|
||||||
|
if isinstance(result, Coroutine):
|
||||||
|
return await result
|
||||||
|
return result
|
||||||
|
return on_timeout
|
||||||
|
|
||||||
|
if exception is None:
|
||||||
|
raise TimeoutError(f"Operation timed out after {seconds} seconds")
|
||||||
|
|
||||||
|
if isinstance(exception, BaseException):
|
||||||
|
raise exception
|
||||||
|
|
||||||
|
if isinstance(exception, type) and issubclass(exception, BaseException):
|
||||||
|
raise exception(f"Operation timed out after {seconds} seconds")
|
||||||
|
|
||||||
|
raise RuntimeError("Invalid exception type provided")
|
||||||
|
|
||||||
|
if asyncio.iscoroutinefunction(func):
|
||||||
|
return async_wrapper
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|||||||
@ -1,34 +0,0 @@
|
|||||||
from api.db.db_models import MCPServer
|
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
|
||||||
|
|
||||||
|
|
||||||
def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]:
|
|
||||||
results = {}
|
|
||||||
tool_call_sessions = []
|
|
||||||
try:
|
|
||||||
for mcp_server in mcp_servers:
|
|
||||||
server_key = mcp_server.id
|
|
||||||
|
|
||||||
cached_tools = mcp_server.variables.get("tools", {})
|
|
||||||
|
|
||||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
|
||||||
tool_call_sessions.append(tool_call_session)
|
|
||||||
|
|
||||||
try:
|
|
||||||
tools = tool_call_session.get_tools(timeout)
|
|
||||||
except Exception:
|
|
||||||
tools = []
|
|
||||||
|
|
||||||
results[server_key] = []
|
|
||||||
for tool in tools:
|
|
||||||
tool_dict = tool.model_dump()
|
|
||||||
cached_tool = cached_tools.get(tool_dict["name"], {})
|
|
||||||
|
|
||||||
tool_dict["enabled"] = cached_tool.get("enabled", True)
|
|
||||||
results[server_key].append(tool_dict)
|
|
||||||
|
|
||||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
|
||||||
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
|
||||||
return results, ""
|
|
||||||
except Exception as e:
|
|
||||||
return {}, str(e)
|
|
||||||
@ -12,6 +12,8 @@ from typing import Callable
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
from api.utils.api_utils import timeout
|
||||||
from graphrag.general import leiden
|
from graphrag.general import leiden
|
||||||
from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
||||||
from graphrag.general.extractor import Extractor
|
from graphrag.general.extractor import Extractor
|
||||||
@ -57,6 +59,7 @@ class CommunityReportsExtractor(Extractor):
|
|||||||
res_str = []
|
res_str = []
|
||||||
res_dict = []
|
res_dict = []
|
||||||
over, token_count = 0, 0
|
over, token_count = 0, 0
|
||||||
|
@timeout(120)
|
||||||
async def extract_community_report(community):
|
async def extract_community_report(community):
|
||||||
nonlocal res_str, res_dict, over, token_count
|
nonlocal res_str, res_dict, over, token_count
|
||||||
cm_id, cm = community
|
cm_id, cm = community
|
||||||
@ -90,7 +93,7 @@ class CommunityReportsExtractor(Extractor):
|
|||||||
gen_conf = {"temperature": 0.3}
|
gen_conf = {"temperature": 0.3}
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
try:
|
try:
|
||||||
with trio.move_on_after(120) as cancel_scope:
|
with trio.move_on_after(80) as cancel_scope:
|
||||||
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf)
|
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||||
if cancel_scope.cancelled_caught:
|
if cancel_scope.cancelled_caught:
|
||||||
logging.warning("extract_community_report._chat timeout, skipping...")
|
logging.warning("extract_community_report._chat timeout, skipping...")
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from typing import Callable
|
|||||||
import trio
|
import trio
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
|
from api.utils.api_utils import timeout
|
||||||
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
|
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
|
||||||
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
|
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
|
||||||
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange
|
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange
|
||||||
@ -46,6 +47,7 @@ class Extractor:
|
|||||||
self._language = language
|
self._language = language
|
||||||
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
|
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
|
||||||
|
|
||||||
|
@timeout(60)
|
||||||
def _chat(self, system, history, gen_conf):
|
def _chat(self, system, history, gen_conf):
|
||||||
hist = deepcopy(history)
|
hist = deepcopy(history)
|
||||||
conf = deepcopy(gen_conf)
|
conf = deepcopy(gen_conf)
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import trio
|
|||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
|
from api.utils.api_utils import timeout
|
||||||
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
|
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
|
||||||
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
|
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
|
||||||
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
|
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
|
||||||
@ -123,6 +124,7 @@ async def run_graphrag(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@timeout(60*60*2)
|
||||||
async def generate_subgraph(
|
async def generate_subgraph(
|
||||||
extractor: Extractor,
|
extractor: Extractor,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
@ -194,6 +196,8 @@ async def generate_subgraph(
|
|||||||
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
||||||
return subgraph
|
return subgraph
|
||||||
|
|
||||||
|
|
||||||
|
@timeout(60*3)
|
||||||
async def merge_subgraph(
|
async def merge_subgraph(
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
kb_id: str,
|
kb_id: str,
|
||||||
@ -225,6 +229,7 @@ async def merge_subgraph(
|
|||||||
return new_graph
|
return new_graph
|
||||||
|
|
||||||
|
|
||||||
|
@timeout(60*60)
|
||||||
async def resolve_entities(
|
async def resolve_entities(
|
||||||
graph,
|
graph,
|
||||||
subgraph_nodes: set[str],
|
subgraph_nodes: set[str],
|
||||||
@ -250,6 +255,7 @@ async def resolve_entities(
|
|||||||
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
||||||
|
|
||||||
|
|
||||||
|
@timeout(60*30)
|
||||||
async def extract_community(
|
async def extract_community(
|
||||||
graph,
|
graph,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
|||||||
@ -157,6 +157,7 @@ def set_tags_to_cache(kb_ids, tags):
|
|||||||
k = hasher.hexdigest()
|
k = hasher.hexdigest()
|
||||||
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
|
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
|
||||||
|
|
||||||
|
|
||||||
def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
|
def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
|
||||||
"""
|
"""
|
||||||
Ensure all nodes and edges in the graph have some essential attribute.
|
Ensure all nodes and edges in the graph have some essential attribute.
|
||||||
@ -190,12 +191,14 @@ def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
|
|||||||
if purged_edges and callback:
|
if purged_edges and callback:
|
||||||
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
|
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
|
||||||
|
|
||||||
|
|
||||||
def get_from_to(node1, node2):
|
def get_from_to(node1, node2):
|
||||||
if node1 < node2:
|
if node1 < node2:
|
||||||
return (node1, node2)
|
return (node1, node2)
|
||||||
else:
|
else:
|
||||||
return (node2, node1)
|
return (node2, node1)
|
||||||
|
|
||||||
|
|
||||||
def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
|
def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
|
||||||
"""Merge graph g2 into g1 in place."""
|
"""Merge graph g2 into g1 in place."""
|
||||||
for node_name, attr in g2.nodes(data=True):
|
for node_name, attr in g2.nodes(data=True):
|
||||||
@ -228,6 +231,7 @@ def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
|
|||||||
g1.graph["source_id"] += g2.graph.get("source_id", [])
|
g1.graph["source_id"] += g2.graph.get("source_id", [])
|
||||||
return g1
|
return g1
|
||||||
|
|
||||||
|
|
||||||
def compute_args_hash(*args):
|
def compute_args_hash(*args):
|
||||||
return md5(str(args).encode()).hexdigest()
|
return md5(str(args).encode()).hexdigest()
|
||||||
|
|
||||||
@ -378,6 +382,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
|
|||||||
chunk["q_%d_vec" % len(ebd)] = ebd
|
chunk["q_%d_vec" % len(ebd)] = ebd
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
|
||||||
async def does_graph_contains(tenant_id, kb_id, doc_id):
|
async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||||
# Get doc_ids of graph
|
# Get doc_ids of graph
|
||||||
fields = ["source_id"]
|
fields = ["source_id"]
|
||||||
@ -392,6 +397,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
|||||||
graph_doc_ids = set(fields2[chunk_id]["source_id"])
|
graph_doc_ids = set(fields2[chunk_id]["source_id"])
|
||||||
return doc_id in graph_doc_ids
|
return doc_id in graph_doc_ids
|
||||||
|
|
||||||
|
|
||||||
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
||||||
conds = {
|
conds = {
|
||||||
"fields": ["source_id"],
|
"fields": ["source_id"],
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import numpy as np
|
|||||||
from sklearn.mixture import GaussianMixture
|
from sklearn.mixture import GaussianMixture
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
from api.utils.api_utils import timeout
|
||||||
from graphrag.utils import (
|
from graphrag.utils import (
|
||||||
get_llm_cache,
|
get_llm_cache,
|
||||||
get_embed_cache,
|
get_embed_cache,
|
||||||
@ -54,6 +55,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||||||
set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
|
set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@timeout(2)
|
||||||
async def _embedding_encode(self, txt):
|
async def _embedding_encode(self, txt):
|
||||||
response = get_embed_cache(self._embd_model.llm_name, txt)
|
response = get_embed_cache(self._embd_model.llm_name, txt)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
@ -83,6 +85,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||||||
layers = [(0, len(chunks))]
|
layers = [(0, len(chunks))]
|
||||||
start, end = 0, len(chunks)
|
start, end = 0, len(chunks)
|
||||||
|
|
||||||
|
@timeout(60)
|
||||||
async def summarize(ck_idx: list[int]):
|
async def summarize(ck_idx: list[int]):
|
||||||
nonlocal chunks
|
nonlocal chunks
|
||||||
texts = [chunks[i][0] for i in ck_idx]
|
texts = [chunks[i][0] for i in ck_idx]
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from api.utils.api_utils import timeout
|
||||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||||
from graphrag.general.index import run_graphrag
|
from graphrag.general.index import run_graphrag
|
||||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||||
@ -275,6 +276,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
doc[PAGERANK_FLD] = int(task["pagerank"])
|
doc[PAGERANK_FLD] = int(task["pagerank"])
|
||||||
st = timer()
|
st = timer()
|
||||||
|
|
||||||
|
@timeout(60)
|
||||||
async def upload_to_minio(document, chunk):
|
async def upload_to_minio(document, chunk):
|
||||||
try:
|
try:
|
||||||
d = copy.deepcopy(document)
|
d = copy.deepcopy(document)
|
||||||
@ -415,6 +417,7 @@ def init_kb(row, vector_size: int):
|
|||||||
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||||
|
|
||||||
|
|
||||||
|
@timeout(60*20)
|
||||||
async def embedding(docs, mdl, parser_config=None, callback=None):
|
async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||||
if parser_config is None:
|
if parser_config is None:
|
||||||
parser_config = {}
|
parser_config = {}
|
||||||
@ -461,6 +464,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
|||||||
return tk_count, vector_size
|
return tk_count, vector_size
|
||||||
|
|
||||||
|
|
||||||
|
@timeout(3600)
|
||||||
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||||
chunks = []
|
chunks = []
|
||||||
vctr_nm = "q_%d_vec"%vector_size
|
vctr_nm = "q_%d_vec"%vector_size
|
||||||
@ -502,6 +506,7 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
|||||||
return res, tk_count
|
return res, tk_count
|
||||||
|
|
||||||
|
|
||||||
|
@timeout(60*60*1.5)
|
||||||
async def do_handle_task(task):
|
async def do_handle_task(task):
|
||||||
task_id = task["id"]
|
task_id = task["id"]
|
||||||
task_from_page = task["from_page"]
|
task_from_page = task["from_page"]
|
||||||
|
|||||||
@ -220,10 +220,12 @@ class RedisDB:
|
|||||||
logging.exception(
|
logging.exception(
|
||||||
"RedisDB.queue_product " + str(queue) + " got exception: " + str(e)
|
"RedisDB.queue_product " + str(queue) + " got exception: " + str(e)
|
||||||
)
|
)
|
||||||
|
self.__open__()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg:
|
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg:
|
||||||
"""https://redis.io/docs/latest/commands/xreadgroup/"""
|
"""https://redis.io/docs/latest/commands/xreadgroup/"""
|
||||||
|
for _ in range(3):
|
||||||
try:
|
try:
|
||||||
group_info = self.REDIS.xinfo_groups(queue_name)
|
group_info = self.REDIS.xinfo_groups(queue_name)
|
||||||
if not any(gi["name"] == group_name for gi in group_info):
|
if not any(gi["name"] == group_name for gi in group_info):
|
||||||
@ -254,6 +256,7 @@ class RedisDB:
|
|||||||
+ " got exception: "
|
+ " got exception: "
|
||||||
+ str(e)
|
+ str(e)
|
||||||
)
|
)
|
||||||
|
self.__open__()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name):
|
def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name):
|
||||||
@ -294,6 +297,7 @@ class RedisDB:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def requeue_msg(self, queue: str, group_name: str, msg_id: str):
|
def requeue_msg(self, queue: str, group_name: str, msg_id: str):
|
||||||
|
for _ in range(3):
|
||||||
try:
|
try:
|
||||||
messages = self.REDIS.xrange(queue, msg_id, msg_id)
|
messages = self.REDIS.xrange(queue, msg_id, msg_id)
|
||||||
if messages:
|
if messages:
|
||||||
@ -303,8 +307,10 @@ class RedisDB:
|
|||||||
logging.warning(
|
logging.warning(
|
||||||
"RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e)
|
"RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e)
|
||||||
)
|
)
|
||||||
|
self.__open__()
|
||||||
|
|
||||||
def queue_info(self, queue, group_name) -> dict | None:
|
def queue_info(self, queue, group_name) -> dict | None:
|
||||||
|
for _ in range(3):
|
||||||
try:
|
try:
|
||||||
groups = self.REDIS.xinfo_groups(queue)
|
groups = self.REDIS.xinfo_groups(queue)
|
||||||
for group in groups:
|
for group in groups:
|
||||||
@ -314,6 +320,7 @@ class RedisDB:
|
|||||||
logging.warning(
|
logging.warning(
|
||||||
"RedisDB.queue_info " + str(queue) + " got exception: " + str(e)
|
"RedisDB.queue_info " + str(queue) + " got exception: " + str(e)
|
||||||
)
|
)
|
||||||
|
self.__open__()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_if_equal(self, key: str, expected_value: str) -> bool:
|
def delete_if_equal(self, key: str, expected_value: str) -> bool:
|
||||||
|
|||||||
Reference in New Issue
Block a user