mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-02 08:35:08 +08:00
refactor: introduce common normalize method in rerank base class (#12550)
### What problem does this PR solve? introduce common normalize method in rerank base class ### Type of change - [x] Refactoring
This commit is contained in:
@ -36,6 +36,22 @@ class Base(ABC):
|
|||||||
def similarity(self, query: str, texts: list):
|
def similarity(self, query: str, texts: list):
|
||||||
raise NotImplementedError("Please implement encode method!")
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_rank(rank: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize rank values to the range 0 to 1.
|
||||||
|
Avoids division by zero if all ranks are identical.
|
||||||
|
"""
|
||||||
|
min_rank = np.min(rank)
|
||||||
|
max_rank = np.max(rank)
|
||||||
|
|
||||||
|
if not np.isclose(min_rank, max_rank, atol=1e-3):
|
||||||
|
rank = (rank - min_rank) / (max_rank - min_rank)
|
||||||
|
else:
|
||||||
|
rank = np.zeros_like(rank)
|
||||||
|
|
||||||
|
return rank
|
||||||
|
|
||||||
|
|
||||||
class JinaRerank(Base):
|
class JinaRerank(Base):
|
||||||
_FACTORY_NAME = "Jina"
|
_FACTORY_NAME = "Jina"
|
||||||
@ -121,15 +137,7 @@ class LocalAIRerank(Base):
|
|||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, res)
|
log_exception(_e, res)
|
||||||
|
|
||||||
# Normalize the rank values to the range 0 to 1
|
rank = Base._normalize_rank(rank)
|
||||||
min_rank = np.min(rank)
|
|
||||||
max_rank = np.max(rank)
|
|
||||||
|
|
||||||
# Avoid division by zero if all ranks are identical
|
|
||||||
if not np.isclose(min_rank, max_rank, atol=1e-3):
|
|
||||||
rank = (rank - min_rank) / (max_rank - min_rank)
|
|
||||||
else:
|
|
||||||
rank = np.zeros_like(rank)
|
|
||||||
|
|
||||||
return rank, token_count
|
return rank, token_count
|
||||||
|
|
||||||
@ -215,15 +223,7 @@ class OpenAI_APIRerank(Base):
|
|||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, res)
|
log_exception(_e, res)
|
||||||
|
|
||||||
# Normalize the rank values to the range 0 to 1
|
rank = Base._normalize_rank(rank)
|
||||||
min_rank = np.min(rank)
|
|
||||||
max_rank = np.max(rank)
|
|
||||||
|
|
||||||
# Avoid division by zero if all ranks are identical
|
|
||||||
if not np.isclose(min_rank, max_rank, atol=1e-3):
|
|
||||||
rank = (rank - min_rank) / (max_rank - min_rank)
|
|
||||||
else:
|
|
||||||
rank = np.zeros_like(rank)
|
|
||||||
|
|
||||||
return rank, token_count
|
return rank, token_count
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user