add stream chat (#811)

### What problem does this PR solve?

#709 
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh
2024-05-16 20:14:53 +08:00
committed by GitHub
parent d6772f5dd7
commit 95f809187e
12 changed files with 353 additions and 102 deletions

View File

@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import re
from datetime import datetime, timedelta
from flask import request
from flask import request, Response
from flask_login import login_required, current_user
from api.db import FileType, ParserType
@ -31,11 +32,11 @@ from api.settings import RetCode
from api.utils import get_uuid, current_timestamp, datetime_format
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
from itsdangerous import URLSafeTimedSerializer
from api.db.services.task_service import TaskService, queue_tasks
from api.utils.file_utils import filename_type, thumbnail
from rag.utils.minio_conn import MINIO
from api.db.db_models import Task
from api.db.services.file2document_service import File2DocumentService
def generate_confirmation_token(tenent_id):
serializer = URLSafeTimedSerializer(tenent_id)
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
@ -164,6 +165,7 @@ def completion():
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(retmsg="Conversation not found!")
if "quote" not in req: req["quote"] = False
msg = []
for m in req["messages"]:
@ -180,13 +182,45 @@ def completion():
return get_data_error_result(retmsg="Dialog not found!")
del req["conversation_id"]
del req["messages"]
ans = chat(dia, msg, **req)
if not conv.reference:
conv.reference = []
conv.reference.append(ans["reference"])
conv.message.append({"role": "assistant", "content": ans["answer"]})
API4ConversationService.append_message(conv.id, conv.to_dict())
return get_json_result(data=ans)
conv.message.append({"role": "assistant", "content": ""})
conv.reference.append({"chunks": [], "doc_aggs": []})
def fillin_conv(ans):
nonlocal conv
if not conv.reference:
conv.reference.append(ans["reference"])
else: conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, True, **req):
fillin_conv(ans)
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: "+str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
if req.get("stream", True):
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
else:
ans = chat(dia, msg, False, **req)
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
return get_json_result(data=ans)
except Exception as e:
return server_error_response(e)
@ -229,7 +263,6 @@ def upload():
return get_json_result(
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
file = request.files['file']
if file.filename == '':
return get_json_result(
@ -253,7 +286,6 @@ def upload():
location += "_"
blob = request.files['file'].read()
MINIO.put(kb_id, location, blob)
doc = {
"id": get_uuid(),
"kb_id": kb.id,
@ -266,42 +298,11 @@ def upload():
"size": len(blob),
"thumbnail": thumbnail(filename, blob)
}
form_data=request.form
if "parser_id" in form_data.keys():
if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
doc["parser_id"] = request.form.get("parser_id").strip()
if doc["type"] == FileType.VISUAL:
doc["parser_id"] = ParserType.PICTURE.value
if re.search(r"\.(ppt|pptx|pages)$", filename):
doc["parser_id"] = ParserType.PRESENTATION.value
doc_result = DocumentService.insert(doc)
doc = DocumentService.insert(doc)
return get_json_result(data=doc.to_json())
except Exception as e:
return server_error_response(e)
if "run" in form_data.keys():
if request.form.get("run").strip() == "1":
try:
info = {"run": 1, "progress": 0}
info["progress_msg"] = ""
info["chunk_num"] = 0
info["token_num"] = 0
DocumentService.update_by_id(doc["id"], info)
# if str(req["run"]) == TaskStatus.CANCEL.value:
tenant_id = DocumentService.get_tenant_id(doc["id"])
if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
#e, doc = DocumentService.get_by_id(doc["id"])
TaskService.filter_delete([Task.doc_id == doc["id"]])
e, doc = DocumentService.get_by_id(doc["id"])
doc = doc.to_dict()
doc["tenant_id"] = tenant_id
bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
queue_tasks(doc, bucket, name)
except Exception as e:
return server_error_response(e)
return get_json_result(data=doc_result.to_json())