From fbd115773b6f879a94bd4f59a8e75ee8d2ef8445 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Wed, 16 Jul 2025 18:06:03 +0800 Subject: [PATCH] Perf: set timeout of some steps in KG. (#8873) ### What problem does this PR solve? ### Type of change - [x] Performance Improvement --- api/db/services/dialog_service.py | 8 ++++---- api/db/services/document_service.py | 2 ++ api/db/services/llm_service.py | 9 +++++++++ api/settings.py | 2 +- api/utils/api_utils.py | 4 +--- graphrag/general/index.py | 6 +++--- graphrag/utils.py | 5 +++-- rag/prompts.py | 19 ++++--------------- rag/svr/task_executor.py | 2 +- 9 files changed, 28 insertions(+), 29 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 211178a51..37dc05d92 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -36,7 +36,7 @@ from api.utils import current_timestamp, datetime_format from rag.app.resume import forbidden_select_fields4resume from rag.app.tag import label_question from rag.nlp.search import index_name -from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in +from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in from rag.utils import num_tokens_from_string, rmSpace from rag.utils.tavily_conn import Tavily @@ -97,7 +97,7 @@ class DialogService(CommonService): def chat_solo(dialog, messages, stream=True): - if llm_id2llm_type(dialog.llm_id) == "image2text": + if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) @@ -139,7 +139,7 @@ def get_models(dialog): if not embd_mdl: raise LookupError("Embedding model(%s) not found" % embedding_list[0]) - if llm_id2llm_type(dialog.llm_id) == "image2text": + if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) @@ -198,7 +198,7 @@ def chat(dialog, messages, stream=True, **kwargs): chat_start_ts = timer() - if llm_id2llm_type(dialog.llm_id) == "image2text": + if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index a9dfcb438..feedaed31 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -583,6 +583,8 @@ class DocumentService(CommonService): info["progress"] = prg if msg: info["progress_msg"] = msg + if msg.endswith("created task graphrag") or msg.endswith("created task raptor"): + info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority) else: info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority) cls.update_by_id(d["id"], info) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 462eb9d4a..92655abfd 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -214,6 +214,15 @@ class TenantLLMService(CommonService): objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() return list(objs) + @staticmethod + def llm_id2llm_type(llm_id: str) ->str|None: + llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id) + llm_factories = settings.FACTORY_LLM_INFOS + for llm_factory in llm_factories: + for llm in llm_factory["llm"]: + if llm_id == llm["llm_name"]: + return llm["model_type"].strip(",")[-1] + class LLMBundle: def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"): diff --git a/api/settings.py b/api/settings.py index f5fd80fb1..4d0e18b68 100644 --- a/api/settings.py +++ b/api/settings.py @@ -26,7 +26,6 @@ import rag.utils.opensearch_conn from api.constants import RAG_FLOW_SERVICE_NAME from api.utils import decrypt_database_config, get_base_config from api.utils.file_utils import get_project_base_directory -from graphrag import search as kg_search from rag.nlp import search LIGHTEN = int(os.environ.get("LIGHTEN", "0")) @@ -169,6 +168,7 @@ def init_settings(): raise Exception(f"Not supported doc engine: {DOC_ENGINE}") retrievaler = search.Dealer(docStoreConn) + from graphrag import search as kg_search kg_retrievaler = kg_search.KGSearch(docStoreConn) if int(os.environ.get("SANDBOX_ENABLED", "0")): diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index f574a8f00..078aa7bf2 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -31,8 +31,6 @@ from urllib.parse import quote, urlencode from uuid import uuid1 import trio - -from api.db.db_models import MCPServer from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions @@ -570,7 +568,7 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict: return transformed_data -def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]: +def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]: results = {} tool_call_sessions = [] try: diff --git a/graphrag/general/index.py b/graphrag/general/index.py index fe54747f4..8ac65bfbd 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -124,7 +124,7 @@ async def run_graphrag( return -@timeout(60*60*2) +@timeout(60*60, 1) async def generate_subgraph( extractor: Extractor, tenant_id: str, @@ -229,7 +229,7 @@ async def merge_subgraph( return new_graph -@timeout(60*60) +@timeout(60*30, 1) async def resolve_entities( graph, subgraph_nodes: set[str], @@ -255,7 +255,7 @@ async def resolve_entities( callback(msg=f"Graph resolution done in {now - start:.2f}s.") -@timeout(60*30) +@timeout(60*30, 1) async def extract_community( graph, tenant_id: str, diff --git a/graphrag/utils.py b/graphrag/utils.py index d414f2c41..ccaf506f2 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -17,13 +17,12 @@ from typing import Any, Callable import os import trio from typing import Set, Tuple - import networkx as nx import numpy as np import xxhash from networkx.readwrite import json_graph import dataclasses - +from api.utils.api_utils import timeout from api import settings from api.utils import get_uuid from rag.nlp import search, rag_tokenizer @@ -305,6 +304,7 @@ def chunk_id(chunk): return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest() +@timeout(1, 3) async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): chunk = { "id": get_uuid(), @@ -357,6 +357,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1): return res +@timeout(1, 3) async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks): chunk = { "id": get_uuid(), diff --git a/rag/prompts.py b/rag/prompts.py index ac04846da..ccaa198b6 100644 --- a/rag/prompts.py +++ b/rag/prompts.py @@ -22,7 +22,6 @@ from collections import defaultdict import jinja2 import json_repair -from api import settings from rag.prompt_template import load_prompt from rag.settings import TAG_FLD from rag.utils import encoder, num_tokens_from_string @@ -51,18 +50,6 @@ def chunks_format(reference): ] -def llm_id2llm_type(llm_id): - from api.db.services.llm_service import TenantLLMService - - llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id) - - llm_factories = settings.FACTORY_LLM_INFOS - for llm_factory in llm_factories: - for llm in llm_factory["llm"]: - if llm_id == llm["llm_name"]: - return llm["model_type"].strip(",")[-1] - - def message_fit_in(msg, max_length=4000): def count(): nonlocal msg @@ -188,8 +175,9 @@ def question_proposal(chat_mdl, content, topn=3): def full_question(tenant_id, llm_id, messages, language=None): from api.db import LLMType from api.db.services.llm_service import LLMBundle + from api.db.services.llm_service import TenantLLMService - if llm_id2llm_type(llm_id) == "image2text": + if TenantLLMService.llm_id2llm_type(llm_id) == "image2text": chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) else: chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) @@ -220,8 +208,9 @@ def full_question(tenant_id, llm_id, messages, language=None): def cross_languages(tenant_id, llm_id, query, languages=[]): from api.db import LLMType from api.db.services.llm_service import LLMBundle + from api.db.services.llm_service import TenantLLMService - if llm_id and llm_id2llm_type(llm_id) == "image2text": + if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text": chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) else: chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 952737d83..c39bb982e 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -506,7 +506,7 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): return res, tk_count -@timeout(60*60*1.5) +@timeout(60*60, 1) async def do_handle_task(task): task_id = task["id"] task_from_page = task["from_page"]