mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add support for Replicate (#1980)
### What problem does this PR solve? #1853 add support for Replicate ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
@ -42,7 +42,8 @@ EmbeddingModel = {
|
||||
"TogetherAI": TogetherAIEmbed,
|
||||
"PerfXCloud": PerfXCloudEmbed,
|
||||
"Upstage": UpstageEmbed,
|
||||
"SILICONFLOW": SILICONFLOWEmbed
|
||||
"SILICONFLOW": SILICONFLOWEmbed,
|
||||
"Replicate": ReplicateEmbed
|
||||
}
|
||||
|
||||
|
||||
@ -96,7 +97,8 @@ ChatModel = {
|
||||
"Upstage":UpstageChat,
|
||||
"novita.ai": NovitaAIChat,
|
||||
"SILICONFLOW": SILICONFLOWChat,
|
||||
"01.AI": YiChat
|
||||
"01.AI": YiChat,
|
||||
"Replicate": ReplicateChat
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1003,7 +1003,7 @@ class TogetherAIChat(Base):
|
||||
base_url = "https://api.together.xyz/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
|
||||
class PerfXCloudChat(Base):
|
||||
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
||||
if not base_url:
|
||||
@ -1036,4 +1036,55 @@ class YiChat(Base):
|
||||
def __init__(self, key, model_name, base_url="https://api.01.ai/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.01.ai/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class ReplicateChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from replicate.client import Client
|
||||
|
||||
self.model_name = model_name
|
||||
self.client = Client(api_token=key)
|
||||
self.system = ""
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
|
||||
if system:
|
||||
self.system = system
|
||||
prompt = "\n".join(
|
||||
[item["role"] + ":" + item["content"] for item in history[-5:]]
|
||||
)
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.run(
|
||||
self.model_name,
|
||||
input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
|
||||
)
|
||||
ans = "".join(response)
|
||||
return ans, num_tokens_from_string(ans)
|
||||
except Exception as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
|
||||
if system:
|
||||
self.system = system
|
||||
prompt = "\n".join(
|
||||
[item["role"] + ":" + item["content"] for item in history[-5:]]
|
||||
)
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.run(
|
||||
self.model_name,
|
||||
input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
|
||||
)
|
||||
for resp in response:
|
||||
ans += resp
|
||||
yield ans
|
||||
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield num_tokens_from_string(ans)
|
||||
|
||||
@ -561,7 +561,7 @@ class TogetherAIEmbed(OllamaEmbed):
|
||||
base_url = "https://api.together.xyz/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
|
||||
class PerfXCloudEmbed(OpenAIEmbed):
|
||||
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
||||
if not base_url:
|
||||
@ -580,4 +580,22 @@ class SILICONFLOWEmbed(OpenAIEmbed):
|
||||
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class ReplicateEmbed(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from replicate.client import Client
|
||||
|
||||
self.model_name = model_name
|
||||
self.client = Client(api_token=key)
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
from json import dumps
|
||||
|
||||
res = self.client.run(self.model_name, input={"texts": dumps(texts)})
|
||||
return np.array(res), sum([num_tokens_from_string(text) for text in texts])
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embed(self.model_name, input={"texts": [text]})
|
||||
return np.array(res), num_tokens_from_string(text)
|
||||
|
||||
Reference in New Issue
Block a user