mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Introduced beartype (#3460)
### What problem does this PR solve? Introduced [beartype](https://github.com/beartype/beartype) for runtime type-checking. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -9,8 +9,8 @@ import logging
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
from typing import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Callable
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
from graphrag import leiden
|
||||
@ -26,8 +26,8 @@ from timeit import default_timer as timer
|
||||
class CommunityReportsResult:
|
||||
"""Community reports result class definition."""
|
||||
|
||||
output: List[str]
|
||||
structured_output: List[dict]
|
||||
output: list[str]
|
||||
structured_output: list[dict]
|
||||
|
||||
|
||||
class CommunityReportsExtractor:
|
||||
@ -53,7 +53,7 @@ class CommunityReportsExtractor:
|
||||
self._max_report_length = max_report_length or 1500
|
||||
|
||||
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
||||
communities: dict[str, dict[str, List]] = leiden.run(graph, {})
|
||||
communities: dict[str, dict[str, list]] = leiden.run(graph, {})
|
||||
total = sum([len(comm.items()) for _, comm in communities.items()])
|
||||
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
|
||||
res_str = []
|
||||
|
||||
@ -6,7 +6,6 @@ Reference:
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
from graphrag.leiden import stable_largest_connected_component
|
||||
|
||||
@ -9,8 +9,8 @@ import logging
|
||||
import numbers
|
||||
import re
|
||||
import traceback
|
||||
from typing import Any, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Mapping, Callable
|
||||
import tiktoken
|
||||
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
|
||||
|
||||
@ -18,7 +18,6 @@ import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import json
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
import networkx as nx
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
@ -53,7 +52,7 @@ def graph_merge(g1, g2):
|
||||
return g
|
||||
|
||||
|
||||
def build_knowledge_graph_chunks(tenant_id: str, chunks: List[str], callback, entity_types=DEFAULT_ENTITY_TYPES):
|
||||
def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, entity_types=DEFAULT_ENTITY_TYPES):
|
||||
_, tenant = TenantService.get_by_id(tenant_id)
|
||||
llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id)
|
||||
ext = GraphExtractor(llm_bdl)
|
||||
|
||||
@ -6,8 +6,8 @@ Reference:
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, cast, List
|
||||
import html
|
||||
from typing import Any
|
||||
from graspologic.partition import hierarchical_leiden
|
||||
from graspologic.utils import largest_connected_component
|
||||
|
||||
@ -132,7 +132,7 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
|
||||
return results_by_level
|
||||
|
||||
|
||||
def add_community_info2graph(graph: nx.Graph, nodes: List[str], community_title):
|
||||
def add_community_info2graph(graph: nx.Graph, nodes: list[str], community_title):
|
||||
for n in nodes:
|
||||
if "communities" not in graph.nodes[n]:
|
||||
graph.nodes[n]["communities"] = []
|
||||
|
||||
@ -19,9 +19,9 @@ import collections
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
#
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict
|
||||
|
||||
import pandas as pd
|
||||
from rag.utils.doc_store_conn import OrderByExpr, FusionExpr
|
||||
@ -25,7 +24,7 @@ from rag.nlp.search import Dealer
|
||||
|
||||
class KGSearch(Dealer):
|
||||
def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
|
||||
def merge_into_first(sres, title="") -> Dict[str, str]:
|
||||
def merge_into_first(sres, title="") -> dict[str, str]:
|
||||
if not sres:
|
||||
return {}
|
||||
content_with_weight = ""
|
||||
|
||||
@ -7,8 +7,7 @@ Reference:
|
||||
|
||||
import html
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user