add support for Replicate (#1980)

### What problem does this PR solve?

#1853  add support for Replicate

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
黄腾
2024-08-19 10:36:57 +08:00
committed by GitHub
parent be5a67895e
commit 79426fc41f
10 changed files with 94 additions and 12 deletions

View File

@ -561,7 +561,7 @@ class TogetherAIEmbed(OllamaEmbed):
base_url = "https://api.together.xyz/v1"
super().__init__(key, model_name, base_url)
class PerfXCloudEmbed(OpenAIEmbed):
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
if not base_url:
@ -580,4 +580,22 @@ class SILICONFLOWEmbed(OpenAIEmbed):
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
super().__init__(key, model_name, base_url)
super().__init__(key, model_name, base_url)
class ReplicateEmbed(Base):
def __init__(self, key, model_name, base_url=None):
from replicate.client import Client
self.model_name = model_name
self.client = Client(api_token=key)
def encode(self, texts: list, batch_size=32):
from json import dumps
res = self.client.run(self.model_name, input={"texts": dumps(texts)})
return np.array(res), sum([num_tokens_from_string(text) for text in texts])
def encode_queries(self, text):
res = self.client.embed(self.model_name, input={"texts": [text]})
return np.array(res), num_tokens_from_string(text)