mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add OpenAI compatible API for agent (#6329)
### What problem does this PR solve? add openai agent _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@ -17,22 +17,23 @@ import json
|
||||
import re
|
||||
import time
|
||||
|
||||
import tiktoken
|
||||
from flask import Response, jsonify, request
|
||||
|
||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||
from api.db.services.conversation_service import completion as rag_completion
|
||||
from api.db.services.canvas_service import completion as agent_completion, completionOpenAI
|
||||
from agent.canvas import Canvas
|
||||
from api.db import LLMType, StatusEnum
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.canvas_service import completion as agent_completion
|
||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||
from api.db.services.conversation_service import completion as rag_completion
|
||||
from api.db.services.dialog_service import DialogService, ask, chat
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_error_data_result, get_result, token_required, validate_request
|
||||
from api.utils.api_utils import get_result, token_required, get_data_openai, get_error_data_result, validate_request
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
|
||||
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@ -71,14 +72,11 @@ def create_agent_session(tenant_id, agent_id):
|
||||
req = request.form
|
||||
files = request.files
|
||||
user_id = request.args.get("user_id", "")
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
if not e:
|
||||
return get_error_data_result("Agent not found.")
|
||||
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
return get_error_data_result("You cannot access the agent.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
@ -352,6 +350,40 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
}
|
||||
return jsonify(response)
|
||||
|
||||
@manager.route('/agents_openai/<agent_id>/chat/completions', methods=['POST']) # noqa: F821
|
||||
@validate_request("model", "messages") # noqa: F821
|
||||
@token_required
|
||||
def agents_completion_openai_compatibility (tenant_id, agent_id):
|
||||
req = request.json
|
||||
tiktokenenc = tiktoken.get_encoding("cl100k_base")
|
||||
messages = req.get("messages", [])
|
||||
if not messages:
|
||||
return get_error_data_result("You must provide at least one message.")
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
|
||||
filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
|
||||
prompt_tokens = sum(len(tiktokenenc.encode(m["content"])) for m in filtered_messages)
|
||||
if not filtered_messages:
|
||||
return jsonify(get_data_openai(
|
||||
id=agent_id,
|
||||
content="No valid messages found (user or assistant).",
|
||||
finish_reason="stop",
|
||||
model=req.get("model", ""),
|
||||
completion_tokens=len(tiktokenenc.encode("No valid messages found (user or assistant).")),
|
||||
prompt_tokens=prompt_tokens,
|
||||
))
|
||||
|
||||
# Get the last user message as the question
|
||||
question = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")
|
||||
|
||||
if req.get("stream", True):
|
||||
return Response(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", ""), stream=True), mimetype="text/event-stream")
|
||||
else:
|
||||
# For non-streaming, just return the response directly
|
||||
response = next(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", ""), stream=False))
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
@ -364,9 +396,7 @@ def agent_completions(tenant_id, agent_id):
|
||||
dsl = cvs[0].dsl
|
||||
if not isinstance(dsl, str):
|
||||
dsl = json.dumps(dsl)
|
||||
# canvas = Canvas(dsl, tenant_id)
|
||||
# if canvas.get_preset_param():
|
||||
# req["question"] = ""
|
||||
|
||||
conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id)
|
||||
if not conv:
|
||||
return get_error_data_result(f"You don't own the session {req['session_id']}")
|
||||
|
||||
Reference in New Issue
Block a user