add supprot for OpenAI-API-Compatible llm (#1787)

### What problem does this PR solve?

#1771  add supprot for OpenAI-API-Compatible 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
黄腾
2024-08-06 16:20:21 +08:00
committed by GitHub
parent 66e4113e0b
commit b67484e77d
12 changed files with 74 additions and 11 deletions

View File

@ -36,7 +36,8 @@ EmbeddingModel = {
"Bedrock": BedrockEmbed,
"Gemini": GeminiEmbed,
"NVIDIA": NvidiaEmbed,
"LM-Studio": LmStudioEmbed
"LM-Studio": LmStudioEmbed,
"OpenAI-API-Compatible": OpenAI_APIEmbed
}
@ -53,7 +54,8 @@ CvModel = {
"LocalAI": LocalAICV,
"NVIDIA": NvidiaCV,
"LM-Studio": LmStudioCV,
"StepFun":StepFunCV
"StepFun":StepFunCV,
"OpenAI-API-Compatible": OpenAI_APICV
}
@ -78,7 +80,8 @@ ChatModel = {
"OpenRouter": OpenRouterChat,
"StepFun": StepFunChat,
"NVIDIA": NvidiaChat,
"LM-Studio": LmStudioChat
"LM-Studio": LmStudioChat,
"OpenAI-API-Compatible": OpenAI_APIChat
}
@ -88,7 +91,8 @@ RerankModel = {
"Youdao": YoudaoRerank,
"Xinference": XInferenceRerank,
"NVIDIA": NvidiaRerank,
"LM-Studio": LmStudioRerank
"LM-Studio": LmStudioRerank,
"OpenAI-API-Compatible": OpenAI_APIRerank
}

View File

@ -887,6 +887,16 @@ class LmStudioChat(Base):
if not base_url:
raise ValueError("Local llm url cannot be None")
if base_url.split("/")[-1] != "v1":
self.base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
self.model_name = model_name
class OpenAI_APIChat(Base):
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
model_name = model_name.split("___")[0]
super().__init__(key, model_name, base_url)

View File

@ -638,3 +638,14 @@ class LmStudioCV(GptV4):
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
self.model_name = model_name
self.lang = lang
class OpenAI_APICV(GptV4):
def __init__(self, key, model_name, base_url, lang="Chinese"):
if not base_url:
raise ValueError("url cannot be None")
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]
self.lang = lang

View File

@ -513,3 +513,13 @@ class LmStudioEmbed(LocalAIEmbed):
self.base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
self.model_name = model_name
class OpenAI_APIEmbed(OpenAIEmbed):
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
if base_url.split("/")[-1] != "v1":
self.base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]

View File

@ -212,3 +212,11 @@ class LmStudioRerank(Base):
def similarity(self, query: str, texts: list):
raise NotImplementedError("The LmStudioRerank has not been implement")
class OpenAI_APIRerank(Base):
def __init__(self, key, model_name, base_url):
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("The api has not been implement")