add support for Baidu yiyan (#2049)

### What problem does this PR solve?

add support for Baidu yiyan

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
黄腾
2024-08-22 16:45:15 +08:00
committed by GitHub
parent 21f2c5838b
commit 733219cc3f
17 changed files with 307 additions and 13 deletions

View File

@ -32,6 +32,7 @@ import asyncio
from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate
import google.generativeai as genai
import json
class Base(ABC):
def __init__(self, key, model_name):
@ -591,11 +592,34 @@ class ReplicateEmbed(Base):
self.client = Client(api_token=key)
def encode(self, texts: list, batch_size=32):
from json import dumps
res = self.client.run(self.model_name, input={"texts": dumps(texts)})
res = self.client.run(self.model_name, input={"texts": json.dumps(texts)})
return np.array(res), sum([num_tokens_from_string(text) for text in texts])
def encode_queries(self, text):
res = self.client.embed(self.model_name, input={"texts": [text]})
return np.array(res), num_tokens_from_string(text)
class BaiduYiyanEmbed(Base):
def __init__(self, key, model_name, base_url=None):
import qianfan
key = json.loads(key)
ak = key.get("yiyan_ak", "")
sk = key.get("yiyan_sk", "")
self.client = qianfan.Embedding(ak=ak, sk=sk)
self.model_name = model_name
def encode(self, texts: list, batch_size=32):
res = self.client.do(model=self.model_name, texts=texts).body
return (
np.array([r["embedding"] for r in res["data"]]),
res["usage"]["total_tokens"],
)
def encode_queries(self, text):
res = self.client.do(model=self.model_name, texts=[text]).body
return (
np.array([r["embedding"] for r in res["data"]]),
res["usage"]["total_tokens"],
)