mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
refine loginfo about graprag progress (#1823)
### What problem does this PR solve? ### Type of change - [x] Refactoring
This commit is contained in:
@ -23,16 +23,16 @@ import logging
|
||||
import re
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
|
||||
from typing import Any, List, Callable
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
from graphrag import leiden
|
||||
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
||||
from graphrag.leiden import add_community_info2graph
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
|
||||
from rag.utils import num_tokens_from_string
|
||||
from timeit import default_timer as timer
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -67,11 +67,14 @@ class CommunityReportsExtractor:
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
self._max_report_length = max_report_length or 1500
|
||||
|
||||
def __call__(self, graph: nx.Graph):
|
||||
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
||||
communities: dict[str, dict[str, List]] = leiden.run(graph, {})
|
||||
total = sum([len(comm.items()) for _, comm in communities.items()])
|
||||
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
|
||||
res_str = []
|
||||
res_dict = []
|
||||
over, token_count = 0, 0
|
||||
st = timer()
|
||||
for level, comm in communities.items():
|
||||
for cm_id, ents in comm.items():
|
||||
weight = ents["weight"]
|
||||
@ -84,9 +87,10 @@ class CommunityReportsExtractor:
|
||||
"relation_df": rela_df.to_csv(index_label="id")
|
||||
}
|
||||
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
||||
gen_conf = {"temperature": 0.5}
|
||||
gen_conf = {"temperature": 0.3}
|
||||
try:
|
||||
response = self._llm.chat(text, [], gen_conf)
|
||||
token_count += num_tokens_from_string(text + response)
|
||||
response = re.sub(r"^[^\{]*", "", response)
|
||||
response = re.sub(r"[^\}]*$", "", response)
|
||||
print(response)
|
||||
@ -108,6 +112,8 @@ class CommunityReportsExtractor:
|
||||
add_community_info2graph(graph, ents, response["title"])
|
||||
res_str.append(self._get_text_output(response))
|
||||
res_dict.append(response)
|
||||
over += 1
|
||||
if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
|
||||
|
||||
return CommunityReportsResult(
|
||||
structured_output=res_dict,
|
||||
|
||||
Reference in New Issue
Block a user