mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-31 15:45:08 +08:00
Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)
### Type of change - [x] Refactoring
This commit is contained in:
@ -32,6 +32,8 @@ from graphrag.utils import perform_variable_replacements, chat_limiter, GraphCha
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.exceptions import TaskCanceledException
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
||||
DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
|
||||
@ -211,7 +213,7 @@ class EntityResolution(Extractor):
|
||||
timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
asyncio.to_thread(
|
||||
thread_pool_exec(
|
||||
self._chat,
|
||||
text,
|
||||
[{"role": "user", "content": "Output:"}],
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
@ -26,7 +29,6 @@ from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommunityReportsResult:
|
||||
"""Community reports result class definition."""
|
||||
@ -102,7 +104,7 @@ class CommunityReportsExtractor(Extractor):
|
||||
async with chat_limiter:
|
||||
try:
|
||||
timeout = 180 if enable_timeout_assertion else 1000000000
|
||||
response = await asyncio.wait_for(asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
|
||||
response = await asyncio.wait_for(thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logging.warning("extract_community_report._chat timeout, skipping...")
|
||||
return
|
||||
|
||||
@ -38,6 +38,7 @@ from graphrag.utils import (
|
||||
set_llm_cache,
|
||||
split_string_by_multi_markers,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from rag.prompts.generator import message_fit_in
|
||||
from common.exceptions import TaskCanceledException
|
||||
@ -339,5 +340,5 @@ class Extractor:
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
|
||||
|
||||
async with chat_limiter:
|
||||
summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
|
||||
summary = await thread_pool_exec(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
|
||||
return summary
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
@ -107,7 +109,7 @@ class GraphExtractor(Extractor):
|
||||
}
|
||||
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
async with chat_limiter:
|
||||
response = await asyncio.to_thread(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
|
||||
response = await thread_pool_exec(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
|
||||
token_count += num_tokens_from_string(hint_prompt + response)
|
||||
|
||||
results = response or ""
|
||||
@ -117,7 +119,7 @@ class GraphExtractor(Extractor):
|
||||
for i in range(self._max_gleanings):
|
||||
history.append({"role": "user", "content": CONTINUE_PROMPT})
|
||||
async with chat_limiter:
|
||||
response = await asyncio.to_thread(self._chat, "", history, {})
|
||||
response = await thread_pool_exec(self._chat, "", history, {})
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
results += response or ""
|
||||
|
||||
@ -127,7 +129,7 @@ class GraphExtractor(Extractor):
|
||||
history.append({"role": "assistant", "content": response})
|
||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||
async with chat_limiter:
|
||||
continuation = await asyncio.to_thread(self._chat, "", history)
|
||||
continuation = await thread_pool_exec(self._chat, "", history)
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
if continuation != "Y":
|
||||
break
|
||||
|
||||
@ -39,6 +39,7 @@ from graphrag.utils import (
|
||||
set_graph,
|
||||
tidy_graph,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
from common import settings
|
||||
@ -460,8 +461,8 @@ async def generate_subgraph(
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
cid = chunk_id(chunk)
|
||||
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
|
||||
await asyncio.to_thread(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
|
||||
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
|
||||
await thread_pool_exec(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
||||
return subgraph
|
||||
@ -592,10 +593,10 @@ async def extract_community(
|
||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
||||
chunks.append(chunk)
|
||||
|
||||
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
|
||||
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
|
||||
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
raise Exception(error_message)
|
||||
|
||||
@ -29,6 +29,7 @@ import markdown_to_json
|
||||
from functools import reduce
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
@dataclass
|
||||
class MindMapResult:
|
||||
@ -185,7 +186,7 @@ class MindMapExtractor(Extractor):
|
||||
}
|
||||
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
|
||||
async with chat_limiter:
|
||||
response = await asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{})
|
||||
response = await thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{})
|
||||
response = re.sub(r"```[^\n]*", "", response)
|
||||
logging.debug(response)
|
||||
logging.debug(self._todict(markdown_to_json.dictify(response)))
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
@ -19,7 +21,6 @@ from graphrag.utils import chat_limiter, pack_user_ass_to_openai_messages, split
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphExtractionResult:
|
||||
"""Unipartite graph extraction result class definition."""
|
||||
@ -82,12 +83,12 @@ class GraphExtractor(Extractor):
|
||||
if self.callback:
|
||||
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
|
||||
async with chat_limiter:
|
||||
final_result = await asyncio.to_thread(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
|
||||
final_result = await thread_pool_exec(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
|
||||
token_count += num_tokens_from_string(hint_prompt + final_result)
|
||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
|
||||
for now_glean_index in range(self._max_gleanings):
|
||||
async with chat_limiter:
|
||||
glean_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
|
||||
glean_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id)
|
||||
history.extend([{"role": "assistant", "content": glean_result}])
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
|
||||
final_result += glean_result
|
||||
@ -96,7 +97,7 @@ class GraphExtractor(Extractor):
|
||||
|
||||
history.extend([{"role": "user", "content": self._if_loop_prompt}])
|
||||
async with chat_limiter:
|
||||
if_loop_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
|
||||
if_loop_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id)
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
|
||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||
if if_loop_result != "yes":
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
@ -316,7 +319,7 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
|
||||
async with chat_limiter:
|
||||
timeout = 3 if enable_timeout_assertion else 30000000
|
||||
ebd, _ = await asyncio.wait_for(
|
||||
asyncio.to_thread(embd_mdl.encode, [ent_name]),
|
||||
thread_pool_exec(embd_mdl.encode, [ent_name]),
|
||||
timeout=timeout
|
||||
)
|
||||
ebd = ebd[0]
|
||||
@ -370,7 +373,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
|
||||
async with chat_limiter:
|
||||
timeout = 3 if enable_timeout_assertion else 300000000
|
||||
ebd, _ = await asyncio.wait_for(
|
||||
asyncio.to_thread(
|
||||
thread_pool_exec(
|
||||
embd_mdl.encode,
|
||||
[txt + f": {meta['description']}"]
|
||||
),
|
||||
@ -390,7 +393,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
"knowledge_graph_kwd": ["graph"],
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
res = await asyncio.to_thread(
|
||||
res = await thread_pool_exec(
|
||||
settings.docStoreConn.search,
|
||||
fields, [], condition, [], OrderByExpr(),
|
||||
0, 1, search.index_name(tenant_id), [kb_id]
|
||||
@ -436,7 +439,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||
global chat_limiter
|
||||
start = asyncio.get_running_loop().time()
|
||||
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.delete,
|
||||
{"knowledge_graph_kwd": ["graph", "subgraph"]},
|
||||
search.index_name(tenant_id),
|
||||
@ -444,7 +447,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||
)
|
||||
|
||||
if change.removed_nodes:
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.delete,
|
||||
{"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)},
|
||||
search.index_name(tenant_id),
|
||||
@ -455,7 +458,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||
|
||||
async def del_edges(from_node, to_node):
|
||||
async with chat_limiter:
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.delete,
|
||||
{"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node},
|
||||
search.index_name(tenant_id),
|
||||
@ -556,7 +559,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
timeout = 3 if enable_timeout_assertion else 30000000
|
||||
doc_store_result = await asyncio.wait_for(
|
||||
asyncio.to_thread(
|
||||
thread_pool_exec(
|
||||
settings.docStoreConn.insert,
|
||||
chunks[b : b + es_bulk_size],
|
||||
search.index_name(tenant_id),
|
||||
@ -650,7 +653,7 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||
flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
|
||||
bs = 256
|
||||
for i in range(0, 1024 * bs, bs):
|
||||
es_res = await asyncio.to_thread(
|
||||
es_res = await thread_pool_exec(
|
||||
settings.docStoreConn.search,
|
||||
flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]},
|
||||
[], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]
|
||||
|
||||
Reference in New Issue
Block a user