Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)

### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu
2026-01-20 13:29:37 +08:00
committed by GitHub
parent 120648ac81
commit 927db0b373
30 changed files with 246 additions and 157 deletions

View File

@ -1,5 +1,8 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
@ -26,7 +29,6 @@ from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
from common.token_utils import num_tokens_from_string
@dataclass
class CommunityReportsResult:
"""Community reports result class definition."""
@ -102,7 +104,7 @@ class CommunityReportsExtractor(Extractor):
async with chat_limiter:
try:
timeout = 180 if enable_timeout_assertion else 1000000000
response = await asyncio.wait_for(asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
response = await asyncio.wait_for(thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
except asyncio.TimeoutError:
logging.warning("extract_community_report._chat timeout, skipping...")
return

View File

@ -38,6 +38,7 @@ from graphrag.utils import (
set_llm_cache,
split_string_by_multi_markers,
)
from common.misc_utils import thread_pool_exec
from rag.llm.chat_model import Base as CompletionLLM
from rag.prompts.generator import message_fit_in
from common.exceptions import TaskCanceledException
@ -339,5 +340,5 @@ class Extractor:
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
async with chat_limiter:
summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
summary = await thread_pool_exec(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
return summary

View File

@ -1,11 +1,13 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import asyncio
import re
from typing import Any
from dataclasses import dataclass
@ -107,7 +109,7 @@ class GraphExtractor(Extractor):
}
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
async with chat_limiter:
response = await asyncio.to_thread(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
response = await thread_pool_exec(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
token_count += num_tokens_from_string(hint_prompt + response)
results = response or ""
@ -117,7 +119,7 @@ class GraphExtractor(Extractor):
for i in range(self._max_gleanings):
history.append({"role": "user", "content": CONTINUE_PROMPT})
async with chat_limiter:
response = await asyncio.to_thread(self._chat, "", history, {})
response = await thread_pool_exec(self._chat, "", history, {})
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""
@ -127,7 +129,7 @@ class GraphExtractor(Extractor):
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
async with chat_limiter:
continuation = await asyncio.to_thread(self._chat, "", history)
continuation = await thread_pool_exec(self._chat, "", history)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "Y":
break

View File

@ -39,6 +39,7 @@ from graphrag.utils import (
set_graph,
tidy_graph,
)
from common.misc_utils import thread_pool_exec
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import RedisDistributedLock
from common import settings
@ -460,8 +461,8 @@ async def generate_subgraph(
"removed_kwd": "N",
}
cid = chunk_id(chunk)
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
await asyncio.to_thread(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
await thread_pool_exec(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
now = asyncio.get_running_loop().time()
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
return subgraph
@ -592,10 +593,10 @@ async def extract_community(
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
chunks.append(chunk)
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)

View File

@ -29,6 +29,7 @@ import markdown_to_json
from functools import reduce
from common.token_utils import num_tokens_from_string
from common.misc_utils import thread_pool_exec
@dataclass
class MindMapResult:
@ -185,7 +186,7 @@ class MindMapExtractor(Extractor):
}
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
async with chat_limiter:
response = await asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{})
response = await thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{})
response = re.sub(r"```[^\n]*", "", response)
logging.debug(response)
logging.debug(self._todict(markdown_to_json.dictify(response)))