mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: OpenAI-compatible-API supports references (#8997)
### What problem does this PR solve? OpenAI-compatible-API supports references. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -19,21 +19,22 @@ import time
|
||||
|
||||
import tiktoken
|
||||
from flask import Response, jsonify, request
|
||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||
from api.db.services.conversation_service import completion as rag_completion
|
||||
from api.db.services.canvas_service import completion as agent_completion, completionOpenAI
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.db import LLMType, StatusEnum
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.canvas_service import UserCanvasService, completionOpenAI
|
||||
from api.db.services.canvas_service import completion as agent_completion
|
||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||
from api.db.services.conversation_service import completion as rag_completion
|
||||
from api.db.services.dialog_service import DialogService, ask, chat
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_result, token_required, get_data_openai, get_error_data_result, validate_request, check_duplicate_ids
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_result, token_required, validate_request
|
||||
from rag.prompts import chunks_format
|
||||
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@ -181,10 +182,16 @@ def chat_completion(tenant_id, chat_id):
|
||||
def chat_completion_openai_like(tenant_id, chat_id):
|
||||
"""
|
||||
OpenAI-like chat completion API that simulates the behavior of OpenAI's completions endpoint.
|
||||
|
||||
|
||||
This function allows users to interact with a model and receive responses based on a series of historical messages.
|
||||
If `stream` is set to True (by default), the response will be streamed in chunks, mimicking the OpenAI-style API.
|
||||
Set `stream` to False explicitly, the response will be returned in a single complete answer.
|
||||
|
||||
Reference:
|
||||
|
||||
- If `stream` is True, the final answer and reference information will appear in the **last chunk** of the stream.
|
||||
- If `stream` is False, the reference will be included in `choices[0].message.reference`.
|
||||
|
||||
Example usage:
|
||||
|
||||
curl -X POST https://ragflow_address.com/api/v1/chats_openai/<chat_id>/chat/completions \
|
||||
@ -202,7 +209,10 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
|
||||
model = "model"
|
||||
client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/<chat_id>")
|
||||
|
||||
|
||||
stream = True
|
||||
reference = True
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
@ -211,17 +221,24 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
{"role": "assistant", "content": "I am an AI assistant named..."},
|
||||
{"role": "user", "content": "Can you tell me how to install neovim"},
|
||||
],
|
||||
stream=True
|
||||
stream=stream,
|
||||
extra_body={"reference": reference}
|
||||
)
|
||||
|
||||
stream = True
|
||||
|
||||
if stream:
|
||||
for chunk in completion:
|
||||
print(chunk)
|
||||
for chunk in completion:
|
||||
print(chunk)
|
||||
if reference and chunk.choices[0].finish_reason == "stop":
|
||||
print(f"Reference:\n{chunk.choices[0].delta.reference}")
|
||||
print(f"Final content:\n{chunk.choices[0].delta.final_content}")
|
||||
else:
|
||||
print(completion.choices[0].message.content)
|
||||
if reference:
|
||||
print(completion.choices[0].message.reference)
|
||||
"""
|
||||
req = request.json
|
||||
req = request.get_json()
|
||||
|
||||
need_reference = bool(req.get("reference", False))
|
||||
|
||||
messages = req.get("messages", [])
|
||||
# To prevent empty [] input
|
||||
@ -261,9 +278,23 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
token_used = 0
|
||||
answer_cache = ""
|
||||
reasoning_cache = ""
|
||||
last_ans = {}
|
||||
response = {
|
||||
"id": f"chatcmpl-{chat_id}",
|
||||
"choices": [{"delta": {"content": "", "role": "assistant", "function_call": None, "tool_calls": None, "reasoning_content": ""}, "finish_reason": None, "index": 0, "logprobs": None}],
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": "",
|
||||
"role": "assistant",
|
||||
"function_call": None,
|
||||
"tool_calls": None,
|
||||
"reasoning_content": "",
|
||||
},
|
||||
"finish_reason": None,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"created": int(time.time()),
|
||||
"model": "model",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -272,7 +303,8 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
}
|
||||
|
||||
try:
|
||||
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools):
|
||||
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||
last_ans = ans
|
||||
answer = ans["answer"]
|
||||
|
||||
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
||||
@ -324,6 +356,9 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||
response["choices"][0]["finish_reason"] = "stop"
|
||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
|
||||
if need_reference:
|
||||
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
|
||||
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
|
||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||
yield "data:[DONE]\n\n"
|
||||
|
||||
@ -335,7 +370,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
return resp
|
||||
else:
|
||||
answer = None
|
||||
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools):
|
||||
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||
# focus answer content only
|
||||
answer = ans
|
||||
break
|
||||
@ -356,14 +391,28 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
"rejected_prediction_tokens": 0, # 0 for simplicity
|
||||
},
|
||||
},
|
||||
"choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": "stop", "index": 0}],
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
}
|
||||
if need_reference:
|
||||
response["choices"][0]["message"]["reference"] = chunks_format(answer.get("reference", []))
|
||||
|
||||
return jsonify(response)
|
||||
|
||||
@manager.route('/agents_openai/<agent_id>/chat/completions', methods=['POST']) # noqa: F821
|
||||
|
||||
@manager.route("/agents_openai/<agent_id>/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("model", "messages") # noqa: F821
|
||||
@token_required
|
||||
def agents_completion_openai_compatibility (tenant_id, agent_id):
|
||||
def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
req = request.json
|
||||
tiktokenenc = tiktoken.get_encoding("cl100k_base")
|
||||
messages = req.get("messages", [])
|
||||
@ -371,29 +420,31 @@ def agents_completion_openai_compatibility (tenant_id, agent_id):
|
||||
return get_error_data_result("You must provide at least one message.")
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
|
||||
|
||||
filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
|
||||
prompt_tokens = sum(len(tiktokenenc.encode(m["content"])) for m in filtered_messages)
|
||||
if not filtered_messages:
|
||||
return jsonify(get_data_openai(
|
||||
id=agent_id,
|
||||
content="No valid messages found (user or assistant).",
|
||||
finish_reason="stop",
|
||||
model=req.get("model", ""),
|
||||
completion_tokens=len(tiktokenenc.encode("No valid messages found (user or assistant).")),
|
||||
prompt_tokens=prompt_tokens,
|
||||
))
|
||||
|
||||
return jsonify(
|
||||
get_data_openai(
|
||||
id=agent_id,
|
||||
content="No valid messages found (user or assistant).",
|
||||
finish_reason="stop",
|
||||
model=req.get("model", ""),
|
||||
completion_tokens=len(tiktokenenc.encode("No valid messages found (user or assistant).")),
|
||||
prompt_tokens=prompt_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
# Get the last user message as the question
|
||||
question = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")
|
||||
|
||||
|
||||
if req.get("stream", True):
|
||||
return Response(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", req.get("metadata", {}).get("id","")), stream=True), mimetype="text/event-stream")
|
||||
return Response(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", req.get("metadata", {}).get("id", "")), stream=True), mimetype="text/event-stream")
|
||||
else:
|
||||
# For non-streaming, just return the response directly
|
||||
response = next(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", req.get("metadata", {}).get("id","")), stream=False))
|
||||
response = next(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", req.get("metadata", {}).get("id", "")), stream=False))
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
@ -547,7 +598,7 @@ def list_agent_session(tenant_id, agent_id):
|
||||
def delete(tenant_id, chat_id):
|
||||
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message="You don't own the chat")
|
||||
|
||||
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
@ -563,10 +614,10 @@ def delete(tenant_id, chat_id):
|
||||
conv_list.append(conv.id)
|
||||
else:
|
||||
conv_list = ids
|
||||
|
||||
|
||||
unique_conv_ids, duplicate_messages = check_duplicate_ids(conv_list, "session")
|
||||
conv_list = unique_conv_ids
|
||||
|
||||
|
||||
for id in conv_list:
|
||||
conv = ConversationService.query(id=id, dialog_id=chat_id)
|
||||
if not conv:
|
||||
@ -574,25 +625,19 @@ def delete(tenant_id, chat_id):
|
||||
continue
|
||||
ConversationService.delete_by_id(id)
|
||||
success_count += 1
|
||||
|
||||
|
||||
if errors:
|
||||
if success_count > 0:
|
||||
return get_result(
|
||||
data={"success_count": success_count, "errors": errors},
|
||||
message=f"Partially deleted {success_count} sessions with {len(errors)} errors"
|
||||
)
|
||||
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
||||
else:
|
||||
return get_error_data_result(message="; ".join(errors))
|
||||
|
||||
|
||||
if duplicate_messages:
|
||||
if success_count > 0:
|
||||
return get_result(
|
||||
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
|
||||
data={"success_count": success_count, "errors": duplicate_messages}
|
||||
)
|
||||
return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
|
||||
else:
|
||||
return get_error_data_result(message=";".join(duplicate_messages))
|
||||
|
||||
|
||||
return get_result()
|
||||
|
||||
|
||||
@ -632,25 +677,19 @@ def delete_agent_session(tenant_id, agent_id):
|
||||
continue
|
||||
API4ConversationService.delete_by_id(session_id)
|
||||
success_count += 1
|
||||
|
||||
|
||||
if errors:
|
||||
if success_count > 0:
|
||||
return get_result(
|
||||
data={"success_count": success_count, "errors": errors},
|
||||
message=f"Partially deleted {success_count} sessions with {len(errors)} errors"
|
||||
)
|
||||
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
||||
else:
|
||||
return get_error_data_result(message="; ".join(errors))
|
||||
|
||||
|
||||
if duplicate_messages:
|
||||
if success_count > 0:
|
||||
return get_result(
|
||||
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
|
||||
data={"success_count": success_count, "errors": duplicate_messages}
|
||||
)
|
||||
return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
|
||||
else:
|
||||
return get_error_data_result(message=";".join(duplicate_messages))
|
||||
|
||||
|
||||
return get_result()
|
||||
|
||||
|
||||
@ -719,7 +758,7 @@ Related search terms:
|
||||
|
||||
Reason:
|
||||
- When searching, users often only use one or two keywords, making it difficult to fully express their information needs.
|
||||
- Generating related search terms can help users dig deeper into relevant information and improve search efficiency.
|
||||
- Generating related search terms can help users dig deeper into relevant information and improve search efficiency.
|
||||
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
||||
|
||||
"""
|
||||
|
||||
@ -69,21 +69,31 @@ from openai import OpenAI
|
||||
model = "model"
|
||||
client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/<chat_id>")
|
||||
|
||||
stream = True
|
||||
reference = True
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
{"role": "assistant", "content": "I am an AI assistant named..."},
|
||||
{"role": "user", "content": "Can you tell me how to install neovim"},
|
||||
],
|
||||
stream=True
|
||||
stream=stream,
|
||||
extra_body={"reference": reference}
|
||||
)
|
||||
|
||||
stream = True
|
||||
if stream:
|
||||
for chunk in completion:
|
||||
print(chunk)
|
||||
for chunk in completion:
|
||||
print(chunk)
|
||||
if reference and chunk.choices[0].finish_reason == "stop":
|
||||
print(f"Reference:\n{chunk.choices[0].delta.reference}")
|
||||
print(f"Final content:\n{chunk.choices[0].delta.final_content}")
|
||||
else:
|
||||
print(completion.choices[0].message.content)
|
||||
if reference:
|
||||
print(completion.choices[0].message.reference)
|
||||
```
|
||||
|
||||
## DATASET MANAGEMENT
|
||||
|
||||
Reference in New Issue
Block a user