mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Made task_executor async to speedup parsing (#5530)
### What problem does this PR solve? Made task_executor async to speedup parsing ### Type of change - [x] Performance Improvement
This commit is contained in:
@ -16,16 +16,14 @@
|
||||
|
||||
import logging
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
import trio
|
||||
|
||||
from graphrag.general.extractor import Extractor
|
||||
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import markdown_to_json
|
||||
from functools import reduce
|
||||
@ -80,63 +78,47 @@ class MindMapExtractor(Extractor):
|
||||
)
|
||||
return arr
|
||||
|
||||
def __call__(
|
||||
async def __call__(
|
||||
self, sections: list[str], prompt_variables: dict[str, Any] | None = None
|
||||
) -> MindMapResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
|
||||
try:
|
||||
res = []
|
||||
max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as exe:
|
||||
threads = []
|
||||
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
|
||||
texts = []
|
||||
cnt = 0
|
||||
for i in range(len(sections)):
|
||||
section_cnt = num_tokens_from_string(sections[i])
|
||||
if cnt + section_cnt >= token_count and texts:
|
||||
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
|
||||
texts = []
|
||||
cnt = 0
|
||||
texts.append(sections[i])
|
||||
cnt += section_cnt
|
||||
if texts:
|
||||
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
|
||||
|
||||
for i, _ in enumerate(threads):
|
||||
res.append(_.result())
|
||||
|
||||
if not res:
|
||||
return MindMapResult(output={"id": "root", "children": []})
|
||||
|
||||
merge_json = reduce(self._merge, res)
|
||||
if len(merge_json) > 1:
|
||||
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
|
||||
keyset = set(i for i in keys if i)
|
||||
merge_json = {
|
||||
"id": "root",
|
||||
"children": [
|
||||
{
|
||||
"id": self._key(k),
|
||||
"children": self._be_children(v, keyset)
|
||||
}
|
||||
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
|
||||
]
|
||||
}
|
||||
else:
|
||||
k = self._key(list(merge_json.keys())[0])
|
||||
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("error mind graph")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(), None
|
||||
)
|
||||
merge_json = {"error": str(e)}
|
||||
res = []
|
||||
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
|
||||
texts = []
|
||||
cnt = 0
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(len(sections)):
|
||||
section_cnt = num_tokens_from_string(sections[i])
|
||||
if cnt + section_cnt >= token_count and texts:
|
||||
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
|
||||
texts = []
|
||||
cnt = 0
|
||||
texts.append(sections[i])
|
||||
cnt += section_cnt
|
||||
if texts:
|
||||
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
|
||||
if not res:
|
||||
return MindMapResult(output={"id": "root", "children": []})
|
||||
merge_json = reduce(self._merge, res)
|
||||
if len(merge_json) > 1:
|
||||
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
|
||||
keyset = set(i for i in keys if i)
|
||||
merge_json = {
|
||||
"id": "root",
|
||||
"children": [
|
||||
{
|
||||
"id": self._key(k),
|
||||
"children": self._be_children(v, keyset)
|
||||
}
|
||||
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
|
||||
]
|
||||
}
|
||||
else:
|
||||
k = self._key(list(merge_json.keys())[0])
|
||||
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
|
||||
|
||||
return MindMapResult(output=merge_json)
|
||||
|
||||
@ -181,8 +163,8 @@ class MindMapExtractor(Extractor):
|
||||
|
||||
return self._list_to_kv(to_ret)
|
||||
|
||||
def _process_document(
|
||||
self, text: str, prompt_variables: dict[str, str]
|
||||
async def _process_document(
|
||||
self, text: str, prompt_variables: dict[str, str], out_res
|
||||
) -> str:
|
||||
variables = {
|
||||
**prompt_variables,
|
||||
@ -190,8 +172,9 @@ class MindMapExtractor(Extractor):
|
||||
}
|
||||
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
|
||||
gen_conf = {"temperature": 0.5}
|
||||
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
async with chat_limiter:
|
||||
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||
response = re.sub(r"```[^\n]*", "", response)
|
||||
logging.debug(response)
|
||||
logging.debug(self._todict(markdown_to_json.dictify(response)))
|
||||
return self._todict(markdown_to_json.dictify(response))
|
||||
out_res.append(self._todict(markdown_to_json.dictify(response)))
|
||||
|
||||
Reference in New Issue
Block a user