mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add primitive support for function calls (#6840)
### What problem does this PR solve? This PR introduces **primitive support for function calls**, enabling the system to handle basic function call capabilities. However, this feature is currently experimental and **not yet enabled for general use**, as it is only supported by a subset of models, namely, Qwen and OpenAI models. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -59,6 +59,7 @@ class Base(ABC):
|
||||
# Configure retry parameters
|
||||
self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5))
|
||||
self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0))
|
||||
self.is_tools = False
|
||||
|
||||
def _get_delay(self, attempt):
|
||||
"""Calculate retry delay time"""
|
||||
@ -89,6 +90,91 @@ class Base(ABC):
|
||||
else:
|
||||
return ERROR_GENERIC
|
||||
|
||||
def bind_tools(self, toolcall_session, tools):
|
||||
if not (toolcall_session and tools):
|
||||
return
|
||||
self.is_tools = True
|
||||
self.toolcall_session = toolcall_session
|
||||
self.tools = tools
|
||||
|
||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
|
||||
tools = self.tools
|
||||
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
|
||||
|
||||
assistant_output = response.choices[0].message
|
||||
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
|
||||
ans += "<think>" + ans + "</think>"
|
||||
ans += response.choices[0].message.content
|
||||
|
||||
if not response.choices[0].message.tool_calls:
|
||||
tk_count += self.total_token_count(response)
|
||||
if response.choices[0].finish_reason == "length":
|
||||
if is_chinese([ans]):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, tk_count
|
||||
|
||||
tk_count += self.total_token_count(response)
|
||||
history.append(assistant_output)
|
||||
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
name = tool_call.function.name
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
# if tool_response.choices[0].finish_reason == "length":
|
||||
# if is_chinese(ans):
|
||||
# ans += LENGTH_NOTIFICATION_CN
|
||||
# else:
|
||||
# ans += LENGTH_NOTIFICATION_EN
|
||||
# return ans, tk_count + self.total_token_count(tool_response)
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
||||
|
||||
final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
|
||||
assistant_output = final_response.choices[0].message
|
||||
if "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
|
||||
ans += "<think>" + ans + "</think>"
|
||||
ans += final_response.choices[0].message.content
|
||||
if final_response.choices[0].finish_reason == "length":
|
||||
tk_count += self.total_token_count(response)
|
||||
if is_chinese([ans]):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, tk_count
|
||||
return ans, tk_count
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("OpenAI cat_with_tools")
|
||||
# Classify the error
|
||||
error_code = self._classify_error(e)
|
||||
|
||||
# Check if it's a rate limit error or server error and not the last attempt
|
||||
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1
|
||||
|
||||
if should_retry:
|
||||
delay = self._get_delay(attempt)
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
else:
|
||||
# For non-rate limit errors or the last attempt, return an error message
|
||||
if attempt == self.max_retries - 1:
|
||||
error_code = ERROR_MAX_RETRIES
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
@ -127,6 +213,127 @@ class Base(ABC):
|
||||
error_code = ERROR_MAX_RETRIES
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}. response: {response}", 0
|
||||
|
||||
def _wrap_toolcall_message(self, stream):
|
||||
final_tool_calls = {}
|
||||
|
||||
for chunk in stream:
|
||||
for tool_call in chunk.choices[0].delta.tool_calls or []:
|
||||
index = tool_call.index
|
||||
|
||||
if index not in final_tool_calls:
|
||||
final_tool_calls[index] = tool_call
|
||||
|
||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
||||
|
||||
return final_tool_calls
|
||||
|
||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
|
||||
tools = self.tools
|
||||
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
reasoning_start = False
|
||||
finish_completion = False
|
||||
final_tool_calls = {}
|
||||
try:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||
while not finish_completion:
|
||||
for resp in response:
|
||||
if resp.choices[0].delta.tool_calls:
|
||||
for tool_call in resp.choices[0].delta.tool_calls or []:
|
||||
index = tool_call.index
|
||||
|
||||
if index not in final_tool_calls:
|
||||
final_tool_calls[index] = tool_call
|
||||
|
||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
||||
if resp.choices[0].finish_reason != "stop":
|
||||
continue
|
||||
else:
|
||||
if not resp.choices:
|
||||
continue
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||
else:
|
||||
reasoning_start = False
|
||||
ans = resp.choices[0].delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
total_tokens += tol
|
||||
|
||||
finish_reason = resp.choices[0].finish_reason
|
||||
if finish_reason == "tool_calls" and final_tool_calls:
|
||||
for tool_call in final_tool_calls.values():
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
if name == "get_current_weather":
|
||||
args = json.loads('{"location":"Shanghai"}')
|
||||
else:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except Exception:
|
||||
continue
|
||||
# args = json.loads(tool_call.function.arguments)
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"refusal": "",
|
||||
"content": "",
|
||||
"audio": "",
|
||||
"function_call": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": tool_call.index,
|
||||
"id": tool_call.id,
|
||||
"function": tool_call.function,
|
||||
"type": "function",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
# if tool_response.choices[0].finish_reason == "length":
|
||||
# if is_chinese(ans):
|
||||
# ans += LENGTH_NOTIFICATION_CN
|
||||
# else:
|
||||
# ans += LENGTH_NOTIFICATION_EN
|
||||
# return ans, total_tokens + self.total_token_count(tool_response)
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
||||
final_tool_calls = {}
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||
continue
|
||||
if finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, total_tokens + self.total_token_count(resp)
|
||||
if finish_reason == "stop":
|
||||
finish_completion = True
|
||||
yield ans
|
||||
break
|
||||
yield ans
|
||||
continue
|
||||
|
||||
except openai.APIError as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
@ -156,7 +363,7 @@ class Base(ABC):
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
total_tokens = tol
|
||||
total_tokens += tol
|
||||
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
@ -180,9 +387,10 @@ class Base(ABC):
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
def _calculate_dynamic_ctx(self, history):
|
||||
"""Calculate dynamic context window size"""
|
||||
|
||||
def count_tokens(text):
|
||||
"""Calculate token count for text"""
|
||||
# Simple calculation: 1 token per ASCII character
|
||||
@ -207,15 +415,16 @@ class Base(ABC):
|
||||
|
||||
# Apply 1.2x buffer ratio
|
||||
total_tokens_with_buffer = int(total_tokens * 1.2)
|
||||
|
||||
|
||||
if total_tokens_with_buffer <= 8192:
|
||||
ctx_size = 8192
|
||||
else:
|
||||
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
|
||||
ctx_size = ctx_multiplier * 8192
|
||||
|
||||
|
||||
return ctx_size
|
||||
|
||||
|
||||
class GptTurbo(Base):
|
||||
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
@ -350,6 +559,8 @@ class BaiChuanChat(Base):
|
||||
|
||||
class QWenChat(Base):
|
||||
def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
import dashscope
|
||||
|
||||
dashscope.api_key = key
|
||||
@ -357,6 +568,78 @@ class QWenChat(Base):
|
||||
if self.is_reasoning_model(self.model_name):
|
||||
super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
|
||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]:
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
# if self.is_reasoning_model(self.model_name):
|
||||
# return super().chat(system, history, gen_conf)
|
||||
|
||||
stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true"
|
||||
if not stream_flag:
|
||||
from http import HTTPStatus
|
||||
|
||||
tools = self.tools
|
||||
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
response = Generation.call(self.model_name, messages=history, result_format="message", tools=tools, **gen_conf)
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
assistant_output = response.output.choices[0].message
|
||||
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
|
||||
ans += "<think>" + ans + "</think>"
|
||||
ans += response.output.choices[0].message.content
|
||||
|
||||
if "tool_calls" not in assistant_output:
|
||||
tk_count += self.total_token_count(response)
|
||||
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||
if is_chinese([ans]):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, tk_count
|
||||
|
||||
tk_count += self.total_token_count(response)
|
||||
history.append(assistant_output)
|
||||
|
||||
while "tool_calls" in assistant_output:
|
||||
tool_info = {"content": "", "role": "tool", "tool_call_id": assistant_output.tool_calls[0]["id"]}
|
||||
tool_name = assistant_output.tool_calls[0]["function"]["name"]
|
||||
if tool_name:
|
||||
arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"])
|
||||
tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=arguments)
|
||||
history.append(tool_info)
|
||||
|
||||
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf)
|
||||
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||
tk_count += self.total_token_count(response)
|
||||
if is_chinese([ans]):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, tk_count
|
||||
|
||||
tk_count += self.total_token_count(response)
|
||||
assistant_output = response.output.choices[0].message
|
||||
if assistant_output.content is None:
|
||||
assistant_output.content = ""
|
||||
history.append(response)
|
||||
ans += assistant_output["content"]
|
||||
return ans, tk_count
|
||||
else:
|
||||
return "**ERROR**: " + response.message, tk_count
|
||||
else:
|
||||
result_list = []
|
||||
for result in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=True):
|
||||
result_list.append(result)
|
||||
error_msg_list = [result for result in result_list if str(result).find("**ERROR**") >= 0]
|
||||
if len(error_msg_list) > 0:
|
||||
return "**ERROR**: " + "".join(error_msg_list), 0
|
||||
else:
|
||||
return "".join(result_list[:-1]), result_list[-1]
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
@ -393,6 +676,99 @@ class QWenChat(Base):
|
||||
else:
|
||||
return "".join(result_list[:-1]), result_list[-1]
|
||||
|
||||
def _wrap_toolcall_message(self, old_message, message):
|
||||
if not old_message:
|
||||
return message
|
||||
tool_call_id = message["tool_calls"][0].get("id")
|
||||
if tool_call_id:
|
||||
old_message.tool_calls[0]["id"] = tool_call_id
|
||||
function = message.tool_calls[0]["function"]
|
||||
if function:
|
||||
if function.get("name"):
|
||||
old_message.tool_calls[0]["function"]["name"] = function["name"]
|
||||
if function.get("arguments"):
|
||||
old_message.tool_calls[0]["function"]["arguments"] += function["arguments"]
|
||||
return old_message
|
||||
|
||||
def _chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
|
||||
from http import HTTPStatus
|
||||
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
|
||||
tool_info = {"content": "", "role": "tool"}
|
||||
toolcall_message = None
|
||||
tool_name = ""
|
||||
tool_arguments = ""
|
||||
finish_completion = False
|
||||
reasoning_start = False
|
||||
while not finish_completion:
|
||||
for resp in response:
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
assistant_output = resp.output.choices[0].message
|
||||
ans = resp.output.choices[0].message.content
|
||||
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
|
||||
ans = resp.output.choices[0].message.reasoning_content
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>" + ans
|
||||
else:
|
||||
ans = ans + "</think>"
|
||||
|
||||
if "tool_calls" not in assistant_output:
|
||||
reasoning_start = False
|
||||
tk_count += self.total_token_count(resp)
|
||||
if resp.output.choices[0].get("finish_reason", "") == "length":
|
||||
if is_chinese([ans]):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
finish_reason = resp.output.choices[0]["finish_reason"]
|
||||
if finish_reason == "stop":
|
||||
finish_completion = True
|
||||
yield ans
|
||||
break
|
||||
yield ans
|
||||
continue
|
||||
|
||||
tk_count += self.total_token_count(resp)
|
||||
toolcall_message = self._wrap_toolcall_message(toolcall_message, assistant_output)
|
||||
if "tool_calls" in assistant_output:
|
||||
tool_call_finish_reason = resp.output.choices[0]["finish_reason"]
|
||||
if tool_call_finish_reason == "tool_calls":
|
||||
try:
|
||||
tool_arguments = json.loads(toolcall_message.tool_calls[0]["function"]["arguments"])
|
||||
except Exception as e:
|
||||
logging.exception(msg="_chat_streamly_with_tool tool call error")
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
finish_completion = True
|
||||
break
|
||||
|
||||
tool_name = toolcall_message.tool_calls[0]["function"]["name"]
|
||||
history.append(toolcall_message)
|
||||
tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=tool_arguments)
|
||||
history.append(tool_info)
|
||||
tool_info = {"content": "", "role": "tool"}
|
||||
tool_name = ""
|
||||
tool_arguments = ""
|
||||
toolcall_message = None
|
||||
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
|
||||
else:
|
||||
yield (
|
||||
ans + "\n**ERROR**: " + resp.output.choices[0].message
|
||||
if not re.search(r" (key|quota)", str(resp.message).lower())
|
||||
else "Out of credit. Please set the API key in **settings > Model providers.**"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(msg="_chat_streamly_with_tool")
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
yield tk_count
|
||||
|
||||
def _chat_streamly(self, system, history, gen_conf, incremental_output=True):
|
||||
from http import HTTPStatus
|
||||
|
||||
@ -425,6 +801,13 @@ class QWenChat(Base):
|
||||
|
||||
yield tk_count
|
||||
|
||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
|
||||
for txt in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=incremental_output):
|
||||
yield txt
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
@ -445,6 +828,8 @@ class QWenChat(Base):
|
||||
|
||||
class ZhipuChat(Base):
|
||||
def __init__(self, key, model_name="glm-3-turbo", **kwargs):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
self.client = ZhipuAI(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
@ -504,6 +889,8 @@ class ZhipuChat(Base):
|
||||
|
||||
class OllamaChat(Base):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
|
||||
self.model_name = model_name
|
||||
|
||||
@ -515,10 +902,8 @@ class OllamaChat(Base):
|
||||
try:
|
||||
# Calculate context size
|
||||
ctx_size = self._calculate_dynamic_ctx(history)
|
||||
|
||||
options = {
|
||||
"num_ctx": ctx_size
|
||||
}
|
||||
|
||||
options = {"num_ctx": ctx_size}
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf:
|
||||
@ -545,9 +930,7 @@ class OllamaChat(Base):
|
||||
try:
|
||||
# Calculate context size
|
||||
ctx_size = self._calculate_dynamic_ctx(history)
|
||||
options = {
|
||||
"num_ctx": ctx_size
|
||||
}
|
||||
options = {"num_ctx": ctx_size}
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf:
|
||||
@ -561,7 +944,7 @@ class OllamaChat(Base):
|
||||
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10 )
|
||||
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10)
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||
@ -578,6 +961,8 @@ class OllamaChat(Base):
|
||||
|
||||
class LocalAIChat(Base):
|
||||
def __init__(self, key, model_name, base_url):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
if base_url.split("/")[-1] != "v1":
|
||||
@ -613,6 +998,8 @@ class LocalLLM(Base):
|
||||
return do_rpc
|
||||
|
||||
def __init__(self, key, model_name):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
from jina import Client
|
||||
|
||||
self.client = Client(port=12345, protocol="grpc", asyncio=True)
|
||||
@ -659,6 +1046,8 @@ class LocalLLM(Base):
|
||||
|
||||
class VolcEngineChat(Base):
|
||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
"""
|
||||
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
|
||||
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
|
||||
@ -677,6 +1066,8 @@ class MiniMaxChat(Base):
|
||||
model_name,
|
||||
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
|
||||
):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
if not base_url:
|
||||
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
||||
self.base_url = base_url
|
||||
@ -755,6 +1146,8 @@ class MiniMaxChat(Base):
|
||||
|
||||
class MistralChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
from mistralai.client import MistralClient
|
||||
|
||||
self.client = MistralClient(api_key=key)
|
||||
@ -808,6 +1201,8 @@ class MistralChat(Base):
|
||||
|
||||
class BedrockChat(Base):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
import boto3
|
||||
|
||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
||||
@ -887,6 +1282,8 @@ class BedrockChat(Base):
|
||||
|
||||
class GeminiChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
from google.generativeai import GenerativeModel, client
|
||||
|
||||
client.configure(api_key=key)
|
||||
@ -947,6 +1344,8 @@ class GeminiChat(Base):
|
||||
|
||||
class GroqChat(Base):
|
||||
def __init__(self, key, model_name, base_url=""):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
from groq import Groq
|
||||
|
||||
self.client = Groq(api_key=key)
|
||||
@ -1049,6 +1448,8 @@ class PPIOChat(Base):
|
||||
|
||||
class CoHereChat(Base):
|
||||
def __init__(self, key, model_name, base_url=""):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
from cohere import Client
|
||||
|
||||
self.client = Client(api_key=key)
|
||||
@ -1171,6 +1572,8 @@ class YiChat(Base):
|
||||
|
||||
class ReplicateChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
from replicate.client import Client
|
||||
|
||||
self.model_name = model_name
|
||||
@ -1218,6 +1621,8 @@ class ReplicateChat(Base):
|
||||
|
||||
class HunyuanChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
||||
|
||||
@ -1321,6 +1726,8 @@ class SparkChat(Base):
|
||||
|
||||
class BaiduYiyanChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
import qianfan
|
||||
|
||||
key = json.loads(key)
|
||||
@ -1372,6 +1779,8 @@ class BaiduYiyanChat(Base):
|
||||
|
||||
class AnthropicChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
import anthropic
|
||||
|
||||
self.client = anthropic.Anthropic(api_key=key)
|
||||
@ -1452,6 +1861,8 @@ class AnthropicChat(Base):
|
||||
|
||||
class GoogleChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
super().__init__(key, model_name, base_url=None)
|
||||
|
||||
import base64
|
||||
|
||||
from google.oauth2 import service_account
|
||||
|
||||
Reference in New Issue
Block a user