Perf: set timeout of some steps in KG. (#8873)

### What problem does this PR solve?

### Type of change


- [x] Performance Improvement
This commit is contained in:
Kevin Hu
2025-07-16 18:06:03 +08:00
committed by GitHub
parent b3018a455f
commit fbd115773b
9 changed files with 28 additions and 29 deletions

View File

@ -36,7 +36,7 @@ from api.utils import current_timestamp, datetime_format
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
from rag.utils import num_tokens_from_string, rmSpace from rag.utils import num_tokens_from_string, rmSpace
from rag.utils.tavily_conn import Tavily from rag.utils.tavily_conn import Tavily
@ -97,7 +97,7 @@ class DialogService(CommonService):
def chat_solo(dialog, messages, stream=True): def chat_solo(dialog, messages, stream=True):
if llm_id2llm_type(dialog.llm_id) == "image2text": if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else: else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
@ -139,7 +139,7 @@ def get_models(dialog):
if not embd_mdl: if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embedding_list[0]) raise LookupError("Embedding model(%s) not found" % embedding_list[0])
if llm_id2llm_type(dialog.llm_id) == "image2text": if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else: else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
@ -198,7 +198,7 @@ def chat(dialog, messages, stream=True, **kwargs):
chat_start_ts = timer() chat_start_ts = timer()
if llm_id2llm_type(dialog.llm_id) == "image2text": if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else: else:
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)

View File

@ -583,6 +583,8 @@ class DocumentService(CommonService):
info["progress"] = prg info["progress"] = prg
if msg: if msg:
info["progress_msg"] = msg info["progress_msg"] = msg
if msg.endswith("created task graphrag") or msg.endswith("created task raptor"):
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
else: else:
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority) info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
cls.update_by_id(d["id"], info) cls.update_by_id(d["id"], info)

View File

@ -214,6 +214,15 @@ class TenantLLMService(CommonService):
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs) return list(objs)
@staticmethod
def llm_id2llm_type(llm_id: str) ->str|None:
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
return llm["model_type"].strip(",")[-1]
class LLMBundle: class LLMBundle:
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"): def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):

View File

@ -26,7 +26,6 @@ import rag.utils.opensearch_conn
from api.constants import RAG_FLOW_SERVICE_NAME from api.constants import RAG_FLOW_SERVICE_NAME
from api.utils import decrypt_database_config, get_base_config from api.utils import decrypt_database_config, get_base_config
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from graphrag import search as kg_search
from rag.nlp import search from rag.nlp import search
LIGHTEN = int(os.environ.get("LIGHTEN", "0")) LIGHTEN = int(os.environ.get("LIGHTEN", "0"))
@ -169,6 +168,7 @@ def init_settings():
raise Exception(f"Not supported doc engine: {DOC_ENGINE}") raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
retrievaler = search.Dealer(docStoreConn) retrievaler = search.Dealer(docStoreConn)
from graphrag import search as kg_search
kg_retrievaler = kg_search.KGSearch(docStoreConn) kg_retrievaler = kg_search.KGSearch(docStoreConn)
if int(os.environ.get("SANDBOX_ENABLED", "0")): if int(os.environ.get("SANDBOX_ENABLED", "0")):

View File

@ -31,8 +31,6 @@ from urllib.parse import quote, urlencode
from uuid import uuid1 from uuid import uuid1
import trio import trio
from api.db.db_models import MCPServer
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@ -570,7 +568,7 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
return transformed_data return transformed_data
def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]: def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
results = {} results = {}
tool_call_sessions = [] tool_call_sessions = []
try: try:

View File

@ -124,7 +124,7 @@ async def run_graphrag(
return return
@timeout(60*60*2) @timeout(60*60, 1)
async def generate_subgraph( async def generate_subgraph(
extractor: Extractor, extractor: Extractor,
tenant_id: str, tenant_id: str,
@ -229,7 +229,7 @@ async def merge_subgraph(
return new_graph return new_graph
@timeout(60*60) @timeout(60*30, 1)
async def resolve_entities( async def resolve_entities(
graph, graph,
subgraph_nodes: set[str], subgraph_nodes: set[str],
@ -255,7 +255,7 @@ async def resolve_entities(
callback(msg=f"Graph resolution done in {now - start:.2f}s.") callback(msg=f"Graph resolution done in {now - start:.2f}s.")
@timeout(60*30) @timeout(60*30, 1)
async def extract_community( async def extract_community(
graph, graph,
tenant_id: str, tenant_id: str,

View File

@ -17,13 +17,12 @@ from typing import Any, Callable
import os import os
import trio import trio
from typing import Set, Tuple from typing import Set, Tuple
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import xxhash import xxhash
from networkx.readwrite import json_graph from networkx.readwrite import json_graph
import dataclasses import dataclasses
from api.utils.api_utils import timeout
from api import settings from api import settings
from api.utils import get_uuid from api.utils import get_uuid
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer
@ -305,6 +304,7 @@ def chunk_id(chunk):
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest() return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
@timeout(1, 3)
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
chunk = { chunk = {
"id": get_uuid(), "id": get_uuid(),
@ -357,6 +357,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
return res return res
@timeout(1, 3)
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks): async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
chunk = { chunk = {
"id": get_uuid(), "id": get_uuid(),

View File

@ -22,7 +22,6 @@ from collections import defaultdict
import jinja2 import jinja2
import json_repair import json_repair
from api import settings
from rag.prompt_template import load_prompt from rag.prompt_template import load_prompt
from rag.settings import TAG_FLD from rag.settings import TAG_FLD
from rag.utils import encoder, num_tokens_from_string from rag.utils import encoder, num_tokens_from_string
@ -51,18 +50,6 @@ def chunks_format(reference):
] ]
def llm_id2llm_type(llm_id):
from api.db.services.llm_service import TenantLLMService
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
return llm["model_type"].strip(",")[-1]
def message_fit_in(msg, max_length=4000): def message_fit_in(msg, max_length=4000):
def count(): def count():
nonlocal msg nonlocal msg
@ -188,8 +175,9 @@ def question_proposal(chat_mdl, content, topn=3):
def full_question(tenant_id, llm_id, messages, language=None): def full_question(tenant_id, llm_id, messages, language=None):
from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.llm_service import TenantLLMService
if llm_id2llm_type(llm_id) == "image2text": if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else: else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
@ -220,8 +208,9 @@ def full_question(tenant_id, llm_id, messages, language=None):
def cross_languages(tenant_id, llm_id, query, languages=[]): def cross_languages(tenant_id, llm_id, query, languages=[]):
from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.llm_service import TenantLLMService
if llm_id and llm_id2llm_type(llm_id) == "image2text": if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else: else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)

View File

@ -506,7 +506,7 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count return res, tk_count
@timeout(60*60*1.5) @timeout(60*60, 1)
async def do_handle_task(task): async def do_handle_task(task):
task_id = task["id"] task_id = task["id"]
task_from_page = task["from_page"] task_from_page = task["from_page"]