mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Integration with Infinity (#2894)
### What problem does this PR solve? Integration with Infinity - Replaced ELASTICSEARCH with dataStoreConn - Renamed deleteByQuery with delete - Renamed bulk to upsertBulk - getHighlight, getAggregation - Fix KGSearch.search - Moved Dealer.sql_retrieval to es_conn.py ### Type of change - [x] Refactoring
This commit is contained in:
251
rag/utils/doc_store_conn.py
Normal file
251
rag/utils/doc_store_conn.py
Normal file
@ -0,0 +1,251 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from typing import List, Dict
|
||||
|
||||
DEFAULT_MATCH_VECTOR_TOPN = 10
|
||||
DEFAULT_MATCH_SPARSE_TOPN = 10
|
||||
VEC = Union[list, np.ndarray]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseVector:
|
||||
indices: list[int]
|
||||
values: Union[list[float], list[int], None] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert (self.values is None) or (len(self.indices) == len(self.values))
|
||||
|
||||
def to_dict_old(self):
|
||||
d = {"indices": self.indices}
|
||||
if self.values is not None:
|
||||
d["values"] = self.values
|
||||
return d
|
||||
|
||||
def to_dict(self):
|
||||
if self.values is None:
|
||||
raise ValueError("SparseVector.values is None")
|
||||
result = {}
|
||||
for i, v in zip(self.indices, self.values):
|
||||
result[str(i)] = v
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d):
|
||||
return SparseVector(d["indices"], d.get("values"))
|
||||
|
||||
def __str__(self):
|
||||
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class MatchTextExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
fields: str,
|
||||
matching_text: str,
|
||||
topn: int,
|
||||
extra_options: dict = dict(),
|
||||
):
|
||||
self.fields = fields
|
||||
self.matching_text = matching_text
|
||||
self.topn = topn
|
||||
self.extra_options = extra_options
|
||||
|
||||
|
||||
class MatchDenseExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
vector_column_name: str,
|
||||
embedding_data: VEC,
|
||||
embedding_data_type: str,
|
||||
distance_type: str,
|
||||
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
|
||||
extra_options: dict = dict(),
|
||||
):
|
||||
self.vector_column_name = vector_column_name
|
||||
self.embedding_data = embedding_data
|
||||
self.embedding_data_type = embedding_data_type
|
||||
self.distance_type = distance_type
|
||||
self.topn = topn
|
||||
self.extra_options = extra_options
|
||||
|
||||
|
||||
class MatchSparseExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
vector_column_name: str,
|
||||
sparse_data: SparseVector | dict,
|
||||
distance_type: str,
|
||||
topn: int,
|
||||
opt_params: Optional[dict] = None,
|
||||
):
|
||||
self.vector_column_name = vector_column_name
|
||||
self.sparse_data = sparse_data
|
||||
self.distance_type = distance_type
|
||||
self.topn = topn
|
||||
self.opt_params = opt_params
|
||||
|
||||
|
||||
class MatchTensorExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
column_name: str,
|
||||
query_data: VEC,
|
||||
query_data_type: str,
|
||||
topn: int,
|
||||
extra_option: Optional[dict] = None,
|
||||
):
|
||||
self.column_name = column_name
|
||||
self.query_data = query_data
|
||||
self.query_data_type = query_data_type
|
||||
self.topn = topn
|
||||
self.extra_option = extra_option
|
||||
|
||||
|
||||
class FusionExpr(ABC):
|
||||
def __init__(self, method: str, topn: int, fusion_params: Optional[dict] = None):
|
||||
self.method = method
|
||||
self.topn = topn
|
||||
self.fusion_params = fusion_params
|
||||
|
||||
|
||||
MatchExpr = Union[
|
||||
MatchTextExpr, MatchDenseExpr, MatchSparseExpr, MatchTensorExpr, FusionExpr
|
||||
]
|
||||
|
||||
|
||||
class OrderByExpr(ABC):
|
||||
def __init__(self):
|
||||
self.fields = list()
|
||||
def asc(self, field: str):
|
||||
self.fields.append((field, 0))
|
||||
return self
|
||||
def desc(self, field: str):
|
||||
self.fields.append((field, 1))
|
||||
return self
|
||||
def fields(self):
|
||||
return self.fields
|
||||
|
||||
class DocStoreConnection(ABC):
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def dbType(self) -> str:
|
||||
"""
|
||||
Return the type of the database.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def health(self) -> dict:
|
||||
"""
|
||||
Return the health status of the database.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
"""
|
||||
Create an index with given name
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
"""
|
||||
Delete an index with given name
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||
"""
|
||||
Check if an index with given name exists
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]
|
||||
) -> list[dict] | pl.DataFrame:
|
||||
"""
|
||||
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
"""
|
||||
Get single chunk with given id
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
|
||||
"""
|
||||
Update or insert a bulk of rows
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
"""
|
||||
Update rows with given conjunctive equivalent filtering condition
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
"""
|
||||
Delete rows with given conjunctive equivalent filtering condition
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def getTotal(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getChunkIds(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getHighlight(self, res, keywords: List[str], fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
@abstractmethod
|
||||
def sql(sql: str, fetch_size: int, format: str):
|
||||
"""
|
||||
Run the sql generated by text-to-sql
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
Reference in New Issue
Block a user