refine loginfo about graprag progress (#1823)

### What problem does this PR solve?



### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu
2024-08-06 16:01:43 +08:00
committed by GitHub
parent 3fd7db40ea
commit 43199c45c3
5 changed files with 32 additions and 15 deletions

View File

@ -21,13 +21,14 @@ import numbers
import re
import traceback
from dataclasses import dataclass
from typing import Any, Mapping
from typing import Any, Mapping, Callable
import tiktoken
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx
from rag.utils import num_tokens_from_string
from timeit import default_timer as timer
DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
@ -103,7 +104,9 @@ class GraphExtractor:
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
def __call__(
self, texts: list[str], prompt_variables: dict[str, Any] | None = None
self, texts: list[str],
prompt_variables: dict[str, Any] | None = None,
callback: Callable | None = None
) -> GraphExtractionResult:
"""Call method definition."""
if prompt_variables is None:
@ -127,12 +130,17 @@ class GraphExtractor:
),
}
st = timer()
total = len(texts)
total_token_count = 0
for doc_index, text in enumerate(texts):
try:
# Invoke the entity extraction
result = self._process_document(text, prompt_variables)
result, token_count = self._process_document(text, prompt_variables)
source_doc_map[doc_index] = text
all_records[doc_index] = result
total_token_count += token_count
if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
except Exception as e:
logging.exception("error extracting graph")
self._on_error(
@ -162,9 +170,11 @@ class GraphExtractor:
**prompt_variables,
self._input_text_key: text,
}
token_count = 0
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.5}
gen_conf = {"temperature": 0.3}
response = self._llm.chat(text, [], gen_conf)
token_count = num_tokens_from_string(text + response)
results = response or ""
history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
@ -185,7 +195,7 @@ class GraphExtractor:
if continuation != "YES":
break
return results
return results, token_count
def _process_results(
self,