mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Perf: set timeout of some steps in KG. (#8873)
### What problem does this PR solve? ### Type of change - [x] Performance Improvement
This commit is contained in:
@ -36,7 +36,7 @@ from api.utils import current_timestamp, datetime_format
|
|||||||
from rag.app.resume import forbidden_select_fields4resume
|
from rag.app.resume import forbidden_select_fields4resume
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp.search import index_name
|
from rag.nlp.search import index_name
|
||||||
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
|
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
|
||||||
from rag.utils import num_tokens_from_string, rmSpace
|
from rag.utils import num_tokens_from_string, rmSpace
|
||||||
from rag.utils.tavily_conn import Tavily
|
from rag.utils.tavily_conn import Tavily
|
||||||
|
|
||||||
@ -97,7 +97,7 @@ class DialogService(CommonService):
|
|||||||
|
|
||||||
|
|
||||||
def chat_solo(dialog, messages, stream=True):
|
def chat_solo(dialog, messages, stream=True):
|
||||||
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||||
else:
|
else:
|
||||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||||
@ -139,7 +139,7 @@ def get_models(dialog):
|
|||||||
if not embd_mdl:
|
if not embd_mdl:
|
||||||
raise LookupError("Embedding model(%s) not found" % embedding_list[0])
|
raise LookupError("Embedding model(%s) not found" % embedding_list[0])
|
||||||
|
|
||||||
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||||
else:
|
else:
|
||||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||||
@ -198,7 +198,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
|
|
||||||
chat_start_ts = timer()
|
chat_start_ts = timer()
|
||||||
|
|
||||||
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||||
else:
|
else:
|
||||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||||
|
|||||||
@ -583,6 +583,8 @@ class DocumentService(CommonService):
|
|||||||
info["progress"] = prg
|
info["progress"] = prg
|
||||||
if msg:
|
if msg:
|
||||||
info["progress_msg"] = msg
|
info["progress_msg"] = msg
|
||||||
|
if msg.endswith("created task graphrag") or msg.endswith("created task raptor"):
|
||||||
|
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||||
else:
|
else:
|
||||||
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
|
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||||
cls.update_by_id(d["id"], info)
|
cls.update_by_id(d["id"], info)
|
||||||
|
|||||||
@ -214,6 +214,15 @@ class TenantLLMService(CommonService):
|
|||||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||||
return list(objs)
|
return list(objs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def llm_id2llm_type(llm_id: str) ->str|None:
|
||||||
|
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||||
|
llm_factories = settings.FACTORY_LLM_INFOS
|
||||||
|
for llm_factory in llm_factories:
|
||||||
|
for llm in llm_factory["llm"]:
|
||||||
|
if llm_id == llm["llm_name"]:
|
||||||
|
return llm["model_type"].strip(",")[-1]
|
||||||
|
|
||||||
|
|
||||||
class LLMBundle:
|
class LLMBundle:
|
||||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
|
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
|
||||||
|
|||||||
@ -26,7 +26,6 @@ import rag.utils.opensearch_conn
|
|||||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||||
from api.utils import decrypt_database_config, get_base_config
|
from api.utils import decrypt_database_config, get_base_config
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from graphrag import search as kg_search
|
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
|
|
||||||
LIGHTEN = int(os.environ.get("LIGHTEN", "0"))
|
LIGHTEN = int(os.environ.get("LIGHTEN", "0"))
|
||||||
@ -169,6 +168,7 @@ def init_settings():
|
|||||||
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
||||||
|
|
||||||
retrievaler = search.Dealer(docStoreConn)
|
retrievaler = search.Dealer(docStoreConn)
|
||||||
|
from graphrag import search as kg_search
|
||||||
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
||||||
|
|
||||||
if int(os.environ.get("SANDBOX_ENABLED", "0")):
|
if int(os.environ.get("SANDBOX_ENABLED", "0")):
|
||||||
|
|||||||
@ -31,8 +31,6 @@ from urllib.parse import quote, urlencode
|
|||||||
from uuid import uuid1
|
from uuid import uuid1
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from api.db.db_models import MCPServer
|
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||||
|
|
||||||
|
|
||||||
@ -570,7 +568,7 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
|
|||||||
return transformed_data
|
return transformed_data
|
||||||
|
|
||||||
|
|
||||||
def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]:
|
def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
|
||||||
results = {}
|
results = {}
|
||||||
tool_call_sessions = []
|
tool_call_sessions = []
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -124,7 +124,7 @@ async def run_graphrag(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@timeout(60*60*2)
|
@timeout(60*60, 1)
|
||||||
async def generate_subgraph(
|
async def generate_subgraph(
|
||||||
extractor: Extractor,
|
extractor: Extractor,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
@ -229,7 +229,7 @@ async def merge_subgraph(
|
|||||||
return new_graph
|
return new_graph
|
||||||
|
|
||||||
|
|
||||||
@timeout(60*60)
|
@timeout(60*30, 1)
|
||||||
async def resolve_entities(
|
async def resolve_entities(
|
||||||
graph,
|
graph,
|
||||||
subgraph_nodes: set[str],
|
subgraph_nodes: set[str],
|
||||||
@ -255,7 +255,7 @@ async def resolve_entities(
|
|||||||
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
||||||
|
|
||||||
|
|
||||||
@timeout(60*30)
|
@timeout(60*30, 1)
|
||||||
async def extract_community(
|
async def extract_community(
|
||||||
graph,
|
graph,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
|||||||
@ -17,13 +17,12 @@ from typing import Any, Callable
|
|||||||
import os
|
import os
|
||||||
import trio
|
import trio
|
||||||
from typing import Set, Tuple
|
from typing import Set, Tuple
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xxhash
|
import xxhash
|
||||||
from networkx.readwrite import json_graph
|
from networkx.readwrite import json_graph
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
from api.utils.api_utils import timeout
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from rag.nlp import search, rag_tokenizer
|
from rag.nlp import search, rag_tokenizer
|
||||||
@ -305,6 +304,7 @@ def chunk_id(chunk):
|
|||||||
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
|
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
@timeout(1, 3)
|
||||||
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
|
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
|
||||||
chunk = {
|
chunk = {
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
@ -357,6 +357,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@timeout(1, 3)
|
||||||
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
|
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
|
||||||
chunk = {
|
chunk = {
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
|
|||||||
@ -22,7 +22,6 @@ from collections import defaultdict
|
|||||||
import jinja2
|
import jinja2
|
||||||
import json_repair
|
import json_repair
|
||||||
|
|
||||||
from api import settings
|
|
||||||
from rag.prompt_template import load_prompt
|
from rag.prompt_template import load_prompt
|
||||||
from rag.settings import TAG_FLD
|
from rag.settings import TAG_FLD
|
||||||
from rag.utils import encoder, num_tokens_from_string
|
from rag.utils import encoder, num_tokens_from_string
|
||||||
@ -51,18 +50,6 @@ def chunks_format(reference):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llm_id2llm_type(llm_id):
|
|
||||||
from api.db.services.llm_service import TenantLLMService
|
|
||||||
|
|
||||||
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
|
||||||
|
|
||||||
llm_factories = settings.FACTORY_LLM_INFOS
|
|
||||||
for llm_factory in llm_factories:
|
|
||||||
for llm in llm_factory["llm"]:
|
|
||||||
if llm_id == llm["llm_name"]:
|
|
||||||
return llm["model_type"].strip(",")[-1]
|
|
||||||
|
|
||||||
|
|
||||||
def message_fit_in(msg, max_length=4000):
|
def message_fit_in(msg, max_length=4000):
|
||||||
def count():
|
def count():
|
||||||
nonlocal msg
|
nonlocal msg
|
||||||
@ -188,8 +175,9 @@ def question_proposal(chat_mdl, content, topn=3):
|
|||||||
def full_question(tenant_id, llm_id, messages, language=None):
|
def full_question(tenant_id, llm_id, messages, language=None):
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.db.services.llm_service import TenantLLMService
|
||||||
|
|
||||||
if llm_id2llm_type(llm_id) == "image2text":
|
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||||
else:
|
else:
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||||
@ -220,8 +208,9 @@ def full_question(tenant_id, llm_id, messages, language=None):
|
|||||||
def cross_languages(tenant_id, llm_id, query, languages=[]):
|
def cross_languages(tenant_id, llm_id, query, languages=[]):
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.db.services.llm_service import TenantLLMService
|
||||||
|
|
||||||
if llm_id and llm_id2llm_type(llm_id) == "image2text":
|
if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||||
else:
|
else:
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||||
|
|||||||
@ -506,7 +506,7 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
|||||||
return res, tk_count
|
return res, tk_count
|
||||||
|
|
||||||
|
|
||||||
@timeout(60*60*1.5)
|
@timeout(60*60, 1)
|
||||||
async def do_handle_task(task):
|
async def do_handle_task(task):
|
||||||
task_id = task["id"]
|
task_id = task["id"]
|
||||||
task_from_page = task["from_page"]
|
task_from_page = task["from_page"]
|
||||||
|
|||||||
Reference in New Issue
Block a user