diff --git a/api/db/init_data.py b/api/db/init_data.py index 36c25a793..988a953e0 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -386,7 +386,7 @@ def init_llm_factory(): "fid": factory_infos[7]["name"], "llm_name": "maidalun1020/bce-reranker-base_v1", "tags": "RE-RANK, 8K", - "max_tokens": 8196, + "max_tokens": 512, "model_type": LLMType.RERANK.value }, # ------------------------ DeepSeek ----------------------- diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 5def03519..bc84fa53c 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -113,4 +113,18 @@ class YoudaoRerank(DefaultRerank): YoudaoRerank._model = RerankerModel( model_name_or_path=model_name.replace( "maidalun1020", "InfiniFlow")) + + def similarity(self, query: str, texts: list): + pairs = [(query,truncate(t, self._model.max_length)) for t in texts] + token_count = 0 + for _, t in pairs: + token_count += num_tokens_from_string(t) + batch_size = 32 + res = [] + for i in range(0, len(pairs), batch_size): + scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length) + scores = sigmoid(np.array(scores)).tolist() + res.extend(scores) + return np.array(res), token_count +