diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 52656e46e..16a415bf7 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -104,10 +104,13 @@ class DefaultEmbedding(Base): token_count = 0 for t in texts: token_count += num_tokens_from_string(t) - ress = [] + ress = None for i in range(0, len(texts), batch_size): - ress.extend(self._model.encode(texts[i : i + batch_size]).tolist()) - return np.array(ress), token_count + if ress is None: + ress = self._model.encode(texts[i : i + batch_size], convert_to_numpy=True) + else: + ress = np.concatenate((ress, self._model.encode(texts[i : i + batch_size], convert_to_numpy=True)), axis=0) + return ress, token_count def encode_queries(self, text: str): token_count = num_tokens_from_string(text)