add support for SILICONFLOW (#1926)

### What problem does this PR solve?

#1853 add support for SILICONFLOW

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
黄腾
2024-08-13 16:09:10 +08:00
committed by GitHub
parent 06700850df
commit e013ac52af
7 changed files with 349 additions and 7 deletions

View File

@ -41,7 +41,8 @@ EmbeddingModel = {
"cohere": CoHereEmbed,
"TogetherAI": TogetherAIEmbed,
"PerfXCloud": PerfXCloudEmbed,
"Upstage": UpstageEmbed
"Upstage": UpstageEmbed,
"SILICONFLOW": SILICONFLOWEmbed
}
@ -92,7 +93,8 @@ ChatModel = {
"TogetherAI": TogetherAIChat,
"PerfXCloud": PerfXCloudChat,
"Upstage":UpstageChat,
"novita.ai": NovitaAIChat
"novita.ai": NovitaAIChat,
"SILICONFLOW": SILICONFLOWChat
}
@ -105,7 +107,8 @@ RerankModel = {
"LM-Studio": LmStudioRerank,
"OpenAI-API-Compatible": OpenAI_APIRerank,
"cohere": CoHereRerank,
"TogetherAI": TogetherAIRerank
"TogetherAI": TogetherAIRerank,
"SILICONFLOW": SILICONFLOWRerank
}

View File

@ -1016,4 +1016,10 @@ class NovitaAIChat(Base):
if not base_url:
base_url = "https://api.novita.ai/v3/openai"
super().__init__(key, model_name, base_url)
class SILICONFLOWChat(Base):
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
super().__init__(key, model_name, base_url)

View File

@ -574,3 +574,10 @@ class UpstageEmbed(OpenAIEmbed):
if not base_url:
base_url = "https://api.upstage.ai/v1/solar"
super().__init__(key, model_name, base_url)
class SILICONFLOWEmbed(OpenAIEmbed):
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
super().__init__(key, model_name, base_url)

View File

@ -252,4 +252,39 @@ class TogetherAIRerank(Base):
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("The api has not been implement")
raise NotImplementedError("The api has not been implement")
class SILICONFLOWRerank(Base):
def __init__(
self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
):
if not base_url:
base_url = "https://api.siliconflow.cn/v1/rerank"
self.model_name = model_name
self.base_url = base_url
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {key}",
}
def similarity(self, query: str, texts: list):
payload = {
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts),
"return_documents": False,
"max_chunks_per_doc": 1024,
"overlap_tokens": 80,
}
response = requests.post(
self.base_url, json=payload, headers=self.headers
).json()
rank = np.array([d["relevance_score"] for d in response["results"]])
indexs = [d["index"] for d in response["results"]]
return (
rank[indexs],
response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
)