mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add stream chat (#811)
### What problem does this PR solve? #709 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -20,7 +20,6 @@ from openai import OpenAI
|
||||
import openai
|
||||
from ollama import Client
|
||||
from rag.nlp import is_english
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
@ -44,6 +43,31 @@ class Base(ABC):
|
||||
except openai.APIError as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
stream=True,
|
||||
**gen_conf)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:continue
|
||||
ans += resp.choices[0].delta.content
|
||||
total_tokens += 1
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
yield ans
|
||||
|
||||
except openai.APIError as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class GptTurbo(Base):
|
||||
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
||||
@ -97,6 +121,35 @@ class QWenChat(Base):
|
||||
|
||||
return "**ERROR**: " + response.message, tk_count
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
from http import HTTPStatus
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
ans = ""
|
||||
try:
|
||||
response = Generation.call(
|
||||
self.model_name,
|
||||
messages=history,
|
||||
result_format='message',
|
||||
stream=True,
|
||||
**gen_conf
|
||||
)
|
||||
tk_count = 0
|
||||
for resp in response:
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
ans = resp.output.choices[0]['message']['content']
|
||||
tk_count = resp.usage.total_tokens
|
||||
if resp.output.choices[0].get("finish_reason", "") == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
yield ans
|
||||
else:
|
||||
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield tk_count
|
||||
|
||||
|
||||
class ZhipuChat(Base):
|
||||
def __init__(self, key, model_name="glm-3-turbo", **kwargs):
|
||||
@ -122,6 +175,34 @@ class ZhipuChat(Base):
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
stream=True,
|
||||
**gen_conf
|
||||
)
|
||||
tk_count = 0
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:continue
|
||||
delta = resp.choices[0].delta.content
|
||||
ans += delta
|
||||
tk_count = resp.usage.total_tokens if response.usage else 0
|
||||
if resp.output.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield tk_count
|
||||
|
||||
|
||||
class OllamaChat(Base):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
@ -148,3 +229,28 @@ class OllamaChat(Base):
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
options = {}
|
||||
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
||||
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
stream=True,
|
||||
options=options
|
||||
)
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
return resp["prompt_eval_count"] + resp["eval_count"]
|
||||
ans = resp["message"]["content"]
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
yield 0
|
||||
|
||||
Reference in New Issue
Block a user