mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add support for Baidu yiyan (#2049)
### What problem does this PR solve? add support for Baidu yiyan ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
@ -43,7 +43,8 @@ EmbeddingModel = {
|
||||
"PerfXCloud": PerfXCloudEmbed,
|
||||
"Upstage": UpstageEmbed,
|
||||
"SILICONFLOW": SILICONFLOWEmbed,
|
||||
"Replicate": ReplicateEmbed
|
||||
"Replicate": ReplicateEmbed,
|
||||
"BaiduYiyan": BaiduYiyanEmbed
|
||||
}
|
||||
|
||||
|
||||
@ -101,7 +102,8 @@ ChatModel = {
|
||||
"01.AI": YiChat,
|
||||
"Replicate": ReplicateChat,
|
||||
"Tencent Hunyuan": HunyuanChat,
|
||||
"XunFei Spark": SparkChat
|
||||
"XunFei Spark": SparkChat,
|
||||
"BaiduYiyan": BaiduYiyanChat
|
||||
}
|
||||
|
||||
|
||||
@ -115,7 +117,8 @@ RerankModel = {
|
||||
"OpenAI-API-Compatible": OpenAI_APIRerank,
|
||||
"cohere": CoHereRerank,
|
||||
"TogetherAI": TogetherAIRerank,
|
||||
"SILICONFLOW": SILICONFLOWRerank
|
||||
"SILICONFLOW": SILICONFLOWRerank,
|
||||
"BaiduYiyan": BaiduYiyanRerank
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1185,3 +1185,69 @@ class SparkChat(Base):
|
||||
}
|
||||
model_version = model2version[model_name]
|
||||
super().__init__(key, model_version, base_url)
|
||||
|
||||
|
||||
class BaiduYiyanChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import qianfan
|
||||
|
||||
key = json.loads(key)
|
||||
ak = key.get("yiyan_ak","")
|
||||
sk = key.get("yiyan_sk","")
|
||||
self.client = qianfan.ChatCompletion(ak=ak,sk=sk)
|
||||
self.model_name = model_name.lower()
|
||||
self.system = ""
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
self.system = system
|
||||
gen_conf["penalty_score"] = (
|
||||
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
|
||||
) + 1
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
|
||||
try:
|
||||
response = self.client.do(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
system=self.system,
|
||||
**gen_conf
|
||||
).body
|
||||
ans = response['result']
|
||||
return ans, response["usage"]["total_tokens"]
|
||||
|
||||
except Exception as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
self.system = system
|
||||
gen_conf["penalty_score"] = (
|
||||
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
|
||||
) + 1
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
|
||||
try:
|
||||
response = self.client.do(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
system=self.system,
|
||||
stream=True,
|
||||
**gen_conf
|
||||
)
|
||||
for resp in response:
|
||||
resp = resp.body
|
||||
ans += resp['result']
|
||||
total_tokens = resp["usage"]["total_tokens"]
|
||||
|
||||
yield ans
|
||||
|
||||
except Exception as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
yield total_tokens
|
||||
|
||||
@ -32,6 +32,7 @@ import asyncio
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
import google.generativeai as genai
|
||||
import json
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name):
|
||||
@ -591,11 +592,34 @@ class ReplicateEmbed(Base):
|
||||
self.client = Client(api_token=key)
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
from json import dumps
|
||||
|
||||
res = self.client.run(self.model_name, input={"texts": dumps(texts)})
|
||||
res = self.client.run(self.model_name, input={"texts": json.dumps(texts)})
|
||||
return np.array(res), sum([num_tokens_from_string(text) for text in texts])
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embed(self.model_name, input={"texts": [text]})
|
||||
return np.array(res), num_tokens_from_string(text)
|
||||
|
||||
|
||||
class BaiduYiyanEmbed(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import qianfan
|
||||
|
||||
key = json.loads(key)
|
||||
ak = key.get("yiyan_ak", "")
|
||||
sk = key.get("yiyan_sk", "")
|
||||
self.client = qianfan.Embedding(ak=ak, sk=sk)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
res = self.client.do(model=self.model_name, texts=texts).body
|
||||
return (
|
||||
np.array([r["embedding"] for r in res["data"]]),
|
||||
res["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.do(model=self.model_name, texts=[text]).body
|
||||
return (
|
||||
np.array([r["embedding"] for r in res["data"]]),
|
||||
res["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
@ -24,6 +24,7 @@ from abc import ABC
|
||||
import numpy as np
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
import json
|
||||
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + np.exp(-x))
|
||||
@ -288,3 +289,25 @@ class SILICONFLOWRerank(Base):
|
||||
rank[indexs],
|
||||
response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
|
||||
)
|
||||
|
||||
|
||||
class BaiduYiyanRerank(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from qianfan.resources import Reranker
|
||||
|
||||
key = json.loads(key)
|
||||
ak = key.get("yiyan_ak", "")
|
||||
sk = key.get("yiyan_sk", "")
|
||||
self.client = Reranker(ak=ak, sk=sk)
|
||||
self.model_name = model_name
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
res = self.client.do(
|
||||
model=self.model_name,
|
||||
query=query,
|
||||
documents=texts,
|
||||
top_n=len(texts),
|
||||
).body
|
||||
rank = np.array([d["relevance_score"] for d in res["results"]])
|
||||
indexs = [d["index"] for d in res["results"]]
|
||||
return rank[indexs], res["usage"]["total_tokens"]
|
||||
|
||||
Reference in New Issue
Block a user