mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Perf: set timeout of some steps in KG. (#8873)
### What problem does this PR solve? ### Type of change - [x] Performance Improvement
This commit is contained in:
@ -36,7 +36,7 @@ from api.utils import current_timestamp, datetime_format
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp.search import index_name
|
||||
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
|
||||
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
|
||||
from rag.utils import num_tokens_from_string, rmSpace
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
|
||||
@ -97,7 +97,7 @@ class DialogService(CommonService):
|
||||
|
||||
|
||||
def chat_solo(dialog, messages, stream=True):
|
||||
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
@ -139,7 +139,7 @@ def get_models(dialog):
|
||||
if not embd_mdl:
|
||||
raise LookupError("Embedding model(%s) not found" % embedding_list[0])
|
||||
|
||||
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
@ -198,7 +198,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
chat_start_ts = timer()
|
||||
|
||||
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
@ -583,6 +583,8 @@ class DocumentService(CommonService):
|
||||
info["progress"] = prg
|
||||
if msg:
|
||||
info["progress_msg"] = msg
|
||||
if msg.endswith("created task graphrag") or msg.endswith("created task raptor"):
|
||||
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
else:
|
||||
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
cls.update_by_id(d["id"], info)
|
||||
|
||||
@ -214,6 +214,15 @@ class TenantLLMService(CommonService):
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
return list(objs)
|
||||
|
||||
@staticmethod
|
||||
def llm_id2llm_type(llm_id: str) ->str|None:
|
||||
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||
llm_factories = settings.FACTORY_LLM_INFOS
|
||||
for llm_factory in llm_factories:
|
||||
for llm in llm_factory["llm"]:
|
||||
if llm_id == llm["llm_name"]:
|
||||
return llm["model_type"].strip(",")[-1]
|
||||
|
||||
|
||||
class LLMBundle:
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
|
||||
|
||||
@ -26,7 +26,6 @@ import rag.utils.opensearch_conn
|
||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||
from api.utils import decrypt_database_config, get_base_config
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from graphrag import search as kg_search
|
||||
from rag.nlp import search
|
||||
|
||||
LIGHTEN = int(os.environ.get("LIGHTEN", "0"))
|
||||
@ -169,6 +168,7 @@ def init_settings():
|
||||
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
||||
|
||||
retrievaler = search.Dealer(docStoreConn)
|
||||
from graphrag import search as kg_search
|
||||
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
||||
|
||||
if int(os.environ.get("SANDBOX_ENABLED", "0")):
|
||||
|
||||
@ -31,8 +31,6 @@ from urllib.parse import quote, urlencode
|
||||
from uuid import uuid1
|
||||
|
||||
import trio
|
||||
|
||||
from api.db.db_models import MCPServer
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
@ -570,7 +568,7 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
|
||||
return transformed_data
|
||||
|
||||
|
||||
def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]:
|
||||
def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
|
||||
results = {}
|
||||
tool_call_sessions = []
|
||||
try:
|
||||
|
||||
@ -124,7 +124,7 @@ async def run_graphrag(
|
||||
return
|
||||
|
||||
|
||||
@timeout(60*60*2)
|
||||
@timeout(60*60, 1)
|
||||
async def generate_subgraph(
|
||||
extractor: Extractor,
|
||||
tenant_id: str,
|
||||
@ -229,7 +229,7 @@ async def merge_subgraph(
|
||||
return new_graph
|
||||
|
||||
|
||||
@timeout(60*60)
|
||||
@timeout(60*30, 1)
|
||||
async def resolve_entities(
|
||||
graph,
|
||||
subgraph_nodes: set[str],
|
||||
@ -255,7 +255,7 @@ async def resolve_entities(
|
||||
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
||||
|
||||
|
||||
@timeout(60*30)
|
||||
@timeout(60*30, 1)
|
||||
async def extract_community(
|
||||
graph,
|
||||
tenant_id: str,
|
||||
|
||||
@ -17,13 +17,12 @@ from typing import Any, Callable
|
||||
import os
|
||||
import trio
|
||||
from typing import Set, Tuple
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import xxhash
|
||||
from networkx.readwrite import json_graph
|
||||
import dataclasses
|
||||
|
||||
from api.utils.api_utils import timeout
|
||||
from api import settings
|
||||
from api.utils import get_uuid
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
@ -305,6 +304,7 @@ def chunk_id(chunk):
|
||||
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
@timeout(1, 3)
|
||||
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
|
||||
chunk = {
|
||||
"id": get_uuid(),
|
||||
@ -357,6 +357,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
|
||||
return res
|
||||
|
||||
|
||||
@timeout(1, 3)
|
||||
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
|
||||
chunk = {
|
||||
"id": get_uuid(),
|
||||
|
||||
@ -22,7 +22,6 @@ from collections import defaultdict
|
||||
import jinja2
|
||||
import json_repair
|
||||
|
||||
from api import settings
|
||||
from rag.prompt_template import load_prompt
|
||||
from rag.settings import TAG_FLD
|
||||
from rag.utils import encoder, num_tokens_from_string
|
||||
@ -51,18 +50,6 @@ def chunks_format(reference):
|
||||
]
|
||||
|
||||
|
||||
def llm_id2llm_type(llm_id):
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
|
||||
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||
|
||||
llm_factories = settings.FACTORY_LLM_INFOS
|
||||
for llm_factory in llm_factories:
|
||||
for llm in llm_factory["llm"]:
|
||||
if llm_id == llm["llm_name"]:
|
||||
return llm["model_type"].strip(",")[-1]
|
||||
|
||||
|
||||
def message_fit_in(msg, max_length=4000):
|
||||
def count():
|
||||
nonlocal msg
|
||||
@ -188,8 +175,9 @@ def question_proposal(chat_mdl, content, topn=3):
|
||||
def full_question(tenant_id, llm_id, messages, language=None):
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
|
||||
if llm_id2llm_type(llm_id) == "image2text":
|
||||
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
@ -220,8 +208,9 @@ def full_question(tenant_id, llm_id, messages, language=None):
|
||||
def cross_languages(tenant_id, llm_id, query, languages=[]):
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
|
||||
if llm_id and llm_id2llm_type(llm_id) == "image2text":
|
||||
if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
|
||||
@ -506,7 +506,7 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
return res, tk_count
|
||||
|
||||
|
||||
@timeout(60*60*1.5)
|
||||
@timeout(60*60, 1)
|
||||
async def do_handle_task(task):
|
||||
task_id = task["id"]
|
||||
task_from_page = task["from_page"]
|
||||
|
||||
Reference in New Issue
Block a user