Fix errors detected by Ruff (#3918)

### What problem does this PR solve?

Fix errors detected by Ruff

### Type of change

- [x] Refactoring
This commit is contained in:
Zhichang Yu
2024-12-08 14:21:12 +08:00
committed by GitHub
parent e267a026f3
commit 0d68a6cd1b
97 changed files with 2558 additions and 1976 deletions

View File

@ -88,7 +88,8 @@ class CommunityReportsExtractor:
("findings", list),
("rating", float),
("rating_explanation", str),
]): continue
]):
continue
response["weight"] = weight
response["entities"] = ents
except Exception as e:
@ -100,7 +101,8 @@ class CommunityReportsExtractor:
res_str.append(self._get_text_output(response))
res_dict.append(response)
over += 1
if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
if callback:
callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
return CommunityReportsResult(
structured_output=res_dict,

View File

@ -8,6 +8,7 @@ Reference:
from typing import Any
import numpy as np
import networkx as nx
from dataclasses import dataclass
from graphrag.leiden import stable_largest_connected_component

View File

@ -129,9 +129,11 @@ class GraphExtractor:
source_doc_map[doc_index] = text
all_records[doc_index] = result
total_token_count += token_count
if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
if callback:
callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
except Exception as e:
if callback: callback(msg="Knowledge graph extraction error:{}".format(str(e)))
if callback:
callback(msg="Knowledge graph extraction error:{}".format(str(e)))
logging.exception("error extracting graph")
self._on_error(
e,
@ -164,7 +166,8 @@ class GraphExtractor:
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.3}
response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
if response.find("**ERROR**") >= 0: raise Exception(response)
if response.find("**ERROR**") >= 0:
raise Exception(response)
token_count = num_tokens_from_string(text + response)
results = response or ""
@ -175,7 +178,8 @@ class GraphExtractor:
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
response = self._llm.chat("", history, gen_conf)
if response.find("**ERROR**") >=0: raise Exception(response)
if response.find("**ERROR**") >=0:
raise Exception(response)
results += response or ""
# if this is the final glean, don't bother updating the continuation flag

View File

@ -134,7 +134,8 @@ def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, en
callback(0.75, "Extracting mind graph.")
mindmap = MindMapExtractor(llm_bdl)
mg = mindmap(_chunks).output
if not len(mg.keys()): return chunks
if not len(mg.keys()):
return chunks
logging.debug(json.dumps(mg, ensure_ascii=False, indent=2))
chunks.append(

View File

@ -78,7 +78,8 @@ def _compute_leiden_communities(
) -> dict[int, dict[str, int]]:
"""Return Leiden root communities."""
results: dict[int, dict[str, int]] = {}
if is_empty(graph): return results
if is_empty(graph):
return results
if use_lcc:
graph = stable_largest_connected_component(graph)
@ -100,7 +101,8 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
logging.debug(
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
)
if not graph.nodes(): return {}
if not graph.nodes():
return {}
node_id_to_community_map = _compute_leiden_communities(
graph=graph,
@ -125,9 +127,11 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
result[community_id]["nodes"].append(node_id)
result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1)
weights = [comm["weight"] for _, comm in result.items()]
if not weights:continue
if not weights:
continue
max_weight = max(weights)
for _, comm in result.items(): comm["weight"] /= max_weight
for _, comm in result.items():
comm["weight"] /= max_weight
return results_by_level