mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix:disallowed special token while embedding (#8692)
### What problem does this PR solve? https://github.com/infiniflow/ragflow/issues/8567 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -273,6 +273,8 @@ class ZhipuEmbed(Base):
|
||||
class OllamaEmbed(Base):
|
||||
_FACTORY_NAME = "Ollama"
|
||||
|
||||
_special_tokens = ["<|endoftext|>"]
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
|
||||
self.model_name = model_name
|
||||
@ -281,6 +283,9 @@ class OllamaEmbed(Base):
|
||||
arr = []
|
||||
tks_num = 0
|
||||
for txt in texts:
|
||||
# remove special tokens if they exist
|
||||
for token in OllamaEmbed._special_tokens:
|
||||
txt = txt.replace(token, "")
|
||||
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True})
|
||||
try:
|
||||
arr.append(res["embedding"])
|
||||
@ -290,6 +295,9 @@ class OllamaEmbed(Base):
|
||||
return np.array(arr), tks_num
|
||||
|
||||
def encode_queries(self, text):
|
||||
# remove special tokens if they exist
|
||||
for token in OllamaEmbed._special_tokens:
|
||||
text = text.replace(token, "")
|
||||
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True})
|
||||
try:
|
||||
return np.array(res["embedding"]), 128
|
||||
|
||||
Reference in New Issue
Block a user