Introduced beartype (#3460)

### What problem does this PR solve?

Introduced [beartype](https://github.com/beartype/beartype) for runtime
type-checking.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Zhichang Yu
2024-11-18 17:38:17 +08:00
committed by GitHub
parent 3824c1fec0
commit 4413683898
32 changed files with 125 additions and 134 deletions

View File

@ -9,8 +9,8 @@ import logging
import json
import re
import traceback
from typing import Callable
from dataclasses import dataclass
from typing import List, Callable
import networkx as nx
import pandas as pd
from graphrag import leiden
@ -26,8 +26,8 @@ from timeit import default_timer as timer
class CommunityReportsResult:
"""Community reports result class definition."""
output: List[str]
structured_output: List[dict]
output: list[str]
structured_output: list[dict]
class CommunityReportsExtractor:
@ -53,7 +53,7 @@ class CommunityReportsExtractor:
self._max_report_length = max_report_length or 1500
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
communities: dict[str, dict[str, List]] = leiden.run(graph, {})
communities: dict[str, dict[str, list]] = leiden.run(graph, {})
total = sum([len(comm.items()) for _, comm in communities.items()])
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
res_str = []

View File

@ -6,7 +6,6 @@ Reference:
"""
from typing import Any
import numpy as np
import networkx as nx
from graphrag.leiden import stable_largest_connected_component

View File

@ -9,8 +9,8 @@ import logging
import numbers
import re
import traceback
from typing import Any, Callable
from dataclasses import dataclass
from typing import Any, Mapping, Callable
import tiktoken
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str

View File

@ -18,7 +18,6 @@ import os
from concurrent.futures import ThreadPoolExecutor
import json
from functools import reduce
from typing import List
import networkx as nx
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
@ -53,7 +52,7 @@ def graph_merge(g1, g2):
return g
def build_knowledge_graph_chunks(tenant_id: str, chunks: List[str], callback, entity_types=DEFAULT_ENTITY_TYPES):
def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, entity_types=DEFAULT_ENTITY_TYPES):
_, tenant = TenantService.get_by_id(tenant_id)
llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id)
ext = GraphExtractor(llm_bdl)

View File

@ -6,8 +6,8 @@ Reference:
"""
import logging
from typing import Any, cast, List
import html
from typing import Any
from graspologic.partition import hierarchical_leiden
from graspologic.utils import largest_connected_component
@ -132,7 +132,7 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
return results_by_level
def add_community_info2graph(graph: nx.Graph, nodes: List[str], community_title):
def add_community_info2graph(graph: nx.Graph, nodes: list[str], community_title):
for n in nodes:
if "communities" not in graph.nodes[n]:
graph.nodes[n]["communities"] = []

View File

@ -19,9 +19,9 @@ import collections
import os
import re
import traceback
from typing import Any
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any
from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements

View File

@ -15,7 +15,6 @@
#
import json
from copy import deepcopy
from typing import Dict
import pandas as pd
from rag.utils.doc_store_conn import OrderByExpr, FusionExpr
@ -25,7 +24,7 @@ from rag.nlp.search import Dealer
class KGSearch(Dealer):
def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
def merge_into_first(sres, title="") -> Dict[str, str]:
def merge_into_first(sres, title="") -> dict[str, str]:
if not sres:
return {}
content_with_weight = ""

View File

@ -7,8 +7,7 @@ Reference:
import html
import re
from collections.abc import Callable
from typing import Any
from typing import Any, Callable
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]