Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)

### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu
2026-01-20 13:29:37 +08:00
committed by GitHub
parent 120648ac81
commit 927db0b373
30 changed files with 246 additions and 157 deletions

View File

@ -32,6 +32,8 @@ from graphrag.utils import perform_variable_replacements, chat_limiter, GraphCha
from api.db.services.task_service import has_canceled
from common.exceptions import TaskCanceledException
from common.misc_utils import thread_pool_exec
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
@ -211,7 +213,7 @@ class EntityResolution(Extractor):
timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000
try:
response = await asyncio.wait_for(
asyncio.to_thread(
thread_pool_exec(
self._chat,
text,
[{"role": "user", "content": "Output:"}],

View File

@ -1,5 +1,8 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
@ -26,7 +29,6 @@ from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
from common.token_utils import num_tokens_from_string
@dataclass
class CommunityReportsResult:
"""Community reports result class definition."""
@ -102,7 +104,7 @@ class CommunityReportsExtractor(Extractor):
async with chat_limiter:
try:
timeout = 180 if enable_timeout_assertion else 1000000000
response = await asyncio.wait_for(asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
response = await asyncio.wait_for(thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
except asyncio.TimeoutError:
logging.warning("extract_community_report._chat timeout, skipping...")
return

View File

@ -38,6 +38,7 @@ from graphrag.utils import (
set_llm_cache,
split_string_by_multi_markers,
)
from common.misc_utils import thread_pool_exec
from rag.llm.chat_model import Base as CompletionLLM
from rag.prompts.generator import message_fit_in
from common.exceptions import TaskCanceledException
@ -339,5 +340,5 @@ class Extractor:
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
async with chat_limiter:
summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
summary = await thread_pool_exec(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
return summary

View File

@ -1,11 +1,13 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import asyncio
import re
from typing import Any
from dataclasses import dataclass
@ -107,7 +109,7 @@ class GraphExtractor(Extractor):
}
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
async with chat_limiter:
response = await asyncio.to_thread(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
response = await thread_pool_exec(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
token_count += num_tokens_from_string(hint_prompt + response)
results = response or ""
@ -117,7 +119,7 @@ class GraphExtractor(Extractor):
for i in range(self._max_gleanings):
history.append({"role": "user", "content": CONTINUE_PROMPT})
async with chat_limiter:
response = await asyncio.to_thread(self._chat, "", history, {})
response = await thread_pool_exec(self._chat, "", history, {})
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""
@ -127,7 +129,7 @@ class GraphExtractor(Extractor):
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
async with chat_limiter:
continuation = await asyncio.to_thread(self._chat, "", history)
continuation = await thread_pool_exec(self._chat, "", history)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "Y":
break

View File

@ -39,6 +39,7 @@ from graphrag.utils import (
set_graph,
tidy_graph,
)
from common.misc_utils import thread_pool_exec
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import RedisDistributedLock
from common import settings
@ -460,8 +461,8 @@ async def generate_subgraph(
"removed_kwd": "N",
}
cid = chunk_id(chunk)
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
await asyncio.to_thread(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
await thread_pool_exec(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
now = asyncio.get_running_loop().time()
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
return subgraph
@ -592,10 +593,10 @@ async def extract_community(
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
chunks.append(chunk)
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)

View File

@ -29,6 +29,7 @@ import markdown_to_json
from functools import reduce
from common.token_utils import num_tokens_from_string
from common.misc_utils import thread_pool_exec
@dataclass
class MindMapResult:
@ -185,7 +186,7 @@ class MindMapExtractor(Extractor):
}
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
async with chat_limiter:
response = await asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{})
response = await thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{})
response = re.sub(r"```[^\n]*", "", response)
logging.debug(response)
logging.debug(self._todict(markdown_to_json.dictify(response)))

View File

@ -1,11 +1,13 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import asyncio
import logging
import re
from dataclasses import dataclass
@ -19,7 +21,6 @@ from graphrag.utils import chat_limiter, pack_user_ass_to_openai_messages, split
from rag.llm.chat_model import Base as CompletionLLM
from common.token_utils import num_tokens_from_string
@dataclass
class GraphExtractionResult:
"""Unipartite graph extraction result class definition."""
@ -82,12 +83,12 @@ class GraphExtractor(Extractor):
if self.callback:
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
async with chat_limiter:
final_result = await asyncio.to_thread(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
final_result = await thread_pool_exec(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
for now_glean_index in range(self._max_gleanings):
async with chat_limiter:
glean_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
glean_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id)
history.extend([{"role": "assistant", "content": glean_result}])
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
final_result += glean_result
@ -96,7 +97,7 @@ class GraphExtractor(Extractor):
history.extend([{"role": "user", "content": self._if_loop_prompt}])
async with chat_limiter:
if_loop_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
if_loop_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":

View File

@ -1,5 +1,8 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
@ -316,7 +319,7 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
async with chat_limiter:
timeout = 3 if enable_timeout_assertion else 30000000
ebd, _ = await asyncio.wait_for(
asyncio.to_thread(embd_mdl.encode, [ent_name]),
thread_pool_exec(embd_mdl.encode, [ent_name]),
timeout=timeout
)
ebd = ebd[0]
@ -370,7 +373,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
async with chat_limiter:
timeout = 3 if enable_timeout_assertion else 300000000
ebd, _ = await asyncio.wait_for(
asyncio.to_thread(
thread_pool_exec(
embd_mdl.encode,
[txt + f": {meta['description']}"]
),
@ -390,7 +393,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
"knowledge_graph_kwd": ["graph"],
"removed_kwd": "N",
}
res = await asyncio.to_thread(
res = await thread_pool_exec(
settings.docStoreConn.search,
fields, [], condition, [], OrderByExpr(),
0, 1, search.index_name(tenant_id), [kb_id]
@ -436,7 +439,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
global chat_limiter
start = asyncio.get_running_loop().time()
await asyncio.to_thread(
await thread_pool_exec(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["graph", "subgraph"]},
search.index_name(tenant_id),
@ -444,7 +447,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
)
if change.removed_nodes:
await asyncio.to_thread(
await thread_pool_exec(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)},
search.index_name(tenant_id),
@ -455,7 +458,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
async def del_edges(from_node, to_node):
async with chat_limiter:
await asyncio.to_thread(
await thread_pool_exec(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node},
search.index_name(tenant_id),
@ -556,7 +559,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
for b in range(0, len(chunks), es_bulk_size):
timeout = 3 if enable_timeout_assertion else 30000000
doc_store_result = await asyncio.wait_for(
asyncio.to_thread(
thread_pool_exec(
settings.docStoreConn.insert,
chunks[b : b + es_bulk_size],
search.index_name(tenant_id),
@ -650,7 +653,7 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
bs = 256
for i in range(0, 1024 * bs, bs):
es_res = await asyncio.to_thread(
es_res = await thread_pool_exec(
settings.docStoreConn.search,
flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]},
[], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]