diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 26607a122..783b62968 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -67,12 +67,12 @@ class DefaultRerank(Base): token_count = 0 for _, t in pairs: token_count += num_tokens_from_string(t) - batch_size = 32 + batch_size = 4096 res = [] for i in range(0, len(pairs), batch_size): scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048) - scores = sigmoid(np.array(scores)).tolist() - res.extend(scores) + if isinstance(scores, float): res.append(scores) + else: res.extend(scores) return np.array(res), token_count @@ -124,7 +124,9 @@ class YoudaoRerank(DefaultRerank): for i in range(0, len(pairs), batch_size): scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length) scores = sigmoid(np.array(scores)).tolist() + if isinstance(scores, float): res.append(scores) res.extend(scores) return np.array(res), token_count +