mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-31 15:45:08 +08:00
Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)
### Type of change - [x] Refactoring
This commit is contained in:
@ -1,5 +1,8 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
@ -26,7 +29,6 @@ from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommunityReportsResult:
|
||||
"""Community reports result class definition."""
|
||||
@ -102,7 +104,7 @@ class CommunityReportsExtractor(Extractor):
|
||||
async with chat_limiter:
|
||||
try:
|
||||
timeout = 180 if enable_timeout_assertion else 1000000000
|
||||
response = await asyncio.wait_for(asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
|
||||
response = await asyncio.wait_for(thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logging.warning("extract_community_report._chat timeout, skipping...")
|
||||
return
|
||||
|
||||
@ -38,6 +38,7 @@ from graphrag.utils import (
|
||||
set_llm_cache,
|
||||
split_string_by_multi_markers,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from rag.prompts.generator import message_fit_in
|
||||
from common.exceptions import TaskCanceledException
|
||||
@ -339,5 +340,5 @@ class Extractor:
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
|
||||
|
||||
async with chat_limiter:
|
||||
summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
|
||||
summary = await thread_pool_exec(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
|
||||
return summary
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
@ -107,7 +109,7 @@ class GraphExtractor(Extractor):
|
||||
}
|
||||
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
async with chat_limiter:
|
||||
response = await asyncio.to_thread(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
|
||||
response = await thread_pool_exec(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
|
||||
token_count += num_tokens_from_string(hint_prompt + response)
|
||||
|
||||
results = response or ""
|
||||
@ -117,7 +119,7 @@ class GraphExtractor(Extractor):
|
||||
for i in range(self._max_gleanings):
|
||||
history.append({"role": "user", "content": CONTINUE_PROMPT})
|
||||
async with chat_limiter:
|
||||
response = await asyncio.to_thread(self._chat, "", history, {})
|
||||
response = await thread_pool_exec(self._chat, "", history, {})
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
results += response or ""
|
||||
|
||||
@ -127,7 +129,7 @@ class GraphExtractor(Extractor):
|
||||
history.append({"role": "assistant", "content": response})
|
||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||
async with chat_limiter:
|
||||
continuation = await asyncio.to_thread(self._chat, "", history)
|
||||
continuation = await thread_pool_exec(self._chat, "", history)
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
if continuation != "Y":
|
||||
break
|
||||
|
||||
@ -39,6 +39,7 @@ from graphrag.utils import (
|
||||
set_graph,
|
||||
tidy_graph,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
from common import settings
|
||||
@ -460,8 +461,8 @@ async def generate_subgraph(
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
cid = chunk_id(chunk)
|
||||
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
|
||||
await asyncio.to_thread(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
|
||||
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
|
||||
await thread_pool_exec(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
||||
return subgraph
|
||||
@ -592,10 +593,10 @@ async def extract_community(
|
||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
||||
chunks.append(chunk)
|
||||
|
||||
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
|
||||
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
|
||||
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
raise Exception(error_message)
|
||||
|
||||
@ -29,6 +29,7 @@ import markdown_to_json
|
||||
from functools import reduce
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
@dataclass
|
||||
class MindMapResult:
|
||||
@ -185,7 +186,7 @@ class MindMapExtractor(Extractor):
|
||||
}
|
||||
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
|
||||
async with chat_limiter:
|
||||
response = await asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{})
|
||||
response = await thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{})
|
||||
response = re.sub(r"```[^\n]*", "", response)
|
||||
logging.debug(response)
|
||||
logging.debug(self._todict(markdown_to_json.dictify(response)))
|
||||
|
||||
Reference in New Issue
Block a user