mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-31 23:55:06 +08:00
0
rag/graphrag/general/__init__.py
Normal file
0
rag/graphrag/general/__init__.py
Normal file
158
rag/graphrag/general/community_report_prompt.py
Normal file
158
rag/graphrag/general/community_report_prompt.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""
|
||||
Reference:
|
||||
- [GraphRAG](https://github.com/microsoft/graphrag/blob/main/graphrag/prompts/index/community_report.py)
|
||||
"""
|
||||
|
||||
COMMUNITY_REPORT_PROMPT = """
|
||||
You are an AI assistant that helps a human analyst to perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network.
|
||||
|
||||
# Goal
|
||||
Write a comprehensive report of a community, given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will be used to inform decision-makers about information associated with the community and their potential impact. The content of this report includes an overview of the community's key entities, their legal compliance, technical capabilities, reputation, and noteworthy claims.
|
||||
|
||||
# Report Structure
|
||||
|
||||
The report should include the following sections:
|
||||
|
||||
- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title.
|
||||
- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities.
|
||||
- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community.
|
||||
- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating.
|
||||
- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive.
|
||||
|
||||
Return output as a well-formed JSON-formatted string with the following format(in language of 'Text' content):
|
||||
{{
|
||||
"title": <report_title>,
|
||||
"summary": <executive_summary>,
|
||||
"rating": <impact_severity_rating>,
|
||||
"rating_explanation": <rating_explanation>,
|
||||
"findings": [
|
||||
{{
|
||||
"summary":<insight_1_summary>,
|
||||
"explanation": <insight_1_explanation>
|
||||
}},
|
||||
{{
|
||||
"summary":<insight_2_summary>,
|
||||
"explanation": <insight_2_explanation>
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
# Grounding Rules
|
||||
|
||||
Points supported by data should list their data references as follows:
|
||||
|
||||
"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."
|
||||
|
||||
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
|
||||
|
||||
For example:
|
||||
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]."
|
||||
|
||||
where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record.
|
||||
|
||||
Do not include information where the supporting evidence for it is not provided.
|
||||
|
||||
|
||||
# Example Input
|
||||
-----------
|
||||
Text:
|
||||
|
||||
-Entities-
|
||||
|
||||
id,entity,description
|
||||
5,VERDANT OASIS PLAZA,Verdant Oasis Plaza is the location of the Unity March
|
||||
6,HARMONY ASSEMBLY,Harmony Assembly is an organization that is holding a march at Verdant Oasis Plaza
|
||||
|
||||
-Relationships-
|
||||
|
||||
id,source,target,description
|
||||
37,VERDANT OASIS PLAZA,UNITY MARCH,Verdant Oasis Plaza is the location of the Unity March
|
||||
38,VERDANT OASIS PLAZA,HARMONY ASSEMBLY,Harmony Assembly is holding a march at Verdant Oasis Plaza
|
||||
39,VERDANT OASIS PLAZA,UNITY MARCH,The Unity March is taking place at Verdant Oasis Plaza
|
||||
40,VERDANT OASIS PLAZA,TRIBUNE SPOTLIGHT,Tribune Spotlight is reporting on the Unity march taking place at Verdant Oasis Plaza
|
||||
41,VERDANT OASIS PLAZA,BAILEY ASADI,Bailey Asadi is speaking at Verdant Oasis Plaza about the march
|
||||
43,HARMONY ASSEMBLY,UNITY MARCH,Harmony Assembly is organizing the Unity March
|
||||
|
||||
Output:
|
||||
{{
|
||||
"title": "Verdant Oasis Plaza and Unity March",
|
||||
"summary": "The community revolves around the Verdant Oasis Plaza, which is the location of the Unity March. The plaza has relationships with the Harmony Assembly, Unity March, and Tribune Spotlight, all of which are associated with the march event.",
|
||||
"rating": 5.0,
|
||||
"rating_explanation": "The impact severity rating is moderate due to the potential for unrest or conflict during the Unity March.",
|
||||
"findings": [
|
||||
{{
|
||||
"summary": "Verdant Oasis Plaza as the central location",
|
||||
"explanation": "Verdant Oasis Plaza is the central entity in this community, serving as the location for the Unity March. This plaza is the common link between all other entities, suggesting its significance in the community. The plaza's association with the march could potentially lead to issues such as public disorder or conflict, depending on the nature of the march and the reactions it provokes. [Data: Entities (5), Relationships (37, 38, 39, 40, 41,+more)]"
|
||||
}},
|
||||
{{
|
||||
"summary": "Harmony Assembly's role in the community",
|
||||
"explanation": "Harmony Assembly is another key entity in this community, being the organizer of the march at Verdant Oasis Plaza. The nature of Harmony Assembly and its march could be a potential source of threat, depending on their objectives and the reactions they provoke. The relationship between Harmony Assembly and the plaza is crucial in understanding the dynamics of this community. [Data: Entities(6), Relationships (38, 43)]"
|
||||
}},
|
||||
{{
|
||||
"summary": "Unity March as a significant event",
|
||||
"explanation": "The Unity March is a significant event taking place at Verdant Oasis Plaza. This event is a key factor in the community's dynamics and could be a potential source of threat, depending on the nature of the march and the reactions it provokes. The relationship between the march and the plaza is crucial in understanding the dynamics of this community. [Data: Relationships (39)]"
|
||||
}},
|
||||
{{
|
||||
"summary": "Role of Tribune Spotlight",
|
||||
"explanation": "Tribune Spotlight is reporting on the Unity March taking place in Verdant Oasis Plaza. This suggests that the event has attracted media attention, which could amplify its impact on the community. The role of Tribune Spotlight could be significant in shaping public perception of the event and the entities involved. [Data: Relationships (40)]"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
|
||||
# Real Data
|
||||
|
||||
Use the following text for your answer. Do not make anything up in your answer.
|
||||
|
||||
Text:
|
||||
|
||||
-Entities-
|
||||
{entity_df}
|
||||
|
||||
-Relationships-
|
||||
{relation_df}
|
||||
|
||||
The report should include the following sections:
|
||||
|
||||
- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title.
|
||||
- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities.
|
||||
- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community.
|
||||
- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating.
|
||||
- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive.
|
||||
|
||||
Return output as a well-formed JSON-formatted string with the following format(in language of 'Text' content):
|
||||
{{
|
||||
"title": <report_title>,
|
||||
"summary": <executive_summary>,
|
||||
"rating": <impact_severity_rating>,
|
||||
"rating_explanation": <rating_explanation>,
|
||||
"findings": [
|
||||
{{
|
||||
"summary":<insight_1_summary>,
|
||||
"explanation": <insight_1_explanation>
|
||||
}},
|
||||
{{
|
||||
"summary":<insight_2_summary>,
|
||||
"explanation": <insight_2_explanation>
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
# Grounding Rules
|
||||
|
||||
Points supported by data should list their data references as follows:
|
||||
|
||||
"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."
|
||||
|
||||
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
|
||||
|
||||
For example:
|
||||
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]."
|
||||
|
||||
where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record.
|
||||
|
||||
Do not include information where the supporting evidence for it is not provided.
|
||||
|
||||
Output:"""
|
||||
186
rag/graphrag/general/community_reports_extractor.py
Normal file
186
rag/graphrag/general/community_reports_extractor.py
Normal file
@ -0,0 +1,186 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Callable
|
||||
from dataclasses import dataclass
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.exceptions import TaskCanceledException
|
||||
from common.connection_utils import timeout
|
||||
from rag.graphrag.general import leiden
|
||||
from rag.graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
||||
from rag.graphrag.general.extractor import Extractor
|
||||
from rag.graphrag.general.leiden import add_community_info2graph
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from rag.graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
@dataclass
|
||||
class CommunityReportsResult:
|
||||
"""Community reports result class definition."""
|
||||
|
||||
output: list[str]
|
||||
structured_output: list[dict]
|
||||
|
||||
|
||||
class CommunityReportsExtractor(Extractor):
|
||||
"""Community reports extractor class definition."""
|
||||
|
||||
_extraction_prompt: str
|
||||
_output_formatter_prompt: str
|
||||
_max_report_length: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
max_report_length: int | None = None,
|
||||
):
|
||||
super().__init__(llm_invoker)
|
||||
"""Init method definition."""
|
||||
self._llm = llm_invoker
|
||||
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
|
||||
self._max_report_length = max_report_length or 1500
|
||||
|
||||
async def __call__(self, graph: nx.Graph, callback: Callable | None = None, task_id: str = ""):
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
for node_degree in graph.degree:
|
||||
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
||||
|
||||
communities: dict[str, dict[str, list]] = leiden.run(graph, {})
|
||||
total = sum([len(comm.items()) for _, comm in communities.items()])
|
||||
res_str = []
|
||||
res_dict = []
|
||||
over, token_count = 0, 0
|
||||
@timeout(120)
|
||||
async def extract_community_report(community):
|
||||
nonlocal res_str, res_dict, over, token_count
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during community report extraction.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
cm_id, cm = community
|
||||
weight = cm["weight"]
|
||||
ents = cm["nodes"]
|
||||
if len(ents) < 2:
|
||||
return
|
||||
ent_list = [{"entity": ent, "description": graph.nodes[ent]["description"]} for ent in ents]
|
||||
ent_df = pd.DataFrame(ent_list)
|
||||
|
||||
rela_list = []
|
||||
k = 0
|
||||
for i in range(0, len(ents)):
|
||||
if k >= 10000:
|
||||
break
|
||||
for j in range(i + 1, len(ents)):
|
||||
if k >= 10000:
|
||||
break
|
||||
edge = graph.get_edge_data(ents[i], ents[j])
|
||||
if edge is None:
|
||||
continue
|
||||
rela_list.append({"source": ents[i], "target": ents[j], "description": edge["description"]})
|
||||
k += 1
|
||||
rela_df = pd.DataFrame(rela_list)
|
||||
|
||||
prompt_variables = {
|
||||
"entity_df": ent_df.to_csv(index_label="id"),
|
||||
"relation_df": rela_df.to_csv(index_label="id")
|
||||
}
|
||||
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
||||
async with chat_limiter:
|
||||
try:
|
||||
timeout = 180 if enable_timeout_assertion else 1000000000
|
||||
response = await asyncio.wait_for(thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logging.warning("extract_community_report._chat timeout, skipping...")
|
||||
return
|
||||
except Exception as e:
|
||||
logging.error(f"extract_community_report._chat failed: {e}")
|
||||
return
|
||||
token_count += num_tokens_from_string(text + response)
|
||||
response = re.sub(r"^[^\{]*", "", response)
|
||||
response = re.sub(r"[^\}]*$", "", response)
|
||||
response = re.sub(r"\{\{", "{", response)
|
||||
response = re.sub(r"\}\}", "}", response)
|
||||
logging.debug(response)
|
||||
try:
|
||||
response = json.loads(response)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse JSON response: {e}")
|
||||
logging.error(f"Response content: {response}")
|
||||
return
|
||||
if not dict_has_keys_with_types(response, [
|
||||
("title", str),
|
||||
("summary", str),
|
||||
("findings", list),
|
||||
("rating", float),
|
||||
("rating_explanation", str),
|
||||
]):
|
||||
return
|
||||
response["weight"] = weight
|
||||
response["entities"] = ents
|
||||
add_community_info2graph(graph, ents, response["title"])
|
||||
res_str.append(self._get_text_output(response))
|
||||
res_dict.append(response)
|
||||
over += 1
|
||||
if callback:
|
||||
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
|
||||
|
||||
st = asyncio.get_running_loop().time()
|
||||
tasks = []
|
||||
for level, comm in communities.items():
|
||||
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
||||
for community in comm.items():
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before community processing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
tasks.append(asyncio.create_task(extract_community_report(community)))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in community processing: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
if callback:
|
||||
callback(msg=f"Community reports done in {asyncio.get_running_loop().time() - st:.2f}s, used tokens: {token_count}")
|
||||
|
||||
return CommunityReportsResult(
|
||||
structured_output=res_dict,
|
||||
output=res_str,
|
||||
)
|
||||
|
||||
def _get_text_output(self, parsed_output: dict) -> str:
|
||||
title = parsed_output.get("title", "Report")
|
||||
summary = parsed_output.get("summary", "")
|
||||
findings = parsed_output.get("findings", [])
|
||||
|
||||
def finding_summary(finding: dict):
|
||||
if isinstance(finding, str):
|
||||
return finding
|
||||
return finding.get("summary")
|
||||
|
||||
def finding_explanation(finding: dict):
|
||||
if isinstance(finding, str):
|
||||
return ""
|
||||
return finding.get("explanation")
|
||||
|
||||
report_sections = "\n\n".join(
|
||||
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
|
||||
)
|
||||
return f"# {title}\n\n{summary}\n\n{report_sections}"
|
||||
66
rag/graphrag/general/entity_embedding.py
Normal file
66
rag/graphrag/general/entity_embedding.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
from dataclasses import dataclass
|
||||
from rag.graphrag.general.leiden import stable_largest_connected_component
|
||||
import graspologic as gc
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeEmbeddings:
|
||||
"""Node embeddings class definition."""
|
||||
|
||||
nodes: list[str]
|
||||
embeddings: np.ndarray
|
||||
|
||||
|
||||
def embed_node2vec(
|
||||
graph: nx.Graph | nx.DiGraph,
|
||||
dimensions: int = 1536,
|
||||
num_walks: int = 10,
|
||||
walk_length: int = 40,
|
||||
window_size: int = 2,
|
||||
iterations: int = 3,
|
||||
random_seed: int = 86,
|
||||
) -> NodeEmbeddings:
|
||||
"""Generate node embeddings using Node2Vec."""
|
||||
# generate embedding
|
||||
lcc_tensors = gc.embed.node2vec_embed( # type: ignore
|
||||
graph=graph,
|
||||
dimensions=dimensions,
|
||||
window_size=window_size,
|
||||
iterations=iterations,
|
||||
num_walks=num_walks,
|
||||
walk_length=walk_length,
|
||||
random_seed=random_seed,
|
||||
)
|
||||
return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1])
|
||||
|
||||
|
||||
def run(graph: nx.Graph, args: dict[str, Any]) -> dict:
|
||||
"""Run method definition."""
|
||||
if args.get("use_lcc", True):
|
||||
graph = stable_largest_connected_component(graph)
|
||||
|
||||
# create graph embedding using node2vec
|
||||
embeddings = embed_node2vec(
|
||||
graph=graph,
|
||||
dimensions=args.get("dimensions", 1536),
|
||||
num_walks=args.get("num_walks", 10),
|
||||
walk_length=args.get("walk_length", 40),
|
||||
window_size=args.get("window_size", 2),
|
||||
iterations=args.get("iterations", 3),
|
||||
random_seed=args.get("random_seed", 86),
|
||||
)
|
||||
|
||||
pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True)
|
||||
sorted_pairs = sorted(pairs, key=lambda x: x[0])
|
||||
|
||||
return dict(sorted_pairs)
|
||||
344
rag/graphrag/general/extractor.py
Normal file
344
rag/graphrag/general/extractor.py
Normal file
@ -0,0 +1,344 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections import Counter, defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.connection_utils import timeout
|
||||
from common.token_utils import truncate
|
||||
from rag.graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
|
||||
from rag.graphrag.utils import (
|
||||
GraphChange,
|
||||
chat_limiter,
|
||||
flat_uniq_list,
|
||||
get_from_to,
|
||||
get_llm_cache,
|
||||
handle_single_entity_extraction,
|
||||
handle_single_relationship_extraction,
|
||||
set_llm_cache,
|
||||
split_string_by_multi_markers,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from rag.prompts.generator import message_fit_in
|
||||
from common.exceptions import TaskCanceledException
|
||||
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"]
|
||||
ENTITY_EXTRACTION_MAX_GLEANINGS = 2
|
||||
MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK = int(os.environ.get("MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK", 10))
|
||||
|
||||
|
||||
class Extractor:
|
||||
_llm: CompletionLLM
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
language: str | None = "English",
|
||||
entity_types: list[str] | None = None,
|
||||
):
|
||||
self._llm = llm_invoker
|
||||
self._language = language
|
||||
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
|
||||
|
||||
@timeout(60 * 20)
|
||||
def _chat(self, system, history, gen_conf={}, task_id=""):
|
||||
hist = deepcopy(history)
|
||||
conf = deepcopy(gen_conf)
|
||||
response = get_llm_cache(self._llm.llm_name, system, hist, conf)
|
||||
if response:
|
||||
return response
|
||||
_, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.92))
|
||||
response = ""
|
||||
for attempt in range(3):
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
try:
|
||||
response = asyncio.run(self._llm.async_chat(system_msg[0]["content"], hist, conf))
|
||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
if attempt == 2:
|
||||
raise
|
||||
|
||||
return response
|
||||
|
||||
def _entities_and_relations(self, chunk_key: str, records: list, tuple_delimiter: str):
|
||||
maybe_nodes = defaultdict(list)
|
||||
maybe_edges = defaultdict(list)
|
||||
ent_types = [t.lower() for t in self._entity_types]
|
||||
for record in records:
|
||||
record_attributes = split_string_by_multi_markers(record, [tuple_delimiter])
|
||||
|
||||
if_entities = handle_single_entity_extraction(record_attributes, chunk_key)
|
||||
if if_entities is not None and if_entities.get("entity_type", "unknown").lower() in ent_types:
|
||||
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
||||
continue
|
||||
|
||||
if_relation = handle_single_relationship_extraction(record_attributes, chunk_key)
|
||||
if if_relation is not None:
|
||||
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(if_relation)
|
||||
return dict(maybe_nodes), dict(maybe_edges)
|
||||
|
||||
async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""):
|
||||
self.callback = callback
|
||||
start_ts = asyncio.get_running_loop().time()
|
||||
|
||||
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
|
||||
out_results = []
|
||||
error_count = 0
|
||||
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
|
||||
|
||||
limiter = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""):
|
||||
nonlocal error_count
|
||||
async with limiter:
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during entity extraction")
|
||||
|
||||
try:
|
||||
await self._process_single_content(chunk_key_dp, idx, total, out_results, task_id)
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
error_msg = f"Error processing chunk {idx + 1}/{total}: {str(e)}"
|
||||
logging.warning(error_msg)
|
||||
if self.callback:
|
||||
self.callback(msg=error_msg)
|
||||
|
||||
if error_count > max_errors:
|
||||
raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}")
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(worker((doc_id, ck), i, len(chunks), task_id))
|
||||
for i, ck in enumerate(chunks)
|
||||
]
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in worker: {str(e)}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
if error_count > 0:
|
||||
warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)"
|
||||
logging.warning(warning_msg)
|
||||
if self.callback:
|
||||
self.callback(msg=warning_msg)
|
||||
|
||||
return out_results
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled before entity extraction")
|
||||
|
||||
out_results = await extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=task_id)
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled after entity extraction")
|
||||
|
||||
maybe_nodes = defaultdict(list)
|
||||
maybe_edges = defaultdict(list)
|
||||
sum_token_count = 0
|
||||
for m_nodes, m_edges, token_count in out_results:
|
||||
for k, v in m_nodes.items():
|
||||
maybe_nodes[k].extend(v)
|
||||
for k, v in m_edges.items():
|
||||
maybe_edges[tuple(sorted(k))].extend(v)
|
||||
sum_token_count += token_count
|
||||
now = asyncio.get_running_loop().time()
|
||||
if self.callback:
|
||||
self.callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.")
|
||||
start_ts = now
|
||||
logging.info("Entities merging...")
|
||||
all_entities_data = []
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging")
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(self._merge_nodes(en_nm, ents, all_entities_data, task_id))
|
||||
for en_nm, ents in maybe_nodes.items()
|
||||
]
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error merging nodes: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging")
|
||||
|
||||
now = asyncio.get_running_loop().time()
|
||||
if self.callback:
|
||||
self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.")
|
||||
|
||||
start_ts = now
|
||||
logging.info("Relationships merging...")
|
||||
all_relationships_data = []
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging")
|
||||
|
||||
tasks = []
|
||||
for (src, tgt), rels in maybe_edges.items():
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
self._merge_edges(src, tgt, rels, all_relationships_data, task_id)
|
||||
)
|
||||
)
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during relationships merging: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging")
|
||||
|
||||
now = asyncio.get_running_loop().time()
|
||||
if self.callback:
|
||||
self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.")
|
||||
|
||||
if not len(all_entities_data) and not len(all_relationships_data):
|
||||
logging.warning("Didn't extract any entities and relationships, maybe your LLM is not working")
|
||||
|
||||
if not len(all_entities_data):
|
||||
logging.warning("Didn't extract any entities")
|
||||
if not len(all_relationships_data):
|
||||
logging.warning("Didn't extract any relationships")
|
||||
|
||||
return all_entities_data, all_relationships_data
|
||||
|
||||
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data, task_id=""):
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during merge nodes")
|
||||
|
||||
if not entities:
|
||||
return
|
||||
entity_type = sorted(
|
||||
Counter([dp["entity_type"] for dp in entities]).items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)[0][0]
|
||||
description = GRAPH_FIELD_SEP.join(sorted(set([dp["description"] for dp in entities])))
|
||||
already_source_ids = flat_uniq_list(entities, "source_id")
|
||||
description = await self._handle_entity_relation_summary(entity_name, description, task_id=task_id)
|
||||
node_data = dict(
|
||||
entity_type=entity_type,
|
||||
description=description,
|
||||
source_id=already_source_ids,
|
||||
)
|
||||
node_data["entity_name"] = entity_name
|
||||
all_relationships_data.append(node_data)
|
||||
|
||||
async def _merge_edges(self, src_id: str, tgt_id: str, edges_data: list[dict], all_relationships_data=None, task_id=""):
|
||||
if not edges_data:
|
||||
return
|
||||
weight = sum([edge["weight"] for edge in edges_data])
|
||||
description = GRAPH_FIELD_SEP.join(sorted(set([edge["description"] for edge in edges_data])))
|
||||
description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description, task_id=task_id)
|
||||
keywords = flat_uniq_list(edges_data, "keywords")
|
||||
source_id = flat_uniq_list(edges_data, "source_id")
|
||||
edge_data = dict(src_id=src_id, tgt_id=tgt_id, description=description, keywords=keywords, weight=weight, source_id=source_id)
|
||||
all_relationships_data.append(edge_data)
|
||||
|
||||
async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange, task_id=""):
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during merge graph nodes")
|
||||
|
||||
if len(nodes) <= 1:
|
||||
return
|
||||
change.added_updated_nodes.add(nodes[0])
|
||||
change.removed_nodes.update(nodes[1:])
|
||||
nodes_set = set(nodes)
|
||||
node0_attrs = graph.nodes[nodes[0]]
|
||||
node0_neighbors = set(graph.neighbors(nodes[0]))
|
||||
for node1 in nodes[1:]:
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during merge_graph nodes")
|
||||
|
||||
# Merge two nodes, keep "entity_name", "entity_type", "page_rank" unchanged.
|
||||
node1_attrs = graph.nodes[node1]
|
||||
node0_attrs["description"] += f"{GRAPH_FIELD_SEP}{node1_attrs['description']}"
|
||||
node0_attrs["source_id"] = sorted(set(node0_attrs["source_id"] + node1_attrs["source_id"]))
|
||||
for neighbor in graph.neighbors(node1):
|
||||
change.removed_edges.add(get_from_to(node1, neighbor))
|
||||
if neighbor not in nodes_set:
|
||||
edge1_attrs = graph.get_edge_data(node1, neighbor)
|
||||
if neighbor in node0_neighbors:
|
||||
# Merge two edges
|
||||
change.added_updated_edges.add(get_from_to(nodes[0], neighbor))
|
||||
edge0_attrs = graph.get_edge_data(nodes[0], neighbor)
|
||||
edge0_attrs["weight"] += edge1_attrs["weight"]
|
||||
edge0_attrs["description"] += f"{GRAPH_FIELD_SEP}{edge1_attrs['description']}"
|
||||
for attr in ["keywords", "source_id"]:
|
||||
edge0_attrs[attr] = sorted(set(edge0_attrs[attr] + edge1_attrs[attr]))
|
||||
edge0_attrs["description"] = await self._handle_entity_relation_summary(f"({nodes[0]}, {neighbor})", edge0_attrs["description"], task_id=task_id)
|
||||
graph.add_edge(nodes[0], neighbor, **edge0_attrs)
|
||||
else:
|
||||
graph.add_edge(nodes[0], neighbor, **edge1_attrs)
|
||||
graph.remove_node(node1)
|
||||
node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"], task_id=task_id)
|
||||
graph.nodes[nodes[0]].update(node0_attrs)
|
||||
|
||||
async def _handle_entity_relation_summary(self, entity_or_relation_name: str, description: str, task_id="") -> str:
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
|
||||
|
||||
summary_max_tokens = 512
|
||||
use_description = truncate(description, summary_max_tokens)
|
||||
description_list = use_description.split(GRAPH_FIELD_SEP)
|
||||
if len(description_list) <= 12:
|
||||
return use_description
|
||||
prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT
|
||||
context_base = dict(
|
||||
entity_name=entity_or_relation_name,
|
||||
description_list=description_list,
|
||||
language=self._language,
|
||||
)
|
||||
use_prompt = prompt_template.format(**context_base)
|
||||
logging.info(f"Trigger summary: {entity_or_relation_name}")
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
|
||||
|
||||
async with chat_limiter:
|
||||
summary = await thread_pool_exec(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
|
||||
return summary
|
||||
152
rag/graphrag/general/graph_extractor.py
Normal file
152
rag/graphrag/general/graph_extractor.py
Normal file
@ -0,0 +1,152 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
import tiktoken
|
||||
|
||||
from rag.graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
|
||||
from rag.graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
||||
from rag.graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter, split_string_by_multi_markers
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import networkx as nx
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
DEFAULT_TUPLE_DELIMITER = "<|>"
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphExtractionResult:
|
||||
"""Unipartite graph extraction result class definition."""
|
||||
|
||||
output: nx.Graph
|
||||
source_docs: dict[Any, Any]
|
||||
|
||||
|
||||
class GraphExtractor(Extractor):
|
||||
"""Unipartite graph extractor class definition."""
|
||||
|
||||
_join_descriptions: bool
|
||||
_tuple_delimiter_key: str
|
||||
_record_delimiter_key: str
|
||||
_entity_types_key: str
|
||||
_input_text_key: str
|
||||
_completion_delimiter_key: str
|
||||
_entity_name_key: str
|
||||
_input_descriptions_key: str
|
||||
_extraction_prompt: str
|
||||
_summarization_prompt: str
|
||||
_loop_args: dict[str, Any]
|
||||
_max_gleanings: int
|
||||
_on_error: ErrorHandlerFn
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
language: str | None = "English",
|
||||
entity_types: list[str] | None = None,
|
||||
tuple_delimiter_key: str | None = None,
|
||||
record_delimiter_key: str | None = None,
|
||||
input_text_key: str | None = None,
|
||||
entity_types_key: str | None = None,
|
||||
completion_delimiter_key: str | None = None,
|
||||
join_descriptions=True,
|
||||
max_gleanings: int | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
):
|
||||
super().__init__(llm_invoker, language, entity_types)
|
||||
"""Init method definition."""
|
||||
# TODO: streamline construction
|
||||
self._llm = llm_invoker
|
||||
self._join_descriptions = join_descriptions
|
||||
self._input_text_key = input_text_key or "input_text"
|
||||
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
|
||||
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
|
||||
self._completion_delimiter_key = (
|
||||
completion_delimiter_key or "completion_delimiter"
|
||||
)
|
||||
self._entity_types_key = entity_types_key or "entity_types"
|
||||
self._extraction_prompt = GRAPH_EXTRACTION_PROMPT
|
||||
self._max_gleanings = (
|
||||
max_gleanings
|
||||
if max_gleanings is not None
|
||||
else ENTITY_EXTRACTION_MAX_GLEANINGS
|
||||
)
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
|
||||
|
||||
# Construct the looping arguments
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
yes = encoding.encode("YES")
|
||||
no = encoding.encode("NO")
|
||||
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
|
||||
|
||||
# Wire defaults into the prompt variables
|
||||
self._prompt_variables = {
|
||||
self._tuple_delimiter_key: DEFAULT_TUPLE_DELIMITER,
|
||||
self._record_delimiter_key: DEFAULT_RECORD_DELIMITER,
|
||||
self._completion_delimiter_key: DEFAULT_COMPLETION_DELIMITER,
|
||||
self._entity_types_key: ",".join(entity_types),
|
||||
}
|
||||
|
||||
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results, task_id=""):
|
||||
token_count = 0
|
||||
chunk_key = chunk_key_dp[0]
|
||||
content = chunk_key_dp[1]
|
||||
variables = {
|
||||
**self._prompt_variables,
|
||||
self._input_text_key: content,
|
||||
}
|
||||
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
async with chat_limiter:
|
||||
response = await thread_pool_exec(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
|
||||
token_count += num_tokens_from_string(hint_prompt + response)
|
||||
|
||||
results = response or ""
|
||||
history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}]
|
||||
|
||||
# Repeat to ensure we maximize entity count
|
||||
for i in range(self._max_gleanings):
|
||||
history.append({"role": "user", "content": CONTINUE_PROMPT})
|
||||
async with chat_limiter:
|
||||
response = await thread_pool_exec(self._chat, "", history, {})
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
results += response or ""
|
||||
|
||||
# if this is the final glean, don't bother updating the continuation flag
|
||||
if i >= self._max_gleanings - 1:
|
||||
break
|
||||
history.append({"role": "assistant", "content": response})
|
||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||
async with chat_limiter:
|
||||
continuation = await thread_pool_exec(self._chat, "", history)
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
if continuation != "Y":
|
||||
break
|
||||
history.append({"role": "assistant", "content": "Y"})
|
||||
|
||||
records = split_string_by_multi_markers(
|
||||
results,
|
||||
[self._prompt_variables[self._record_delimiter_key], self._prompt_variables[self._completion_delimiter_key]],
|
||||
)
|
||||
rcds = []
|
||||
for record in records:
|
||||
record = re.search(r"\((.*)\)", record)
|
||||
if record is None:
|
||||
continue
|
||||
rcds.append(record.group(1))
|
||||
records = rcds
|
||||
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._prompt_variables[self._tuple_delimiter_key])
|
||||
out_results.append((maybe_nodes, maybe_edges, token_count))
|
||||
if self.callback:
|
||||
self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.")
|
||||
124
rag/graphrag/general/graph_prompt.py
Normal file
124
rag/graphrag/general/graph_prompt.py
Normal file
@ -0,0 +1,124 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""
|
||||
Reference:
|
||||
- [GraphRAG](https://github.com/microsoft/graphrag/blob/main/graphrag/prompts/index/extract_graph.py)
|
||||
"""
|
||||
|
||||
GRAPH_EXTRACTION_PROMPT = """
|
||||
-Goal-
|
||||
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
||||
|
||||
-Steps-
|
||||
1. Identify all entities. For each identified entity, extract the following information:
|
||||
- entity_name: Name of the entity, capitalized, in language of 'Text'
|
||||
- entity_type: One of the following types: [{entity_types}]
|
||||
- entity_description: Comprehensive description of the entity's attributes and activities in language of 'Text'
|
||||
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>
|
||||
|
||||
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
|
||||
For each pair of related entities, extract the following information:
|
||||
- source_entity: name of the source entity, as identified in step 1
|
||||
- target_entity: name of the target entity, as identified in step 1
|
||||
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other in language of 'Text'
|
||||
- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
|
||||
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_strength>)
|
||||
|
||||
3. Return output as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
|
||||
|
||||
4. When finished, output {completion_delimiter}
|
||||
|
||||
######################
|
||||
-Examples-
|
||||
######################
|
||||
Example 1:
|
||||
|
||||
Entity_types: [person, technology, mission, organization, location]
|
||||
Text:
|
||||
while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
|
||||
|
||||
Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.”
|
||||
|
||||
The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce.
|
||||
|
||||
It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
|
||||
################
|
||||
Output:
|
||||
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter}
|
||||
("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter}
|
||||
("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter}
|
||||
("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter}
|
||||
("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter}
|
||||
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}7){record_delimiter}
|
||||
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}6){record_delimiter}
|
||||
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}8){record_delimiter}
|
||||
("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}5){record_delimiter}
|
||||
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}9){completion_delimiter}
|
||||
#############################
|
||||
Example 2:
|
||||
|
||||
Entity_types: [person, technology, mission, organization, location]
|
||||
Text:
|
||||
They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve.
|
||||
|
||||
Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril.
|
||||
|
||||
Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly
|
||||
#############
|
||||
Output:
|
||||
("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter}
|
||||
("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter}
|
||||
("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter}
|
||||
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}7){record_delimiter}
|
||||
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}9){completion_delimiter}
|
||||
#############################
|
||||
Example 3:
|
||||
|
||||
Entity_types: [person, role, technology, organization, event, location, concept]
|
||||
Text:
|
||||
their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data.
|
||||
|
||||
"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning."
|
||||
|
||||
Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back."
|
||||
|
||||
Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history.
|
||||
|
||||
The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation
|
||||
#############
|
||||
Output:
|
||||
("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter}
|
||||
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter}
|
||||
("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter}
|
||||
("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter}
|
||||
("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter}
|
||||
("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter}
|
||||
("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}9){record_delimiter}
|
||||
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}10){record_delimiter}
|
||||
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}8){record_delimiter}
|
||||
("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}7){completion_delimiter}
|
||||
#############################
|
||||
-Real Data-
|
||||
######################
|
||||
Entity_types: {entity_types}
|
||||
Text: {input_text}
|
||||
######################
|
||||
Output:"""
|
||||
|
||||
CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format:\n"
|
||||
LOOP_PROMPT = "It appears some entities may have still been missed. Answer Y if there are still entities that need to be added, or N if there are none. Please answer with a single letter Y or N.\n"
|
||||
|
||||
SUMMARIZE_DESCRIPTIONS_PROMPT = """
|
||||
You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
|
||||
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
|
||||
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
|
||||
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
|
||||
Make sure it is written in third person, and include the entity names so we the have full context.
|
||||
Use {language} as output language.
|
||||
|
||||
#######
|
||||
-Data-
|
||||
Entities: {entity_name}
|
||||
Description List: {description_list}
|
||||
#######
|
||||
"""
|
||||
610
rag/graphrag/general/index.py
Normal file
610
rag/graphrag/general/index.py
Normal file
@ -0,0 +1,610 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.exceptions import TaskCanceledException
|
||||
from common.misc_utils import get_uuid
|
||||
from common.connection_utils import timeout
|
||||
from rag.graphrag.entity_resolution import EntityResolution
|
||||
from rag.graphrag.general.community_reports_extractor import CommunityReportsExtractor
|
||||
from rag.graphrag.general.extractor import Extractor
|
||||
from rag.graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
|
||||
from rag.graphrag.light.graph_extractor import GraphExtractor as LightKGExt
|
||||
from rag.graphrag.utils import (
|
||||
GraphChange,
|
||||
chunk_id,
|
||||
does_graph_contains,
|
||||
get_graph,
|
||||
graph_merge,
|
||||
set_graph,
|
||||
tidy_graph,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
from common import settings
|
||||
|
||||
|
||||
async def run_graphrag(
|
||||
row: dict,
|
||||
language,
|
||||
with_resolution: bool,
|
||||
with_community: bool,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
):
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
start = asyncio.get_running_loop().time()
|
||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||
chunks = []
|
||||
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||
chunks.append(d["content_with_weight"])
|
||||
|
||||
timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
||||
|
||||
try:
|
||||
subgraph = await asyncio.wait_for(
|
||||
generate_subgraph(
|
||||
LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {})
|
||||
or row["kb_parser_config"]["graphrag"]["method"] != "general"
|
||||
else GeneralKGExt,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chunks,
|
||||
language,
|
||||
row["kb_parser_config"]["graphrag"].get("entity_types", []),
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
),
|
||||
timeout=timeout_sec,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logging.error("generate_subgraph timeout")
|
||||
raise
|
||||
|
||||
if not subgraph:
|
||||
return
|
||||
|
||||
graphrag_task_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value=doc_id, timeout=1200)
|
||||
await graphrag_task_lock.spin_acquire()
|
||||
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
|
||||
|
||||
try:
|
||||
subgraph_nodes = set(subgraph.nodes())
|
||||
new_graph = await merge_subgraph(
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
subgraph,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
assert new_graph is not None
|
||||
|
||||
if not with_resolution and not with_community:
|
||||
return
|
||||
|
||||
if with_resolution:
|
||||
await graphrag_task_lock.spin_acquire()
|
||||
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
|
||||
await resolve_entities(
|
||||
new_graph,
|
||||
subgraph_nodes,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
if with_community:
|
||||
await graphrag_task_lock.spin_acquire()
|
||||
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
|
||||
await extract_community(
|
||||
new_graph,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
finally:
|
||||
graphrag_task_lock.release()
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
|
||||
return
|
||||
|
||||
|
||||
async def run_graphrag_for_kb(
|
||||
row: dict,
|
||||
doc_ids: list[str],
|
||||
language: str,
|
||||
kb_parser_config: dict,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
*,
|
||||
with_resolution: bool = True,
|
||||
with_community: bool = True,
|
||||
max_parallel_docs: int = 4,
|
||||
) -> dict:
|
||||
tenant_id, kb_id = row["tenant_id"], row["kb_id"]
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
start = asyncio.get_running_loop().time()
|
||||
fields_for_chunks = ["content_with_weight", "doc_id"]
|
||||
|
||||
if not doc_ids:
|
||||
logging.info(f"Fetching all docs for {kb_id}")
|
||||
docs, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
doc_ids = [doc["id"] for doc in docs]
|
||||
|
||||
doc_ids = list(dict.fromkeys(doc_ids))
|
||||
if not doc_ids:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} has no processable doc_id.")
|
||||
return {"ok_docs": [], "failed_docs": [], "total_docs": 0, "total_chunks": 0, "seconds": 0.0}
|
||||
|
||||
def load_doc_chunks(doc_id: str) -> list[str]:
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
# DEBUG: Obtener todos los chunks primero
|
||||
raw_chunks = list(settings.retriever.chunk_list(
|
||||
doc_id,
|
||||
tenant_id,
|
||||
[kb_id],
|
||||
max_count=10000, # FIX: Aumentar límite para procesar todos los chunks
|
||||
fields=fields_for_chunks,
|
||||
sort_by_position=True,
|
||||
))
|
||||
|
||||
callback(msg=f"[DEBUG] chunk_list() returned {len(raw_chunks)} raw chunks for doc {doc_id}")
|
||||
|
||||
for d in raw_chunks:
|
||||
content = d["content_with_weight"]
|
||||
if num_tokens_from_string(current_chunk + content) < 4096:
|
||||
current_chunk += content
|
||||
else:
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = content
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
all_doc_chunks: dict[str, list[str]] = {}
|
||||
total_chunks = 0
|
||||
for doc_id in doc_ids:
|
||||
chunks = load_doc_chunks(doc_id)
|
||||
all_doc_chunks[doc_id] = chunks
|
||||
total_chunks += len(chunks)
|
||||
|
||||
if total_chunks == 0:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
|
||||
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
|
||||
|
||||
semaphore = asyncio.Semaphore(max_parallel_docs)
|
||||
|
||||
subgraphs: dict[str, object] = {}
|
||||
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
|
||||
|
||||
async def build_one(doc_id: str):
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled, stopping execution.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
chunks = all_doc_chunks.get(doc_id, [])
|
||||
if not chunks:
|
||||
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
|
||||
return
|
||||
|
||||
kg_extractor = LightKGExt if ("method" not in kb_parser_config.get("graphrag", {}) or kb_parser_config["graphrag"]["method"] != "general") else GeneralKGExt
|
||||
|
||||
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
||||
|
||||
async with semaphore:
|
||||
try:
|
||||
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
|
||||
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
|
||||
|
||||
try:
|
||||
sg = await asyncio.wait_for(
|
||||
generate_subgraph(
|
||||
kg_extractor,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chunks,
|
||||
language,
|
||||
kb_parser_config.get("graphrag", {}).get("entity_types", []),
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"]
|
||||
),
|
||||
timeout=deadline,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
failed_docs.append((doc_id, "timeout"))
|
||||
callback(msg=f"{msg} FAILED: timeout")
|
||||
return
|
||||
if sg:
|
||||
subgraphs[doc_id] = sg
|
||||
callback(msg=f"{msg} done")
|
||||
else:
|
||||
failed_docs.append((doc_id, "subgraph is empty"))
|
||||
callback(msg=f"{msg} empty")
|
||||
except TaskCanceledException as canceled:
|
||||
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {canceled}")
|
||||
except Exception as e:
|
||||
failed_docs.append((doc_id, repr(e)))
|
||||
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}")
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled before processing documents.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
tasks = [asyncio.create_task(build_one(doc_id)) for doc_id in doc_ids]
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in asyncio.gather: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled after document processing.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
ok_docs = [d for d in doc_ids if d in subgraphs]
|
||||
if not ok_docs:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
|
||||
now = asyncio.get_running_loop().time()
|
||||
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||
|
||||
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
|
||||
await kb_lock.spin_acquire()
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled before merging subgraphs.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
try:
|
||||
union_nodes: set = set()
|
||||
final_graph = None
|
||||
|
||||
for doc_id in ok_docs:
|
||||
sg = subgraphs[doc_id]
|
||||
union_nodes.update(set(sg.nodes()))
|
||||
|
||||
new_graph = await merge_subgraph(
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
sg,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
if new_graph is not None:
|
||||
final_graph = new_graph
|
||||
|
||||
if final_graph is None:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished (no in-memory graph returned).")
|
||||
else:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished, graph ready.")
|
||||
finally:
|
||||
kb_lock.release()
|
||||
|
||||
if not with_resolution and not with_community:
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
|
||||
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled before resolution/community extraction.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
await kb_lock.spin_acquire()
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
|
||||
|
||||
try:
|
||||
subgraph_nodes = set()
|
||||
for sg in subgraphs.values():
|
||||
subgraph_nodes.update(set(sg.nodes()))
|
||||
|
||||
if with_resolution:
|
||||
await resolve_entities(
|
||||
final_graph,
|
||||
subgraph_nodes,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
None,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
|
||||
if with_community:
|
||||
await extract_community(
|
||||
final_graph,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
None,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
finally:
|
||||
kb_lock.release()
|
||||
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
|
||||
return {
|
||||
"ok_docs": ok_docs,
|
||||
"failed_docs": failed_docs, # [(doc_id, error), ...]
|
||||
"total_docs": len(doc_ids),
|
||||
"total_chunks": total_chunks,
|
||||
"seconds": now - start,
|
||||
}
|
||||
|
||||
|
||||
async def generate_subgraph(
|
||||
extractor: Extractor,
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
chunks: list[str],
|
||||
language,
|
||||
entity_types,
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
task_id: str = "",
|
||||
):
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during subgraph generation for doc {doc_id}.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
contains = await does_graph_contains(tenant_id, kb_id, doc_id)
|
||||
if contains:
|
||||
callback(msg=f"Graph already contains {doc_id}")
|
||||
return None
|
||||
start = asyncio.get_running_loop().time()
|
||||
ext = extractor(
|
||||
llm_bdl,
|
||||
language=language,
|
||||
entity_types=entity_types,
|
||||
)
|
||||
ents, rels = await ext(doc_id, chunks, callback, task_id=task_id)
|
||||
subgraph = nx.Graph()
|
||||
|
||||
for ent in ents:
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during entity processing for doc {doc_id}.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
assert "description" in ent, f"entity {ent} does not have description"
|
||||
ent["source_id"] = [doc_id]
|
||||
subgraph.add_node(ent["entity_name"], **ent)
|
||||
|
||||
ignored_rels = 0
|
||||
for rel in rels:
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during relationship processing for doc {doc_id}.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
assert "description" in rel, f"relation {rel} does not have description"
|
||||
if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]):
|
||||
ignored_rels += 1
|
||||
continue
|
||||
rel["source_id"] = [doc_id]
|
||||
subgraph.add_edge(
|
||||
rel["src_id"],
|
||||
rel["tgt_id"],
|
||||
**rel,
|
||||
)
|
||||
if ignored_rels:
|
||||
callback(msg=f"ignored {ignored_rels} relations due to missing entities.")
|
||||
tidy_graph(subgraph, callback, check_attribute=False)
|
||||
|
||||
subgraph.graph["source_id"] = [doc_id]
|
||||
chunk = {
|
||||
"content_with_weight": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False),
|
||||
"knowledge_graph_kwd": "subgraph",
|
||||
"kb_id": kb_id,
|
||||
"source_id": [doc_id],
|
||||
"available_int": 0,
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
cid = chunk_id(chunk)
|
||||
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
|
||||
await thread_pool_exec(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
||||
return subgraph
|
||||
|
||||
|
||||
@timeout(60 * 3)
|
||||
async def merge_subgraph(
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
subgraph: nx.Graph,
|
||||
embedding_model,
|
||||
callback,
|
||||
):
|
||||
start = asyncio.get_running_loop().time()
|
||||
change = GraphChange()
|
||||
old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"])
|
||||
if old_graph is not None:
|
||||
logging.info("Merge with an exiting graph...................")
|
||||
tidy_graph(old_graph, callback)
|
||||
new_graph = graph_merge(old_graph, subgraph, change)
|
||||
else:
|
||||
new_graph = subgraph
|
||||
change.added_updated_nodes = set(new_graph.nodes())
|
||||
change.added_updated_edges = set(new_graph.edges())
|
||||
pr = nx.pagerank(new_graph)
|
||||
for node_name, pagerank in pr.items():
|
||||
new_graph.nodes[node_name]["pagerank"] = pagerank
|
||||
|
||||
await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback)
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.")
|
||||
return new_graph
|
||||
|
||||
|
||||
@timeout(60 * 30, 1)
|
||||
async def resolve_entities(
|
||||
graph,
|
||||
subgraph_nodes: set[str],
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
task_id: str = "",
|
||||
):
|
||||
# Check if task has been canceled before resolution
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during entity resolution.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
start = asyncio.get_running_loop().time()
|
||||
er = EntityResolution(
|
||||
llm_bdl,
|
||||
)
|
||||
reso = await er(graph, subgraph_nodes, callback=callback, task_id=task_id)
|
||||
graph = reso.graph
|
||||
change = reso.change
|
||||
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
|
||||
callback(msg="Graph resolution updated pagerank.")
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled after entity resolution.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
||||
|
||||
|
||||
@timeout(60 * 30, 1)
|
||||
async def extract_community(
|
||||
graph,
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
task_id: str = "",
|
||||
):
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled before community extraction.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
start = asyncio.get_running_loop().time()
|
||||
ext = CommunityReportsExtractor(
|
||||
llm_bdl,
|
||||
)
|
||||
cr = await ext(graph, callback=callback, task_id=task_id)
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during community extraction.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
community_structure = cr.structured_output
|
||||
community_reports = cr.output
|
||||
doc_ids = graph.graph["source_id"]
|
||||
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
||||
start = now
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled during community indexing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
chunks = []
|
||||
for stru, rep in zip(community_structure, community_reports):
|
||||
obj = {
|
||||
"report": rep,
|
||||
"evidences": "\n".join([f.get("explanation", "") for f in stru["findings"]]),
|
||||
}
|
||||
chunk = {
|
||||
"id": get_uuid(),
|
||||
"docnm_kwd": stru["title"],
|
||||
"title_tks": rag_tokenizer.tokenize(stru["title"]),
|
||||
"content_with_weight": json.dumps(obj, ensure_ascii=False),
|
||||
"content_ltks": rag_tokenizer.tokenize(obj["report"] + " " + obj["evidences"]),
|
||||
"knowledge_graph_kwd": "community_report",
|
||||
"weight_flt": stru["weight"],
|
||||
"entities_kwd": stru["entities"],
|
||||
"important_kwd": stru["entities"],
|
||||
"kb_id": kb_id,
|
||||
"source_id": list(doc_ids),
|
||||
"available_int": 0,
|
||||
}
|
||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
||||
chunks.append(chunk)
|
||||
|
||||
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
raise Exception(error_message)
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
callback(msg=f"Task {task_id} cancelled after community indexing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
||||
return community_structure, community_reports
|
||||
149
rag/graphrag/general/leiden.py
Normal file
149
rag/graphrag/general/leiden.py
Normal file
@ -0,0 +1,149 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import html
|
||||
from typing import Any, cast
|
||||
from graspologic.partition import hierarchical_leiden
|
||||
from graspologic.utils import largest_connected_component
|
||||
import networkx as nx
|
||||
from networkx import is_empty
|
||||
|
||||
|
||||
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
||||
"""Ensure an undirected graph with the same relationships will always be read the same way."""
|
||||
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
|
||||
|
||||
sorted_nodes = graph.nodes(data=True)
|
||||
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
|
||||
|
||||
fixed_graph.add_nodes_from(sorted_nodes)
|
||||
edges = list(graph.edges(data=True))
|
||||
|
||||
# If the graph is undirected, we create the edges in a stable way, so we get the same results
|
||||
# for example:
|
||||
# A -> B
|
||||
# in graph theory is the same as
|
||||
# B -> A
|
||||
# in an undirected graph
|
||||
# however, this can lead to downstream issues because sometimes
|
||||
# consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A]
|
||||
# but they base some of their logic on the order of the nodes, so the order ends up being important
|
||||
# so we sort the nodes in the edge in a stable way, so that we always get the same order
|
||||
if not graph.is_directed():
|
||||
|
||||
def _sort_source_target(edge):
|
||||
source, target, edge_data = edge
|
||||
if source > target:
|
||||
temp = source
|
||||
source = target
|
||||
target = temp
|
||||
return source, target, edge_data
|
||||
|
||||
edges = [_sort_source_target(edge) for edge in edges]
|
||||
|
||||
def _get_edge_key(source: Any, target: Any) -> str:
|
||||
return f"{source} -> {target}"
|
||||
|
||||
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
|
||||
|
||||
fixed_graph.add_edges_from(edges)
|
||||
return fixed_graph
|
||||
|
||||
|
||||
def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph:
|
||||
"""Normalize node names."""
|
||||
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
|
||||
return nx.relabel_nodes(graph, node_mapping)
|
||||
|
||||
|
||||
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
|
||||
"""Return the largest connected component of the graph, with nodes and edges sorted in a stable way."""
|
||||
graph = graph.copy()
|
||||
graph = cast(nx.Graph, largest_connected_component(graph))
|
||||
graph = normalize_node_names(graph)
|
||||
return _stabilize_graph(graph)
|
||||
|
||||
|
||||
def _compute_leiden_communities(
|
||||
graph: nx.Graph | nx.DiGraph,
|
||||
max_cluster_size: int,
|
||||
use_lcc: bool,
|
||||
seed=0xDEADBEEF,
|
||||
) -> dict[int, dict[str, int]]:
|
||||
"""Return Leiden root communities."""
|
||||
results: dict[int, dict[str, int]] = {}
|
||||
if is_empty(graph):
|
||||
return results
|
||||
if use_lcc:
|
||||
graph = stable_largest_connected_component(graph)
|
||||
|
||||
community_mapping = hierarchical_leiden(
|
||||
graph, max_cluster_size=max_cluster_size, random_seed=seed
|
||||
)
|
||||
for partition in community_mapping:
|
||||
results[partition.level] = results.get(partition.level, {})
|
||||
results[partition.level][partition.node] = partition.cluster
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
|
||||
"""Run method definition."""
|
||||
max_cluster_size = args.get("max_cluster_size", 12)
|
||||
use_lcc = args.get("use_lcc", True)
|
||||
if args.get("verbose", False):
|
||||
logging.debug(
|
||||
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
|
||||
)
|
||||
nodes = set(graph.nodes())
|
||||
if not nodes:
|
||||
return {}
|
||||
|
||||
node_id_to_community_map = _compute_leiden_communities(
|
||||
graph=graph,
|
||||
max_cluster_size=max_cluster_size,
|
||||
use_lcc=use_lcc,
|
||||
seed=args.get("seed", 0xDEADBEEF),
|
||||
)
|
||||
levels = args.get("levels")
|
||||
|
||||
# If they don't pass in levels, use them all
|
||||
if levels is None:
|
||||
levels = sorted(node_id_to_community_map.keys())
|
||||
|
||||
results_by_level: dict[int, dict[str, list[str]]] = {}
|
||||
for level in levels:
|
||||
result = {}
|
||||
results_by_level[level] = result
|
||||
for node_id, raw_community_id in node_id_to_community_map[level].items():
|
||||
if node_id not in nodes:
|
||||
logging.warning(f"Node {node_id} not found in the graph.")
|
||||
continue
|
||||
community_id = str(raw_community_id)
|
||||
if community_id not in result:
|
||||
result[community_id] = {"weight": 0, "nodes": []}
|
||||
result[community_id]["nodes"].append(node_id)
|
||||
result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1)
|
||||
weights = [comm["weight"] for _, comm in result.items()]
|
||||
if not weights:
|
||||
continue
|
||||
max_weight = max(weights)
|
||||
if max_weight == 0:
|
||||
continue
|
||||
for _, comm in result.items():
|
||||
comm["weight"] /= max_weight
|
||||
|
||||
return results_by_level
|
||||
|
||||
|
||||
def add_community_info2graph(graph: nx.Graph, nodes: list[str], community_title):
|
||||
for n in nodes:
|
||||
if "communities" not in graph.nodes[n]:
|
||||
graph.nodes[n]["communities"] = []
|
||||
graph.nodes[n]["communities"].append(community_title)
|
||||
graph.nodes[n]["communities"] = list(set(graph.nodes[n]["communities"]))
|
||||
193
rag/graphrag/general/mind_map_extractor.py
Normal file
193
rag/graphrag/general/mind_map_extractor.py
Normal file
@ -0,0 +1,193 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import collections
|
||||
import re
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from rag.graphrag.general.extractor import Extractor
|
||||
from rag.graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
|
||||
from rag.graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import markdown_to_json
|
||||
from functools import reduce
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
@dataclass
|
||||
class MindMapResult:
|
||||
"""Unipartite Mind Graph result class definition."""
|
||||
output: dict
|
||||
|
||||
|
||||
class MindMapExtractor(Extractor):
|
||||
_input_text_key: str
|
||||
_mind_map_prompt: str
|
||||
_on_error: ErrorHandlerFn
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
prompt: str | None = None,
|
||||
input_text_key: str | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
):
|
||||
"""Init method definition."""
|
||||
# TODO: streamline construction
|
||||
self._llm = llm_invoker
|
||||
self._input_text_key = input_text_key or "input_text"
|
||||
self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
|
||||
def _key(self, k):
|
||||
return re.sub(r"\*+", "", k)
|
||||
|
||||
def _be_children(self, obj: dict, keyset: set):
|
||||
if isinstance(obj, str):
|
||||
obj = [obj]
|
||||
if isinstance(obj, list):
|
||||
keyset.update(obj)
|
||||
obj = [re.sub(r"\*+", "", i) for i in obj]
|
||||
return [{"id": i, "children": []} for i in obj if i]
|
||||
arr = []
|
||||
for k, v in obj.items():
|
||||
k = self._key(k)
|
||||
if k and k not in keyset:
|
||||
keyset.add(k)
|
||||
arr.append(
|
||||
{
|
||||
"id": k,
|
||||
"children": self._be_children(v, keyset)
|
||||
}
|
||||
)
|
||||
return arr
|
||||
|
||||
async def __call__(
|
||||
self, sections: list[str], prompt_variables: dict[str, Any] | None = None
|
||||
) -> MindMapResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
|
||||
res = []
|
||||
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
|
||||
texts = []
|
||||
cnt = 0
|
||||
tasks = []
|
||||
for i in range(len(sections)):
|
||||
section_cnt = num_tokens_from_string(sections[i])
|
||||
if cnt + section_cnt >= token_count and texts:
|
||||
tasks.append(asyncio.create_task(
|
||||
self._process_document("".join(texts), prompt_variables, res)
|
||||
))
|
||||
texts = []
|
||||
cnt = 0
|
||||
|
||||
texts.append(sections[i])
|
||||
cnt += section_cnt
|
||||
if texts:
|
||||
tasks.append(asyncio.create_task(
|
||||
self._process_document("".join(texts), prompt_variables, res)
|
||||
))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing document: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
if not res:
|
||||
return MindMapResult(output={"id": "root", "children": []})
|
||||
merge_json = reduce(self._merge, res)
|
||||
if len(merge_json) > 1:
|
||||
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
|
||||
keyset = set(i for i in keys if i)
|
||||
merge_json = {
|
||||
"id": "root",
|
||||
"children": [
|
||||
{
|
||||
"id": self._key(k),
|
||||
"children": self._be_children(v, keyset)
|
||||
}
|
||||
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
|
||||
]
|
||||
}
|
||||
else:
|
||||
k = self._key(list(merge_json.keys())[0])
|
||||
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
|
||||
|
||||
return MindMapResult(output=merge_json)
|
||||
|
||||
def _merge(self, d1, d2):
|
||||
for k in d1:
|
||||
if k in d2:
|
||||
if isinstance(d1[k], dict) and isinstance(d2[k], dict):
|
||||
self._merge(d1[k], d2[k])
|
||||
elif isinstance(d1[k], list) and isinstance(d2[k], list):
|
||||
d2[k].extend(d1[k])
|
||||
else:
|
||||
d2[k] = d1[k]
|
||||
else:
|
||||
d2[k] = d1[k]
|
||||
|
||||
return d2
|
||||
|
||||
def _list_to_kv(self, data):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
self._list_to_kv(value)
|
||||
elif isinstance(value, list):
|
||||
new_value = {}
|
||||
for i in range(len(value)):
|
||||
if isinstance(value[i], list) and i > 0:
|
||||
new_value[value[i - 1]] = value[i][0]
|
||||
data[key] = new_value
|
||||
else:
|
||||
continue
|
||||
return data
|
||||
|
||||
def _todict(self, layer: collections.OrderedDict):
|
||||
to_ret = layer
|
||||
if isinstance(layer, collections.OrderedDict):
|
||||
to_ret = dict(layer)
|
||||
|
||||
try:
|
||||
for key, value in to_ret.items():
|
||||
to_ret[key] = self._todict(value)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return self._list_to_kv(to_ret)
|
||||
|
||||
async def _process_document(
|
||||
self, text: str, prompt_variables: dict[str, str], out_res
|
||||
) -> str:
|
||||
variables = {
|
||||
**prompt_variables,
|
||||
self._input_text_key: text,
|
||||
}
|
||||
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
|
||||
async with chat_limiter:
|
||||
response = await thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{})
|
||||
response = re.sub(r"```[^\n]*", "", response)
|
||||
logging.debug(response)
|
||||
logging.debug(self._todict(markdown_to_json.dictify(response)))
|
||||
out_res.append(self._todict(markdown_to_json.dictify(response)))
|
||||
35
rag/graphrag/general/mind_map_prompt.py
Normal file
35
rag/graphrag/general/mind_map_prompt.py
Normal file
@ -0,0 +1,35 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
MIND_MAP_EXTRACTION_PROMPT = """
|
||||
- Role: You're a talent text processor to summarize a piece of text into a mind map.
|
||||
|
||||
- Step of task:
|
||||
1. Generate a title for user's 'TEXT'。
|
||||
2. Classify the 'TEXT' into sections of a mind map.
|
||||
3. If the subject matter is really complex, split them into sub-sections and sub-subsections.
|
||||
4. Add a shot content summary of the bottom level section.
|
||||
|
||||
- Output requirement:
|
||||
- Generate at least 4 levels.
|
||||
- Always try to maximize the number of sub-sections.
|
||||
- In language of 'Text'
|
||||
- MUST IN FORMAT OF MARKDOWN
|
||||
|
||||
-TEXT-
|
||||
{input_text}
|
||||
|
||||
"""
|
||||
110
rag/graphrag/general/smoke.py
Normal file
110
rag/graphrag/general/smoke.py
Normal file
@ -0,0 +1,110 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import networkx as nx
|
||||
|
||||
from common.constants import LLMType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService
|
||||
from rag.graphrag.general.graph_extractor import GraphExtractor
|
||||
from rag.graphrag.general.index import update_graph, with_resolution, with_community
|
||||
from common import settings
|
||||
|
||||
settings.init_settings()
|
||||
|
||||
|
||||
def callback(prog=None, msg="Processing..."):
|
||||
logging.info(msg)
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tenant_id",
|
||||
default=False,
|
||||
help="Tenant ID",
|
||||
action="store",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--doc_id",
|
||||
default=False,
|
||||
help="Document ID",
|
||||
action="store",
|
||||
required=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
e, doc = DocumentService.get_by_id(args.doc_id)
|
||||
if not e:
|
||||
raise LookupError("Document not found.")
|
||||
kb_id = doc.kb_id
|
||||
|
||||
chunks = [
|
||||
d["content_with_weight"]
|
||||
for d in settings.retriever.chunk_list(
|
||||
args.doc_id,
|
||||
args.tenant_id,
|
||||
[kb_id],
|
||||
max_count=6,
|
||||
fields=["content_with_weight"],
|
||||
)
|
||||
]
|
||||
|
||||
_, tenant = TenantService.get_by_id(args.tenant_id)
|
||||
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
|
||||
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
|
||||
graph, doc_ids = await update_graph(
|
||||
GraphExtractor,
|
||||
args.tenant_id,
|
||||
kb_id,
|
||||
args.doc_id,
|
||||
chunks,
|
||||
"English",
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
)
|
||||
print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2))
|
||||
|
||||
await with_resolution(
|
||||
args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback
|
||||
)
|
||||
community_structure, community_reports = await with_community(
|
||||
args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback
|
||||
)
|
||||
|
||||
print(
|
||||
"------------------ COMMUNITY STRUCTURE--------------------\n",
|
||||
json.dumps(community_structure, ensure_ascii=False, indent=2),
|
||||
)
|
||||
print(
|
||||
"------------------ COMMUNITY REPORTS----------------------\n",
|
||||
community_reports,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main)
|
||||
Reference in New Issue
Block a user