Made task_executor async to speedup parsing (#5530)

### What problem does this PR solve?

Made task_executor async to speedup parsing

### Type of change

- [x] Performance Improvement
This commit is contained in:
Zhichang Yu
2025-03-03 18:59:49 +08:00
committed by GitHub
parent abac2ca2c5
commit c813c1ff4c
22 changed files with 576 additions and 1005 deletions

View File

@ -17,9 +17,10 @@ from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.general.extractor import Extractor
from graphrag.general.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
from rag.utils import num_tokens_from_string
from timeit import default_timer as timer
import trio
@dataclass
@ -52,7 +53,7 @@ class CommunityReportsExtractor(Extractor):
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
self._max_report_length = max_report_length or 1500
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
@ -86,28 +87,25 @@ class CommunityReportsExtractor(Extractor):
}
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
gen_conf = {"temperature": 0.3}
try:
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response)
response = re.sub(r"\{\{", "{", response)
response = re.sub(r"\}\}", "}", response)
logging.debug(response)
response = json.loads(response)
if not dict_has_keys_with_types(response, [
("title", str),
("summary", str),
("findings", list),
("rating", float),
("rating_explanation", str),
]):
continue
response["weight"] = weight
response["entities"] = ents
except Exception:
logging.exception("CommunityReportsExtractor got exception")
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response)
response = re.sub(r"\{\{", "{", response)
response = re.sub(r"\}\}", "}", response)
logging.debug(response)
response = json.loads(response)
if not dict_has_keys_with_types(response, [
("title", str),
("summary", str),
("findings", list),
("rating", float),
("rating_explanation", str),
]):
continue
response["weight"] = weight
response["entities"] = ents
add_community_info2graph(graph, ents, response["title"])
res_str.append(self._get_text_output(response))