Introduced beartype (#3460)

### What problem does this PR solve?

Introduced [beartype](https://github.com/beartype/beartype) for runtime
type-checking.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Zhichang Yu
2024-11-18 17:38:17 +08:00
committed by GitHub
parent 3824c1fec0
commit 4413683898
32 changed files with 125 additions and 134 deletions

View File

@ -15,7 +15,6 @@
#
import logging
import re
from typing import Optional
import threading
import requests
from huggingface_hub import snapshot_download
@ -242,10 +241,10 @@ class FastEmbed(Base):
def __init__(
self,
key: Optional[str] = None,
key: str | None = None,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
cache_dir: str | None = None,
threads: int | None = None,
**kwargs,
):
if not settings.LIGHTEN and not FastEmbed._model:

View File

@ -17,7 +17,6 @@
import logging
import re
import json
from typing import List, Optional, Dict, Union
from dataclasses import dataclass
from rag.utils import rmSpace
@ -37,13 +36,13 @@ class Dealer:
@dataclass
class SearchResult:
total: int
ids: List[str]
query_vector: List[float] = None
field: Optional[Dict] = None
highlight: Optional[Dict] = None
aggregation: Union[List, Dict, None] = None
keywords: Optional[List[str]] = None
group_docs: List[List] = None
ids: list[str]
query_vector: list[float] | None = None
field: dict | None = None
highlight: dict | None = None
aggregation: list | dict | None = None
keywords: list[str] | None = None
group_docs: list[list] | None = None
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
qv, _ = emb_mdl.encode_queries(txt)

View File

@ -17,7 +17,6 @@ import logging
import re
from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait
from threading import Lock
from typing import Tuple
import umap
import numpy as np
from sklearn.mixture import GaussianMixture
@ -45,7 +44,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
optimal_clusters = n_clusters[np.argmin(bics)]
return optimal_clusters
def __call__(self, chunks: Tuple[str, np.ndarray], random_state, callback=None):
def __call__(self, chunks: tuple[str, np.ndarray], random_state, callback=None):
layers = [(0, len(chunks))]
start, end = 0, len(chunks)
if len(chunks) <= 1: return

View File

@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from beartype.claw import beartype_packages
beartype_packages(["agent", "api", "deepdoc", "plugins", "rag", "ragflow_sdk"]) # <-- raise exceptions in your code
import logging
import sys
from api.utils.log_utils import initRootLogger

View File

@ -1,19 +1,17 @@
from abc import ABC, abstractmethod
from typing import Optional, Union
from dataclasses import dataclass
import numpy as np
import polars as pl
from typing import List, Dict
DEFAULT_MATCH_VECTOR_TOPN = 10
DEFAULT_MATCH_SPARSE_TOPN = 10
VEC = Union[list, np.ndarray]
VEC = list | np.ndarray
@dataclass
class SparseVector:
indices: list[int]
values: Union[list[float], list[int], None] = None
values: list[float] | list[int] | None = None
def __post_init__(self):
assert (self.values is None) or (len(self.indices) == len(self.values))
@ -82,7 +80,7 @@ class MatchSparseExpr(ABC):
sparse_data: SparseVector | dict,
distance_type: str,
topn: int,
opt_params: Optional[dict] = None,
opt_params: dict | None = None,
):
self.vector_column_name = vector_column_name
self.sparse_data = sparse_data
@ -98,7 +96,7 @@ class MatchTensorExpr(ABC):
query_data: VEC,
query_data_type: str,
topn: int,
extra_option: Optional[dict] = None,
extra_option: dict | None = None,
):
self.column_name = column_name
self.query_data = query_data
@ -108,16 +106,13 @@ class MatchTensorExpr(ABC):
class FusionExpr(ABC):
def __init__(self, method: str, topn: int, fusion_params: Optional[dict] = None):
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
self.method = method
self.topn = topn
self.fusion_params = fusion_params
MatchExpr = Union[
MatchTextExpr, MatchDenseExpr, MatchSparseExpr, MatchTensorExpr, FusionExpr
]
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
class OrderByExpr(ABC):
def __init__(self):
@ -229,11 +224,11 @@ class DocStoreConnection(ABC):
raise NotImplementedError("Not implemented")
@abstractmethod
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
@abstractmethod
def getHighlight(self, res, keywords: List[str], fieldnm: str):
def getHighlight(self, res, keywords: list[str], fieldnm: str):
raise NotImplementedError("Not implemented")
@abstractmethod

View File

@ -3,7 +3,6 @@ import re
import json
import time
import os
from typing import List, Dict
import copy
from elasticsearch import Elasticsearch
@ -363,7 +362,7 @@ class ESConnection(DocStoreConnection):
rr.append(d["_source"])
return rr
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
@ -382,7 +381,7 @@ class ESConnection(DocStoreConnection):
res_fields[d["id"]] = m
return res_fields
def getHighlight(self, res, keywords: List[str], fieldnm: str):
def getHighlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")

View File

@ -3,7 +3,6 @@ import os
import re
import json
import time
from typing import List, Dict
import infinity
from infinity.common import ConflictType, InfinityException
from infinity.index import IndexInfo, IndexType
@ -384,7 +383,7 @@ class InfinityConnection(DocStoreConnection):
def getChunkIds(self, res):
return list(res["id"])
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
def getFields(self, res, fields: list[str]) -> list[str, dict]:
res_fields = {}
if not fields:
return {}
@ -412,7 +411,7 @@ class InfinityConnection(DocStoreConnection):
res_fields[id] = m
return res_fields
def getHighlight(self, res, keywords: List[str], fieldnm: str):
def getHighlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
num_rows = len(res)
column_id = res["id"]