Refa: change LLM chat output from full to delta (incremental) (#6534)

### What problem does this PR solve?

Change LLM chat output from full to delta (incremental)

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-03-26 19:33:14 +08:00
committed by GitHub
parent 6599db1e99
commit df3890827d
3 changed files with 277 additions and 399 deletions

View File

@ -1,5 +1,5 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,25 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import asyncio
import json
import logging
import os
import random
import re
import time
from abc import ABC
import openai
import requests
from dashscope import Generation
from ollama import Client
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
from dashscope import Generation
from abc import ABC
from openai import OpenAI
import openai
from ollama import Client
from rag.nlp import is_chinese, is_english
from rag.utils import num_tokens_from_string
import os
import json
import requests
import asyncio
import logging
import time
# Error message constants
ERROR_PREFIX = "**ERROR**"
@ -53,21 +53,21 @@ LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to
class Base(ABC):
def __init__(self, key, model_name, base_url):
timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600))
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
self.model_name = model_name
# Configure retry parameters
self.max_retries = int(os.environ.get('LLM_MAX_RETRIES', 5))
self.base_delay = float(os.environ.get('LLM_BASE_DELAY', 2.0))
self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5))
self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0))
def _get_delay(self, attempt):
"""Calculate retry delay time"""
return self.base_delay * (2 ** attempt) + random.uniform(0, 0.5)
return self.base_delay * (2**attempt) + random.uniform(0, 0.5)
def _classify_error(self, error):
"""Classify error based on error message content"""
error_str = str(error).lower()
if "rate limit" in error_str or "429" in error_str or "tpm limit" in error_str or "too many requests" in error_str or "requests per minute" in error_str:
return ERROR_RATE_LIMIT
elif "auth" in error_str or "key" in error_str or "apikey" in error_str or "401" in error_str or "forbidden" in error_str or "permission" in error_str:
@ -98,11 +98,8 @@ class Base(ABC):
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
**gen_conf)
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
return "", 0
ans = response.choices[0].message.content.strip()
@ -111,17 +108,17 @@ class Base(ABC):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response)
return ans, self.total_token_count(response)
except Exception as e:
# Classify the error
error_code = self._classify_error(e)
# Check if it's a rate limit error or server error and not the last attempt
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1
if should_retry:
delay = self._get_delay(attempt)
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt+1}/{self.max_retries})")
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
else:
# For non-rate limit errors or the last attempt, return an error message
@ -136,24 +133,23 @@ class Base(ABC):
del gen_conf["max_tokens"]
ans = ""
total_tokens = 0
reasoning_start = False
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
stream=True,
**gen_conf)
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
if ans.find("<think>") < 0:
ans += "<think>"
ans = ans.replace("</think>", "")
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
ans += resp.choices[0].delta.content
reasoning_start = False
ans = resp.choices[0].delta.content
tol = self.total_token_count(resp)
if not tol:
@ -221,7 +217,7 @@ class ModelScopeChat(Base):
def __init__(self, key=None, model_name="", base_url=""):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = base_url.rstrip('/')
base_url = base_url.rstrip("/")
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
super().__init__(key, model_name.split("___")[0], base_url)
@ -236,8 +232,8 @@ class DeepSeekChat(Base):
class AzureChat(Base):
def __init__(self, key, model_name, **kwargs):
api_key = json.loads(key).get('api_key', '')
api_version = json.loads(key).get('api_version', '2024-02-01')
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
@ -264,16 +260,9 @@ class BaiChuanChat(Base):
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
extra_body={
"tools": [{
"type": "web_search",
"web_search": {
"enable": True,
"search_mode": "performance_first"
}
}]
},
**self._format_params(gen_conf))
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
**self._format_params(gen_conf),
)
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
if is_chinese([ans]):
@ -295,23 +284,16 @@ class BaiChuanChat(Base):
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
extra_body={
"tools": [{
"type": "web_search",
"web_search": {
"enable": True,
"search_mode": "performance_first"
}
}]
},
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
stream=True,
**self._format_params(gen_conf))
**self._format_params(gen_conf),
)
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
ans = resp.choices[0].delta.content
tol = self.total_token_count(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
@ -333,6 +315,7 @@ class BaiChuanChat(Base):
class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
import dashscope
dashscope.api_key = key
self.model_name = model_name
if self.is_reasoning_model(self.model_name):
@ -344,22 +327,18 @@ class QWenChat(Base):
if self.is_reasoning_model(self.model_name):
return super().chat(system, history, gen_conf)
stream_flag = str(os.environ.get('QWEN_CHAT_BY_STREAM', 'true')).lower() == 'true'
stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true"
if not stream_flag:
from http import HTTPStatus
if system:
history.insert(0, {"role": "system", "content": system})
response = Generation.call(
self.model_name,
messages=history,
result_format='message',
**gen_conf
)
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
ans = ""
tk_count = 0
if response.status_code == HTTPStatus.OK:
ans += response.output.choices[0]['message']['content']
ans += response.output.choices[0]["message"]["content"]
tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
@ -378,8 +357,9 @@ class QWenChat(Base):
else:
return "".join(result_list[:-1]), result_list[-1]
def _chat_streamly(self, system, history, gen_conf, incremental_output=False):
def _chat_streamly(self, system, history, gen_conf, incremental_output=True):
from http import HTTPStatus
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
@ -387,17 +367,10 @@ class QWenChat(Base):
ans = ""
tk_count = 0
try:
response = Generation.call(
self.model_name,
messages=history,
result_format='message',
stream=True,
incremental_output=incremental_output,
**gen_conf
)
response = Generation.call(self.model_name, messages=history, result_format="message", stream=True, incremental_output=incremental_output, **gen_conf)
for resp in response:
if resp.status_code == HTTPStatus.OK:
ans = resp.output.choices[0]['message']['content']
ans = resp.output.choices[0]["message"]["content"]
tk_count = self.total_token_count(resp)
if resp.output.choices[0].get("finish_reason", "") == "length":
if is_chinese(ans):
@ -406,8 +379,11 @@ class QWenChat(Base):
ans += LENGTH_NOTIFICATION_EN
yield ans
else:
yield ans + "\n**ERROR**: " + resp.message if not re.search(r" (key|quota)",
str(resp.message).lower()) else "Out of credit. Please set the API key in **settings > Model providers.**"
yield (
ans + "\n**ERROR**: " + resp.message
if not re.search(r" (key|quota)", str(resp.message).lower())
else "Out of credit. Please set the API key in **settings > Model providers.**"
)
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@ -423,10 +399,12 @@ class QWenChat(Base):
@staticmethod
def is_reasoning_model(model_name: str) -> bool:
return any([
model_name.lower().find("deepseek") >= 0,
model_name.lower().find("qwq") >= 0 and model_name.lower() != 'qwq-32b-preview',
])
return any(
[
model_name.lower().find("deepseek") >= 0,
model_name.lower().find("qwq") >= 0 and model_name.lower() != "qwq-32b-preview",
]
)
class ZhipuChat(Base):
@ -444,11 +422,7 @@ class ZhipuChat(Base):
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
**gen_conf
)
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
if is_chinese(ans):
@ -471,17 +445,12 @@ class ZhipuChat(Base):
ans = ""
tk_count = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
stream=True,
**gen_conf
)
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans += delta
ans = delta
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
@ -499,8 +468,7 @@ class ZhipuChat(Base):
class OllamaChat(Base):
def __init__(self, key, model_name, **kwargs):
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else \
Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
self.model_name = model_name
def chat(self, system, history, gen_conf):
@ -509,9 +477,7 @@ class OllamaChat(Base):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
try:
options = {
"num_ctx": 32768
}
options = {"num_ctx": 32768}
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
@ -522,12 +488,7 @@ class OllamaChat(Base):
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
response = self.client.chat(
model=self.model_name,
messages=history,
options=options,
keep_alive=-1
)
response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=-1)
ans = response["message"]["content"].strip()
return ans, response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
except Exception as e:
@ -551,17 +512,11 @@ class OllamaChat(Base):
options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = ""
try:
response = self.client.chat(
model=self.model_name,
messages=history,
stream=True,
options=options,
keep_alive=-1
)
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=-1)
for resp in response:
if resp["done"]:
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
ans += resp["message"]["content"]
ans = resp["message"]["content"]
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@ -588,9 +543,7 @@ class LocalLLM(Base):
def __conn(self):
from multiprocessing.connection import Client
self._connection = Client(
(self.host, self.port), authkey=b"infiniflow-token4kevinhu"
)
self._connection = Client((self.host, self.port), authkey=b"infiniflow-token4kevinhu")
def __getattr__(self, name):
import pickle
@ -613,17 +566,17 @@ class LocalLLM(Base):
def _prepare_prompt(self, system, history, gen_conf):
from rag.svr.jina_server import Prompt
if system:
history.insert(0, {"role": "system", "content": system})
return Prompt(message=history, gen_conf=gen_conf)
def _stream_response(self, endpoint, prompt):
from rag.svr.jina_server import Generation
answer = ""
try:
res = self.client.stream_doc(
on=endpoint, inputs=prompt, return_type=Generation
)
res = self.client.stream_doc(on=endpoint, inputs=prompt, return_type=Generation)
loop = asyncio.get_event_loop()
try:
while True:
@ -652,24 +605,24 @@ class LocalLLM(Base):
class VolcEngineChat(Base):
def __init__(self, key, model_name, base_url='https://ark.cn-beijing.volces.com/api/v3'):
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
"""
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
model_name is for display only
"""
base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3'
ark_api_key = json.loads(key).get('ark_api_key', '')
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3"
ark_api_key = json.loads(key).get("ark_api_key", "")
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
super().__init__(ark_api_key, model_name, base_url)
class MiniMaxChat(Base):
def __init__(
self,
key,
model_name,
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
self,
key,
model_name,
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
):
if not base_url:
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
@ -687,13 +640,9 @@ class MiniMaxChat(Base):
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = json.dumps(
{"model": self.model_name, "messages": history, **gen_conf}
)
payload = json.dumps({"model": self.model_name, "messages": history, **gen_conf})
try:
response = requests.request(
"POST", url=self.base_url, headers=headers, data=payload
)
response = requests.request("POST", url=self.base_url, headers=headers, data=payload)
response = response.json()
ans = response["choices"][0]["message"]["content"].strip()
if response["choices"][0]["finish_reason"] == "length":
@ -737,7 +686,7 @@ class MiniMaxChat(Base):
text = ""
if "choices" in resp and "delta" in resp["choices"][0]:
text = resp["choices"][0]["delta"]["content"]
ans += text
ans = text
tol = self.total_token_count(resp)
if not tol:
total_tokens += num_tokens_from_string(text)
@ -752,9 +701,9 @@ class MiniMaxChat(Base):
class MistralChat(Base):
def __init__(self, key, model_name, base_url=None):
from mistralai.client import MistralClient
self.client = MistralClient(api_key=key)
self.model_name = model_name
@ -765,10 +714,7 @@ class MistralChat(Base):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
try:
response = self.client.chat(
model=self.model_name,
messages=history,
**gen_conf)
response = self.client.chat(model=self.model_name, messages=history, **gen_conf)
ans = response.choices[0].message.content
if response.choices[0].finish_reason == "length":
if is_chinese(ans):
@ -788,14 +734,11 @@ class MistralChat(Base):
ans = ""
total_tokens = 0
try:
response = self.client.chat_stream(
model=self.model_name,
messages=history,
**gen_conf)
response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf)
for resp in response:
if not resp.choices or not resp.choices[0].delta.content:
continue
ans += resp.choices[0].delta.content
ans = resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
@ -811,23 +754,23 @@ class MistralChat(Base):
class BedrockChat(Base):
def __init__(self, key, model_name, **kwargs):
import boto3
self.bedrock_ak = json.loads(key).get('bedrock_ak', '')
self.bedrock_sk = json.loads(key).get('bedrock_sk', '')
self.bedrock_region = json.loads(key).get('bedrock_region', '')
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
self.bedrock_region = json.loads(key).get("bedrock_region", "")
self.model_name = model_name
if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '':
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
self.client = boto3.client('bedrock-runtime')
self.client = boto3.client("bedrock-runtime")
else:
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
def chat(self, system, history, gen_conf):
from botocore.exceptions import ClientError
for k in list(gen_conf.keys()):
if k not in ["temperature"]:
del gen_conf[k]
@ -853,6 +796,7 @@ class BedrockChat(Base):
def chat_streamly(self, system, history, gen_conf):
from botocore.exceptions import ClientError
for k in list(gen_conf.keys()):
if k not in ["temperature"]:
del gen_conf[k]
@ -860,14 +804,9 @@ class BedrockChat(Base):
if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
item["content"] = [{"text": item["content"]}]
if self.model_name.split('.')[0] == 'ai21':
if self.model_name.split(".")[0] == "ai21":
try:
response = self.client.converse(
modelId=self.model_name,
messages=history,
inferenceConfig=gen_conf,
system=[{"text": (system if system else "Answer the user's message.")}]
)
response = self.client.converse(modelId=self.model_name, messages=history, inferenceConfig=gen_conf, system=[{"text": (system if system else "Answer the user's message.")}])
ans = response["output"]["message"]["content"][0]["text"]
return ans, num_tokens_from_string(ans)
@ -878,16 +817,13 @@ class BedrockChat(Base):
try:
# Send the message to the model, using a basic inference configuration.
streaming_response = self.client.converse_stream(
modelId=self.model_name,
messages=history,
inferenceConfig=gen_conf,
system=[{"text": (system if system else "Answer the user's message.")}]
modelId=self.model_name, messages=history, inferenceConfig=gen_conf, system=[{"text": (system if system else "Answer the user's message.")}]
)
# Extract and print the streamed response text in real-time.
for resp in streaming_response["stream"]:
if "contentBlockDelta" in resp:
ans += resp["contentBlockDelta"]["delta"]["text"]
ans = resp["contentBlockDelta"]["delta"]["text"]
yield ans
except (ClientError, Exception) as e:
@ -897,13 +833,12 @@ class BedrockChat(Base):
class GeminiChat(Base):
def __init__(self, key, model_name, base_url=None):
from google.generativeai import client, GenerativeModel
from google.generativeai import GenerativeModel, client
client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = 'models/' + model_name
self.model_name = "models/" + model_name
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client
@ -916,17 +851,15 @@ class GeminiChat(Base):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
for item in history:
if 'role' in item and item['role'] == 'assistant':
item['role'] = 'model'
if 'role' in item and item['role'] == 'system':
item['role'] = 'user'
if 'content' in item:
item['parts'] = item.pop('content')
if "role" in item and item["role"] == "assistant":
item["role"] = "model"
if "role" in item and item["role"] == "system":
item["role"] = "user"
if "content" in item:
item["parts"] = item.pop("content")
try:
response = self.model.generate_content(
history,
generation_config=gen_conf)
response = self.model.generate_content(history, generation_config=gen_conf)
ans = response.text
return ans, response.usage_metadata.total_token_count
except Exception as e:
@ -941,17 +874,15 @@ class GeminiChat(Base):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
for item in history:
if 'role' in item and item['role'] == 'assistant':
item['role'] = 'model'
if 'content' in item:
item['parts'] = item.pop('content')
if "role" in item and item["role"] == "assistant":
item["role"] = "model"
if "content" in item:
item["parts"] = item.pop("content")
ans = ""
try:
response = self.model.generate_content(
history,
generation_config=gen_conf, stream=True)
response = self.model.generate_content(history, generation_config=gen_conf, stream=True)
for resp in response:
ans += resp.text
ans = resp.text
yield ans
yield response._chunks[-1].usage_metadata.total_token_count
@ -962,8 +893,9 @@ class GeminiChat(Base):
class GroqChat(Base):
def __init__(self, key, model_name, base_url=''):
def __init__(self, key, model_name, base_url=""):
from groq import Groq
self.client = Groq(api_key=key)
self.model_name = model_name
@ -975,11 +907,7 @@ class GroqChat(Base):
del gen_conf[k]
ans = ""
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
**gen_conf
)
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
ans = response.choices[0].message.content
if response.choices[0].finish_reason == "length":
if is_chinese(ans):
@ -999,16 +927,11 @@ class GroqChat(Base):
ans = ""
total_tokens = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
stream=True,
**gen_conf
)
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices or not resp.choices[0].delta.content:
continue
ans += resp.choices[0].delta.content
ans = resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
@ -1096,16 +1019,10 @@ class CoHereChat(Base):
mes = history.pop()["message"]
ans = ""
try:
response = self.client.chat(
model=self.model_name, chat_history=history, message=mes, **gen_conf
)
response = self.client.chat(model=self.model_name, chat_history=history, message=mes, **gen_conf)
ans = response.text
if response.finish_reason == "MAX_TOKENS":
ans += (
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return (
ans,
response.meta.tokens.input_tokens + response.meta.tokens.output_tokens,
@ -1133,20 +1050,14 @@ class CoHereChat(Base):
ans = ""
total_tokens = 0
try:
response = self.client.chat_stream(
model=self.model_name, chat_history=history, message=mes, **gen_conf
)
response = self.client.chat_stream(model=self.model_name, chat_history=history, message=mes, **gen_conf)
for resp in response:
if resp.event_type == "text-generation":
ans += resp.text
ans = resp.text
total_tokens += num_tokens_from_string(resp.text)
elif resp.event_type == "stream-end":
if resp.finish_reason == "MAX_TOKENS":
ans += (
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans
except Exception as e:
@ -1217,9 +1128,7 @@ class ReplicateChat(Base):
del gen_conf["max_tokens"]
if system:
self.system = system
prompt = "\n".join(
[item["role"] + ":" + item["content"] for item in history[-5:]]
)
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
ans = ""
try:
response = self.client.run(
@ -1236,9 +1145,7 @@ class ReplicateChat(Base):
del gen_conf["max_tokens"]
if system:
self.system = system
prompt = "\n".join(
[item["role"] + ":" + item["content"] for item in history[-5:]]
)
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
ans = ""
try:
response = self.client.run(
@ -1246,7 +1153,7 @@ class ReplicateChat(Base):
input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
)
for resp in response:
ans += resp
ans = resp
yield ans
except Exception as e:
@ -1268,10 +1175,10 @@ class HunyuanChat(Base):
self.client = hunyuan_client.HunyuanClient(cred, "")
def chat(self, system, history, gen_conf):
from tencentcloud.hunyuan.v20230901 import models
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
from tencentcloud.hunyuan.v20230901 import models
_gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
@ -1296,10 +1203,10 @@ class HunyuanChat(Base):
return ans + "\n**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf):
from tencentcloud.hunyuan.v20230901 import models
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
from tencentcloud.hunyuan.v20230901 import models
_gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
@ -1327,7 +1234,7 @@ class HunyuanChat(Base):
resp = json.loads(resp["data"])
if not resp["Choices"] or not resp["Choices"][0]["Delta"]["Content"]:
continue
ans += resp["Choices"][0]["Delta"]["Content"]
ans = resp["Choices"][0]["Delta"]["Content"]
total_tokens += 1
yield ans
@ -1339,9 +1246,7 @@ class HunyuanChat(Base):
class SparkChat(Base):
def __init__(
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
):
def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"):
if not base_url:
base_url = "https://spark-api-open.xf-yun.com/v1"
model2version = {
@ -1374,22 +1279,14 @@ class BaiduYiyanChat(Base):
def chat(self, system, history, gen_conf):
if system:
self.system = system
gen_conf["penalty_score"] = (
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
0)) / 2
) + 1
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
ans = ""
try:
response = self.client.do(
model=self.model_name,
messages=history,
system=self.system,
**gen_conf
).body
ans = response['result']
response = self.client.do(model=self.model_name, messages=history, system=self.system, **gen_conf).body
ans = response["result"]
return ans, self.total_token_count(response)
except Exception as e:
@ -1398,26 +1295,17 @@ class BaiduYiyanChat(Base):
def chat_streamly(self, system, history, gen_conf):
if system:
self.system = system
gen_conf["penalty_score"] = (
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
0)) / 2
) + 1
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
ans = ""
total_tokens = 0
try:
response = self.client.do(
model=self.model_name,
messages=history,
system=self.system,
stream=True,
**gen_conf
)
response = self.client.do(model=self.model_name, messages=history, system=self.system, stream=True, **gen_conf)
for resp in response:
resp = resp.body
ans += resp['result']
ans = resp["result"]
total_tokens = self.total_token_count(resp)
yield ans
@ -1458,11 +1346,7 @@ class AnthropicChat(Base):
).to_dict()
ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens":
ans += (
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return (
ans,
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
@ -1483,6 +1367,7 @@ class AnthropicChat(Base):
ans = ""
total_tokens = 0
reasoning_start = False
try:
response = self.client.messages.create(
model=self.model_name,
@ -1492,15 +1377,17 @@ class AnthropicChat(Base):
**gen_conf,
)
for res in response:
if res.type == 'content_block_delta':
if res.type == "content_block_delta":
if res.delta.type == "thinking_delta" and res.delta.thinking:
if ans.find("<think>") < 0:
ans += "<think>"
ans = ans.replace("</think>", "")
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += res.delta.thinking + "</think>"
else:
reasoning_start = False
text = res.delta.text
ans += text
ans = text
total_tokens += num_tokens_from_string(text)
yield ans
except Exception as e:
@ -1511,13 +1398,12 @@ class AnthropicChat(Base):
class GoogleChat(Base):
def __init__(self, key, model_name, base_url=None):
from google.oauth2 import service_account
import base64
from google.oauth2 import service_account
key = json.loads(key)
access_token = json.loads(
base64.b64decode(key.get("google_service_account_key", ""))
)
access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
project_id = key.get("google_project_id", "")
region = key.get("google_region", "")
@ -1530,28 +1416,20 @@ class GoogleChat(Base):
from google.auth.transport.requests import Request
if access_token:
credits = service_account.Credentials.from_service_account_info(
access_token, scopes=scopes
)
credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
request = Request()
credits.refresh(request)
token = credits.token
self.client = AnthropicVertex(
region=region, project_id=project_id, access_token=token
)
self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token)
else:
self.client = AnthropicVertex(region=region, project_id=project_id)
else:
from google.cloud import aiplatform
import vertexai.generative_models as glm
from google.cloud import aiplatform
if access_token:
credits = service_account.Credentials.from_service_account_info(
access_token
)
aiplatform.init(
credentials=credits, project=project_id, location=region
)
credits = service_account.Credentials.from_service_account_info(access_token)
aiplatform.init(credentials=credits, project=project_id, location=region)
else:
aiplatform.init(project=project_id, location=region)
self.client = glm.GenerativeModel(model_name=self.model_name)
@ -1573,15 +1451,10 @@ class GoogleChat(Base):
).json()
ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens":
ans += (
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return (
ans,
response["usage"]["input_tokens"]
+ response["usage"]["output_tokens"],
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
)
except Exception as e:
return "\n**ERROR**: " + str(e), 0
@ -1598,9 +1471,7 @@ class GoogleChat(Base):
if "content" in item:
item["parts"] = item.pop("content")
try:
response = self.client.generate_content(
history, generation_config=gen_conf
)
response = self.client.generate_content(history, generation_config=gen_conf)
ans = response.text
return ans, response.usage_metadata.total_token_count
except Exception as e:
@ -1627,7 +1498,7 @@ class GoogleChat(Base):
res = res.decode("utf-8")
if "content_block_delta" in res and "data" in res:
text = json.loads(res[6:])["delta"]["text"]
ans += text
ans = text
total_tokens += num_tokens_from_string(text)
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@ -1647,11 +1518,9 @@ class GoogleChat(Base):
item["parts"] = item.pop("content")
ans = ""
try:
response = self.model.generate_content(
history, generation_config=gen_conf, stream=True
)
response = self.model.generate_content(history, generation_config=gen_conf, stream=True)
for resp in response:
ans += resp.text
ans = resp.text
yield ans
except Exception as e: