From c58a1c48ebdce419722e55cefbbeb3f328e815f1 Mon Sep 17 00:00:00 2001 From: Wang Baoling Date: Fri, 31 May 2024 18:03:47 +0800 Subject: [PATCH] Fix: bug #991 (#1013) ### What problem does this PR solve? issue #991 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: KevinHuSh --- api/db/init_data.py | 2 +- rag/llm/rerank_model.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/api/db/init_data.py b/api/db/init_data.py index 36c25a793..988a953e0 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -386,7 +386,7 @@ def init_llm_factory(): "fid": factory_infos[7]["name"], "llm_name": "maidalun1020/bce-reranker-base_v1", "tags": "RE-RANK, 8K", - "max_tokens": 8196, + "max_tokens": 512, "model_type": LLMType.RERANK.value }, # ------------------------ DeepSeek ----------------------- diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 5def03519..bc84fa53c 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -113,4 +113,18 @@ class YoudaoRerank(DefaultRerank): YoudaoRerank._model = RerankerModel( model_name_or_path=model_name.replace( "maidalun1020", "InfiniFlow")) + + def similarity(self, query: str, texts: list): + pairs = [(query,truncate(t, self._model.max_length)) for t in texts] + token_count = 0 + for _, t in pairs: + token_count += num_tokens_from_string(t) + batch_size = 32 + res = [] + for i in range(0, len(pairs), batch_size): + scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length) + scores = sigmoid(np.array(scores)).tolist() + res.extend(scores) + return np.array(res), token_count +