diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 0476e3572..a32d05889 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -51,7 +51,7 @@ def create(tenant_id, chat_id): "name": req.get("name", "New session"), "message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}], "user_id": req.get("user_id", ""), - "reference":[{}], + "reference": [{}], } if not conv.get("name"): return get_error_data_result(message="`name` can not be empty.") @@ -475,41 +475,38 @@ def agents_completion_openai_compatibility(tenant_id, agent_id): @token_required def agent_completions(tenant_id, agent_id): req = request.json - cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id) - if not cvs: - return get_error_data_result(f"You don't own the agent {agent_id}") - if req.get("session_id"): - dsl = cvs[0].dsl - if not isinstance(dsl, str): - dsl = json.dumps(dsl) - conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id) - if not conv: - return get_error_data_result(f"You don't own the session {req['session_id']}") - # If an update to UserCanvas is detected, update the API4Conversation.dsl - sync_dsl = req.get("sync_dsl", False) - if sync_dsl is True and cvs[0].update_time > conv[0].update_time: - current_dsl = conv[0].dsl - new_dsl = json.loads(dsl) - state_fields = ["history", "messages", "path", "reference"] - states = {field: current_dsl.get(field, []) for field in state_fields} - current_dsl.update(new_dsl) - current_dsl.update(states) - API4ConversationService.update_by_id(req["session_id"], {"dsl": current_dsl}) - else: - req["question"] = "" + ans = {} if req.get("stream", True): - resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream") + + def generate(): + for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): + if isinstance(answer, str): + try: + ans = json.loads(answer[5:]) # remove "data:" + except Exception: + continue + + if ans.get("event") != "message": + continue + + yield answer + + yield "data:[DONE]\n\n" + + resp = Response(generate(), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp - try: - for answer in agent_completion(tenant_id, agent_id, **req): - return get_result(data=answer) - except Exception as e: - return get_error_data_result(str(e)) + + for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): + try: + ans = json.loads(answer[5:]) # remove "data:" + except Exception as e: + return get_result(data=f"**ERROR**: {str(e)}") + return get_result(data=ans) @manager.route("/chats//sessions", methods=["GET"]) # noqa: F821 @@ -836,44 +833,30 @@ def chatbot_completions(dialog_id): @manager.route("/agentbots//completions", methods=["POST"]) # noqa: F821 -def agent_bot_completions(agent_id): +@token_required +def agent_bot_completions(tenant_id, agent_id): req = request.json - token = request.headers.get("Authorization").split() - if len(token) != 2: - return get_error_data_result(message='Authorization is not valid!"') - token = token[1] - objs = APIToken.query(beta=token) - if not objs: - return get_error_data_result(message='Authentication error: API key is invalid!"') - if req.get("stream", True): - resp = Response(agent_completion(objs[0].tenant_id, agent_id, **req), mimetype="text/event-stream") + resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp - for answer in agent_completion(objs[0].tenant_id, agent_id, **req): + for answer in agent_completion(tenant_id, agent_id, **req): return get_result(data=answer) @manager.route("/agentbots//inputs", methods=["GET"]) # noqa: F821 -def begin_inputs(agent_id): - token = request.headers.get("Authorization").split() - if len(token) != 2: - return get_error_data_result(message='Authorization is not valid!"') - token = token[1] - objs = APIToken.query(beta=token) - if not objs: - return get_error_data_result(message='Authentication error: API key is invalid!"') - +@token_required +def begin_inputs(tenant_id, agent_id): e, cvs = UserCanvasService.get_by_id(agent_id) if not e: return get_error_data_result(f"Can't find agent by ID: {agent_id}") - canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id) + canvas = Canvas(json.dumps(cvs.dsl), tenant_id) return get_result( data={ "title": cvs.title, diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index f791edf7a..b15c12007 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -123,7 +123,7 @@ class UserCanvasService(CommonService): def completion(tenant_id, agent_id, session_id=None, **kwargs): - query = kwargs.get("query", "") + query = kwargs.get("query", "") or kwargs.get("question", "") files = kwargs.get("files", []) inputs = kwargs.get("inputs", {}) user_id = kwargs.get("user_id", "")