mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add stream chat (#811)
### What problem does this PR solve? #709 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -20,7 +20,6 @@ from openai import OpenAI
|
||||
import openai
|
||||
from ollama import Client
|
||||
from rag.nlp import is_english
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
@ -44,6 +43,31 @@ class Base(ABC):
|
||||
except openai.APIError as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
stream=True,
|
||||
**gen_conf)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:continue
|
||||
ans += resp.choices[0].delta.content
|
||||
total_tokens += 1
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
yield ans
|
||||
|
||||
except openai.APIError as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class GptTurbo(Base):
|
||||
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
||||
@ -97,6 +121,35 @@ class QWenChat(Base):
|
||||
|
||||
return "**ERROR**: " + response.message, tk_count
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
from http import HTTPStatus
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
ans = ""
|
||||
try:
|
||||
response = Generation.call(
|
||||
self.model_name,
|
||||
messages=history,
|
||||
result_format='message',
|
||||
stream=True,
|
||||
**gen_conf
|
||||
)
|
||||
tk_count = 0
|
||||
for resp in response:
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
ans = resp.output.choices[0]['message']['content']
|
||||
tk_count = resp.usage.total_tokens
|
||||
if resp.output.choices[0].get("finish_reason", "") == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
yield ans
|
||||
else:
|
||||
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield tk_count
|
||||
|
||||
|
||||
class ZhipuChat(Base):
|
||||
def __init__(self, key, model_name="glm-3-turbo", **kwargs):
|
||||
@ -122,6 +175,34 @@ class ZhipuChat(Base):
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
stream=True,
|
||||
**gen_conf
|
||||
)
|
||||
tk_count = 0
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:continue
|
||||
delta = resp.choices[0].delta.content
|
||||
ans += delta
|
||||
tk_count = resp.usage.total_tokens if response.usage else 0
|
||||
if resp.output.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield tk_count
|
||||
|
||||
|
||||
class OllamaChat(Base):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
@ -148,3 +229,28 @@ class OllamaChat(Base):
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
options = {}
|
||||
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
||||
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
|
||||
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
stream=True,
|
||||
options=options
|
||||
)
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
return resp["prompt_eval_count"] + resp["eval_count"]
|
||||
ans = resp["message"]["content"]
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
yield 0
|
||||
|
||||
@ -80,7 +80,7 @@ def set_progress(task_id, from_page=0, to_page=-1,
|
||||
|
||||
if to_page > 0:
|
||||
if msg:
|
||||
msg = f"Page({from_page+1}~{to_page+1}): " + msg
|
||||
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
|
||||
d = {"progress_msg": msg}
|
||||
if prog is not None:
|
||||
d["progress"] = prog
|
||||
@ -124,7 +124,7 @@ def get_minio_binary(bucket, name):
|
||||
def build(row):
|
||||
if row["size"] > DOC_MAXIMUM_SIZE:
|
||||
set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
||||
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||
return []
|
||||
|
||||
callback = partial(
|
||||
@ -138,12 +138,12 @@ def build(row):
|
||||
bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"])
|
||||
binary = get_minio_binary(bucket, name)
|
||||
cron_logger.info(
|
||||
"From minio({}) {}/{}".format(timer()-st, row["location"], row["name"]))
|
||||
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
||||
cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
|
||||
to_page=row["to_page"], lang=row["language"], callback=callback,
|
||||
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
|
||||
cron_logger.info(
|
||||
"Chunkking({}) {}/{}".format(timer()-st, row["location"], row["name"]))
|
||||
"Chunkking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
||||
except TimeoutError as e:
|
||||
callback(-1, f"Internal server error: Fetch file timeout. Could you try it again.")
|
||||
cron_logger.error(
|
||||
@ -173,7 +173,7 @@ def build(row):
|
||||
d.update(ck)
|
||||
md5 = hashlib.md5()
|
||||
md5.update((ck["content_with_weight"] +
|
||||
str(d["doc_id"])).encode("utf-8"))
|
||||
str(d["doc_id"])).encode("utf-8"))
|
||||
d["_id"] = md5.hexdigest()
|
||||
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||
@ -261,7 +261,7 @@ def main():
|
||||
|
||||
st = timer()
|
||||
cks = build(r)
|
||||
cron_logger.info("Build chunks({}): {:.2f}".format(r["name"], timer()-st))
|
||||
cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st))
|
||||
if cks is None:
|
||||
continue
|
||||
if not cks:
|
||||
@ -271,7 +271,7 @@ def main():
|
||||
## set_progress(r["did"], -1, "ERROR: ")
|
||||
callback(
|
||||
msg="Finished slicing files(%d). Start to embedding the content." %
|
||||
len(cks))
|
||||
len(cks))
|
||||
st = timer()
|
||||
try:
|
||||
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
|
||||
@ -279,19 +279,19 @@ def main():
|
||||
callback(-1, "Embedding error:{}".format(str(e)))
|
||||
cron_logger.error(str(e))
|
||||
tk_count = 0
|
||||
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer()-st))
|
||||
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||
|
||||
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer()-st))
|
||||
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
||||
init_kb(r)
|
||||
chunk_count = len(set([c["_id"] for c in cks]))
|
||||
st = timer()
|
||||
es_r = ""
|
||||
for b in range(0, len(cks), 32):
|
||||
es_r = ELASTICSEARCH.bulk(cks[b:b+32], search.index_name(r["tenant_id"]))
|
||||
es_r = ELASTICSEARCH.bulk(cks[b:b + 32], search.index_name(r["tenant_id"]))
|
||||
if b % 128 == 0:
|
||||
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
||||
|
||||
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer()-st))
|
||||
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||
if es_r:
|
||||
callback(-1, "Index failure!")
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
@ -307,8 +307,7 @@ def main():
|
||||
r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
|
||||
cron_logger.info(
|
||||
"Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
|
||||
r["id"], tk_count, len(cks), timer()-st))
|
||||
|
||||
r["id"], tk_count, len(cks), timer() - st))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -43,6 +43,9 @@ class ESConnection:
|
||||
v = v["number"].split(".")[0]
|
||||
return int(v) >= 7
|
||||
|
||||
def health(self):
|
||||
return dict(self.es.cluster.health())
|
||||
|
||||
def upsert(self, df, idxnm=""):
|
||||
res = []
|
||||
for d in df:
|
||||
|
||||
@ -34,6 +34,16 @@ class RAGFlowMinio(object):
|
||||
del self.conn
|
||||
self.conn = None
|
||||
|
||||
def health(self):
|
||||
bucket, fnm, binary = "_t@@@1", "_t@@@1", b"_t@@@1"
|
||||
if not self.conn.bucket_exists(bucket):
|
||||
self.conn.make_bucket(bucket)
|
||||
r = self.conn.put_object(bucket, fnm,
|
||||
BytesIO(binary),
|
||||
len(binary)
|
||||
)
|
||||
return r
|
||||
|
||||
def put(self, bucket, fnm, binary):
|
||||
for _ in range(3):
|
||||
try:
|
||||
|
||||
@ -44,6 +44,10 @@ class RedisDB:
|
||||
logging.warning("Redis can't be connected.")
|
||||
return self.REDIS
|
||||
|
||||
def health(self, queue_name):
|
||||
self.REDIS.ping()
|
||||
return self.REDIS.xinfo_groups(queue_name)[0]
|
||||
|
||||
def is_alive(self):
|
||||
return self.REDIS is not None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user