### What problem does this PR solve?

### Type of change

- [x] Refactoring
This commit is contained in:
KevinHuSh
2024-04-25 14:14:28 +08:00
committed by GitHub
parent cf9b554c3a
commit 66f8d35632
14 changed files with 124 additions and 34 deletions

View File

@ -229,19 +229,19 @@ class XinferenceEmbed(Base):
return np.array(res.data[0].embedding), res.usage.total_tokens
class QAnythingEmbed(Base):
class YoudaoEmbed(Base):
_client = None
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
from BCEmbedding import EmbeddingModel as qanthing
if not QAnythingEmbed._client:
if not YoudaoEmbed._client:
try:
print("LOADING BCE...")
QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join(
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
get_project_base_directory(),
"rag/res/bce-embedding-base_v1"))
except Exception as e:
QAnythingEmbed._client = qanthing(
YoudaoEmbed._client = qanthing(
model_name_or_path=model_name.replace(
"maidalun1020", "InfiniFlow"))
@ -251,10 +251,10 @@ class QAnythingEmbed(Base):
for t in texts:
token_count += num_tokens_from_string(t)
for i in range(0, len(texts), batch_size):
embds = QAnythingEmbed._client.encode(texts[i:i + batch_size])
embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
res.extend(embds)
return np.array(res), token_count
def encode_queries(self, text):
embds = QAnythingEmbed._client.encode([text])
embds = YoudaoEmbed._client.encode([text])
return np.array(embds[0]), num_tokens_from_string(text)