Made task_executor async to speedup parsing (#5530)

### What problem does this PR solve?

Made task_executor async to speedup parsing

### Type of change

- [x] Performance Improvement
This commit is contained in:
Zhichang Yu
2025-03-03 18:59:49 +08:00
committed by GitHub
parent abac2ca2c5
commit c813c1ff4c
22 changed files with 576 additions and 1005 deletions

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import itertools
import re
import time
@ -21,13 +20,14 @@ from dataclasses import dataclass
from typing import Any, Callable
import networkx as nx
import trio
from graphrag.general.extractor import Extractor
from rag.nlp import is_english
import editdistance
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements
from graphrag.utils import perform_variable_replacements, chat_limiter
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
@ -67,13 +67,13 @@ class EntityResolution(Extractor):
self._resolution_result_delimiter_key = "resolution_result_delimiter"
self._input_text_key = "input_text"
def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
# Wire defaults into the prompt variables
prompt_variables = {
self.prompt_variables = {
**prompt_variables,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER,
@ -94,39 +94,12 @@ class EntityResolution(Extractor):
for k, v in node_clusters.items():
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
gen_conf = {"temperature": 0.5}
resolution_result = set()
for candidate_resolution_i in candidate_resolution.items():
if candidate_resolution_i[1]:
try:
pair_txt = [
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
for index, candidate in enumerate(candidate_resolution_i[1]):
pair_txt.append(
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
pair_txt.append(
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
pair_prompt = '\n'.join(pair_txt)
variables = {
**prompt_variables,
self._input_text_key: pair_prompt
}
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
result = self._process_results(len(candidate_resolution_i[1]), response,
prompt_variables.get(self._record_delimiter_key,
DEFAULT_RECORD_DELIMITER),
prompt_variables.get(self._entity_index_dilimiter_key,
DEFAULT_ENTITY_INDEX_DELIMITER),
prompt_variables.get(self._resolution_result_delimiter_key,
DEFAULT_RESOLUTION_RESULT_DELIMITER))
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
except Exception:
logging.exception("error entity resolution")
async with trio.open_nursery() as nursery:
for candidate_resolution_i in candidate_resolution.items():
if not candidate_resolution_i[1]:
continue
nursery.start_soon(self._resolve_candidate(candidate_resolution_i, resolution_result))
connect_graph = nx.Graph()
removed_entities = []
@ -172,6 +145,34 @@ class EntityResolution(Extractor):
removed_entities=removed_entities
)
async def _resolve_candidate(self, candidate_resolution_i, resolution_result):
gen_conf = {"temperature": 0.5}
pair_txt = [
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
for index, candidate in enumerate(candidate_resolution_i[1]):
pair_txt.append(
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
pair_txt.append(
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
pair_prompt = '\n'.join(pair_txt)
variables = {
**self.prompt_variables,
self._input_text_key: pair_prompt
}
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
result = self._process_results(len(candidate_resolution_i[1]), response,
self.prompt_variables.get(self._record_delimiter_key,
DEFAULT_RECORD_DELIMITER),
self.prompt_variables.get(self._entity_index_dilimiter_key,
DEFAULT_ENTITY_INDEX_DELIMITER),
self.prompt_variables.get(self._resolution_result_delimiter_key,
DEFAULT_RESOLUTION_RESULT_DELIMITER))
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
def _process_results(
self,
records_length: int,