mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Add graphrag (#1793)
### What problem does this PR solve? #1594 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
167
graphrag/description_summary.py
Normal file
167
graphrag/description_summary.py
Normal file
@ -0,0 +1,167 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import html
|
||||
import json
|
||||
import logging
|
||||
import numbers
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import networkx as nx
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
SUMMARIZE_PROMPT = """
|
||||
You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
|
||||
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
|
||||
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
|
||||
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
|
||||
Make sure it is written in third person, and include the entity names so we the have full context.
|
||||
|
||||
#######
|
||||
-Data-
|
||||
Entities: {entity_name}
|
||||
Description List: {description_list}
|
||||
#######
|
||||
Output:
|
||||
"""
|
||||
|
||||
# Max token size for input prompts
|
||||
DEFAULT_MAX_INPUT_TOKENS = 4_000
|
||||
# Max token count for LLM answers
|
||||
DEFAULT_MAX_SUMMARY_LENGTH = 128
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummarizationResult:
|
||||
"""Unipartite graph extraction result class definition."""
|
||||
|
||||
items: str | tuple[str, str]
|
||||
description: str
|
||||
|
||||
|
||||
class SummarizeExtractor:
|
||||
"""Unipartite graph extractor class definition."""
|
||||
|
||||
_llm: CompletionLLM
|
||||
_entity_name_key: str
|
||||
_input_descriptions_key: str
|
||||
_summarization_prompt: str
|
||||
_on_error: ErrorHandlerFn
|
||||
_max_summary_length: int
|
||||
_max_input_tokens: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
entity_name_key: str | None = None,
|
||||
input_descriptions_key: str | None = None,
|
||||
summarization_prompt: str | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
max_summary_length: int | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
):
|
||||
"""Init method definition."""
|
||||
# TODO: streamline construction
|
||||
self._llm = llm_invoker
|
||||
self._entity_name_key = entity_name_key or "entity_name"
|
||||
self._input_descriptions_key = input_descriptions_key or "description_list"
|
||||
|
||||
self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH
|
||||
self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
items: str | tuple[str, str],
|
||||
descriptions: list[str],
|
||||
) -> SummarizationResult:
|
||||
"""Call method definition."""
|
||||
result = ""
|
||||
if len(descriptions) == 0:
|
||||
result = ""
|
||||
if len(descriptions) == 1:
|
||||
result = descriptions[0]
|
||||
else:
|
||||
result = self._summarize_descriptions(items, descriptions)
|
||||
|
||||
return SummarizationResult(
|
||||
items=items,
|
||||
description=result or "",
|
||||
)
|
||||
|
||||
def _summarize_descriptions(
|
||||
self, items: str | tuple[str, str], descriptions: list[str]
|
||||
) -> str:
|
||||
"""Summarize descriptions into a single description."""
|
||||
sorted_items = sorted(items) if isinstance(items, list) else items
|
||||
|
||||
# Safety check, should always be a list
|
||||
if not isinstance(descriptions, list):
|
||||
descriptions = [descriptions]
|
||||
|
||||
# Iterate over descriptions, adding all until the max input tokens is reached
|
||||
usable_tokens = self._max_input_tokens - num_tokens_from_string(
|
||||
self._summarization_prompt
|
||||
)
|
||||
descriptions_collected = []
|
||||
result = ""
|
||||
|
||||
for i, description in enumerate(descriptions):
|
||||
usable_tokens -= num_tokens_from_string(description)
|
||||
descriptions_collected.append(description)
|
||||
|
||||
# If buffer is full, or all descriptions have been added, summarize
|
||||
if (usable_tokens < 0 and len(descriptions_collected) > 1) or (
|
||||
i == len(descriptions) - 1
|
||||
):
|
||||
# Calculate result (final or partial)
|
||||
result = await self._summarize_descriptions_with_llm(
|
||||
sorted_items, descriptions_collected
|
||||
)
|
||||
|
||||
# If we go for another loop, reset values to new
|
||||
if i != len(descriptions) - 1:
|
||||
descriptions_collected = [result]
|
||||
usable_tokens = (
|
||||
self._max_input_tokens
|
||||
- num_tokens_from_string(self._summarization_prompt)
|
||||
- num_tokens_from_string(result)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _summarize_descriptions_with_llm(
|
||||
self, items: str | tuple[str, str] | list[str], descriptions: list[str]
|
||||
):
|
||||
"""Summarize descriptions using the LLM."""
|
||||
variables = {
|
||||
self._entity_name_key: json.dumps(items),
|
||||
self._input_descriptions_key: json.dumps(sorted(descriptions)),
|
||||
}
|
||||
text = perform_variable_replacements(self._summarization_prompt, variables=variables)
|
||||
return self._llm.chat("", [{"role": "user", "content": text}])
|
||||
Reference in New Issue
Block a user