mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refactor improve codes for ranker (#8936)
### What problem does this PR solve? Use the normalize method directly ### Type of change - [x] Refactoring
This commit is contained in:
@ -32,11 +32,6 @@ from api.utils.file_utils import get_home_cache_dir
|
|||||||
from api.utils.log_utils import log_exception
|
from api.utils.log_utils import log_exception
|
||||||
from rag.utils import num_tokens_from_string, truncate
|
from rag.utils import num_tokens_from_string, truncate
|
||||||
|
|
||||||
|
|
||||||
def sigmoid(x):
|
|
||||||
return 1 / (1 + np.exp(-x))
|
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
def __init__(self, key, model_name):
|
def __init__(self, key, model_name):
|
||||||
pass
|
pass
|
||||||
@ -133,10 +128,9 @@ class DefaultRerank(Base):
|
|||||||
|
|
||||||
def _compute_batch_scores(self, batch_pairs, max_length=None):
|
def _compute_batch_scores(self, batch_pairs, max_length=None):
|
||||||
if max_length is None:
|
if max_length is None:
|
||||||
scores = self._model.compute_score(batch_pairs)
|
scores = self._model.compute_score(batch_pairs, normalize=True)
|
||||||
else:
|
else:
|
||||||
scores = self._model.compute_score(batch_pairs, max_length=max_length)
|
scores = self._model.compute_score(batch_pairs, max_length=max_length, normalize=True)
|
||||||
scores = sigmoid(np.array(scores))
|
|
||||||
if not isinstance(scores, Iterable):
|
if not isinstance(scores, Iterable):
|
||||||
scores = [scores]
|
scores = [scores]
|
||||||
return scores
|
return scores
|
||||||
|
|||||||
Reference in New Issue
Block a user