mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix: update broken agent completion due to v0.20.0 changes (#9309)
### What problem does this PR solve? Update broken agent completion due to v0.20.0 changes. #9199 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -51,7 +51,7 @@ def create(tenant_id, chat_id):
|
|||||||
"name": req.get("name", "New session"),
|
"name": req.get("name", "New session"),
|
||||||
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}],
|
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}],
|
||||||
"user_id": req.get("user_id", ""),
|
"user_id": req.get("user_id", ""),
|
||||||
"reference":[{}],
|
"reference": [{}],
|
||||||
}
|
}
|
||||||
if not conv.get("name"):
|
if not conv.get("name"):
|
||||||
return get_error_data_result(message="`name` can not be empty.")
|
return get_error_data_result(message="`name` can not be empty.")
|
||||||
@ -475,41 +475,38 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
|||||||
@token_required
|
@token_required
|
||||||
def agent_completions(tenant_id, agent_id):
|
def agent_completions(tenant_id, agent_id):
|
||||||
req = request.json
|
req = request.json
|
||||||
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
|
|
||||||
if not cvs:
|
|
||||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
|
||||||
if req.get("session_id"):
|
|
||||||
dsl = cvs[0].dsl
|
|
||||||
if not isinstance(dsl, str):
|
|
||||||
dsl = json.dumps(dsl)
|
|
||||||
|
|
||||||
conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id)
|
ans = {}
|
||||||
if not conv:
|
|
||||||
return get_error_data_result(f"You don't own the session {req['session_id']}")
|
|
||||||
# If an update to UserCanvas is detected, update the API4Conversation.dsl
|
|
||||||
sync_dsl = req.get("sync_dsl", False)
|
|
||||||
if sync_dsl is True and cvs[0].update_time > conv[0].update_time:
|
|
||||||
current_dsl = conv[0].dsl
|
|
||||||
new_dsl = json.loads(dsl)
|
|
||||||
state_fields = ["history", "messages", "path", "reference"]
|
|
||||||
states = {field: current_dsl.get(field, []) for field in state_fields}
|
|
||||||
current_dsl.update(new_dsl)
|
|
||||||
current_dsl.update(states)
|
|
||||||
API4ConversationService.update_by_id(req["session_id"], {"dsl": current_dsl})
|
|
||||||
else:
|
|
||||||
req["question"] = ""
|
|
||||||
if req.get("stream", True):
|
if req.get("stream", True):
|
||||||
resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream")
|
|
||||||
|
def generate():
|
||||||
|
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||||
|
if isinstance(answer, str):
|
||||||
|
try:
|
||||||
|
ans = json.loads(answer[5:]) # remove "data:"
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ans.get("event") != "message":
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield answer
|
||||||
|
|
||||||
|
yield "data:[DONE]\n\n"
|
||||||
|
|
||||||
|
resp = Response(generate(), mimetype="text/event-stream")
|
||||||
resp.headers.add_header("Cache-control", "no-cache")
|
resp.headers.add_header("Cache-control", "no-cache")
|
||||||
resp.headers.add_header("Connection", "keep-alive")
|
resp.headers.add_header("Connection", "keep-alive")
|
||||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
return resp
|
return resp
|
||||||
try:
|
|
||||||
for answer in agent_completion(tenant_id, agent_id, **req):
|
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||||
return get_result(data=answer)
|
try:
|
||||||
except Exception as e:
|
ans = json.loads(answer[5:]) # remove "data:"
|
||||||
return get_error_data_result(str(e))
|
except Exception as e:
|
||||||
|
return get_result(data=f"**ERROR**: {str(e)}")
|
||||||
|
return get_result(data=ans)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
|
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
|
||||||
@ -836,44 +833,30 @@ def chatbot_completions(dialog_id):
|
|||||||
|
|
||||||
|
|
||||||
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||||
def agent_bot_completions(agent_id):
|
@token_required
|
||||||
|
def agent_bot_completions(tenant_id, agent_id):
|
||||||
req = request.json
|
req = request.json
|
||||||
|
|
||||||
token = request.headers.get("Authorization").split()
|
|
||||||
if len(token) != 2:
|
|
||||||
return get_error_data_result(message='Authorization is not valid!"')
|
|
||||||
token = token[1]
|
|
||||||
objs = APIToken.query(beta=token)
|
|
||||||
if not objs:
|
|
||||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
|
||||||
|
|
||||||
if req.get("stream", True):
|
if req.get("stream", True):
|
||||||
resp = Response(agent_completion(objs[0].tenant_id, agent_id, **req), mimetype="text/event-stream")
|
resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream")
|
||||||
resp.headers.add_header("Cache-control", "no-cache")
|
resp.headers.add_header("Cache-control", "no-cache")
|
||||||
resp.headers.add_header("Connection", "keep-alive")
|
resp.headers.add_header("Connection", "keep-alive")
|
||||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
|
for answer in agent_completion(tenant_id, agent_id, **req):
|
||||||
return get_result(data=answer)
|
return get_result(data=answer)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
|
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
|
||||||
def begin_inputs(agent_id):
|
@token_required
|
||||||
token = request.headers.get("Authorization").split()
|
def begin_inputs(tenant_id, agent_id):
|
||||||
if len(token) != 2:
|
|
||||||
return get_error_data_result(message='Authorization is not valid!"')
|
|
||||||
token = token[1]
|
|
||||||
objs = APIToken.query(beta=token)
|
|
||||||
if not objs:
|
|
||||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
|
||||||
|
|
||||||
e, cvs = UserCanvasService.get_by_id(agent_id)
|
e, cvs = UserCanvasService.get_by_id(agent_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
|
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
|
||||||
|
|
||||||
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id)
|
canvas = Canvas(json.dumps(cvs.dsl), tenant_id)
|
||||||
return get_result(
|
return get_result(
|
||||||
data={
|
data={
|
||||||
"title": cvs.title,
|
"title": cvs.title,
|
||||||
|
|||||||
@ -123,7 +123,7 @@ class UserCanvasService(CommonService):
|
|||||||
|
|
||||||
|
|
||||||
def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||||
query = kwargs.get("query", "")
|
query = kwargs.get("query", "") or kwargs.get("question", "")
|
||||||
files = kwargs.get("files", [])
|
files = kwargs.get("files", [])
|
||||||
inputs = kwargs.get("inputs", {})
|
inputs = kwargs.get("inputs", {})
|
||||||
user_id = kwargs.get("user_id", "")
|
user_id = kwargs.get("user_id", "")
|
||||||
|
|||||||
Reference in New Issue
Block a user