mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: OpenAI-compatible-API supports references (#8997)
### What problem does this PR solve? OpenAI-compatible-API supports references. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -19,21 +19,22 @@ import time
|
|||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from flask import Response, jsonify, request
|
from flask import Response, jsonify, request
|
||||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
|
||||||
from api.db.services.conversation_service import completion as rag_completion
|
|
||||||
from api.db.services.canvas_service import completion as agent_completion, completionOpenAI
|
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
from api.db import LLMType, StatusEnum
|
from api.db import LLMType, StatusEnum
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
from api.db.services.api_service import API4ConversationService
|
from api.db.services.api_service import API4ConversationService
|
||||||
from api.db.services.canvas_service import UserCanvasService
|
from api.db.services.canvas_service import UserCanvasService, completionOpenAI
|
||||||
|
from api.db.services.canvas_service import completion as agent_completion
|
||||||
|
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||||
|
from api.db.services.conversation_service import completion as rag_completion
|
||||||
from api.db.services.dialog_service import DialogService, ask, chat
|
from api.db.services.dialog_service import DialogService, ask, chat
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.utils import get_uuid
|
|
||||||
from api.utils.api_utils import get_result, token_required, get_data_openai, get_error_data_result, validate_request, check_duplicate_ids
|
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.utils import get_uuid
|
||||||
|
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_result, token_required, validate_request
|
||||||
|
from rag.prompts import chunks_format
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||||
@ -185,6 +186,12 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
This function allows users to interact with a model and receive responses based on a series of historical messages.
|
This function allows users to interact with a model and receive responses based on a series of historical messages.
|
||||||
If `stream` is set to True (by default), the response will be streamed in chunks, mimicking the OpenAI-style API.
|
If `stream` is set to True (by default), the response will be streamed in chunks, mimicking the OpenAI-style API.
|
||||||
Set `stream` to False explicitly, the response will be returned in a single complete answer.
|
Set `stream` to False explicitly, the response will be returned in a single complete answer.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
|
||||||
|
- If `stream` is True, the final answer and reference information will appear in the **last chunk** of the stream.
|
||||||
|
- If `stream` is False, the reference will be included in `choices[0].message.reference`.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
|
|
||||||
curl -X POST https://ragflow_address.com/api/v1/chats_openai/<chat_id>/chat/completions \
|
curl -X POST https://ragflow_address.com/api/v1/chats_openai/<chat_id>/chat/completions \
|
||||||
@ -203,6 +210,9 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
model = "model"
|
model = "model"
|
||||||
client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/<chat_id>")
|
client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/<chat_id>")
|
||||||
|
|
||||||
|
stream = True
|
||||||
|
reference = True
|
||||||
|
|
||||||
completion = client.chat.completions.create(
|
completion = client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[
|
messages=[
|
||||||
@ -211,17 +221,24 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
{"role": "assistant", "content": "I am an AI assistant named..."},
|
{"role": "assistant", "content": "I am an AI assistant named..."},
|
||||||
{"role": "user", "content": "Can you tell me how to install neovim"},
|
{"role": "user", "content": "Can you tell me how to install neovim"},
|
||||||
],
|
],
|
||||||
stream=True
|
stream=stream,
|
||||||
|
extra_body={"reference": reference}
|
||||||
)
|
)
|
||||||
|
|
||||||
stream = True
|
|
||||||
if stream:
|
if stream:
|
||||||
for chunk in completion:
|
for chunk in completion:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
if reference and chunk.choices[0].finish_reason == "stop":
|
||||||
|
print(f"Reference:\n{chunk.choices[0].delta.reference}")
|
||||||
|
print(f"Final content:\n{chunk.choices[0].delta.final_content}")
|
||||||
else:
|
else:
|
||||||
print(completion.choices[0].message.content)
|
print(completion.choices[0].message.content)
|
||||||
|
if reference:
|
||||||
|
print(completion.choices[0].message.reference)
|
||||||
"""
|
"""
|
||||||
req = request.json
|
req = request.get_json()
|
||||||
|
|
||||||
|
need_reference = bool(req.get("reference", False))
|
||||||
|
|
||||||
messages = req.get("messages", [])
|
messages = req.get("messages", [])
|
||||||
# To prevent empty [] input
|
# To prevent empty [] input
|
||||||
@ -261,9 +278,23 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
token_used = 0
|
token_used = 0
|
||||||
answer_cache = ""
|
answer_cache = ""
|
||||||
reasoning_cache = ""
|
reasoning_cache = ""
|
||||||
|
last_ans = {}
|
||||||
response = {
|
response = {
|
||||||
"id": f"chatcmpl-{chat_id}",
|
"id": f"chatcmpl-{chat_id}",
|
||||||
"choices": [{"delta": {"content": "", "role": "assistant", "function_call": None, "tool_calls": None, "reasoning_content": ""}, "finish_reason": None, "index": 0, "logprobs": None}],
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "",
|
||||||
|
"role": "assistant",
|
||||||
|
"function_call": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
"reasoning_content": "",
|
||||||
|
},
|
||||||
|
"finish_reason": None,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
"created": int(time.time()),
|
"created": int(time.time()),
|
||||||
"model": "model",
|
"model": "model",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -272,7 +303,8 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools):
|
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||||
|
last_ans = ans
|
||||||
answer = ans["answer"]
|
answer = ans["answer"]
|
||||||
|
|
||||||
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
||||||
@ -324,6 +356,9 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||||
response["choices"][0]["finish_reason"] = "stop"
|
response["choices"][0]["finish_reason"] = "stop"
|
||||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
|
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
|
||||||
|
if need_reference:
|
||||||
|
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
|
||||||
|
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
|
||||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||||
yield "data:[DONE]\n\n"
|
yield "data:[DONE]\n\n"
|
||||||
|
|
||||||
@ -335,7 +370,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
return resp
|
return resp
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools):
|
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||||
# focus answer content only
|
# focus answer content only
|
||||||
answer = ans
|
answer = ans
|
||||||
break
|
break
|
||||||
@ -356,11 +391,25 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
"rejected_prediction_tokens": 0, # 0 for simplicity
|
"rejected_prediction_tokens": 0, # 0 for simplicity
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": "stop", "index": 0}],
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": content,
|
||||||
|
},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
}
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
if need_reference:
|
||||||
|
response["choices"][0]["message"]["reference"] = chunks_format(answer.get("reference", []))
|
||||||
|
|
||||||
return jsonify(response)
|
return jsonify(response)
|
||||||
|
|
||||||
@manager.route('/agents_openai/<agent_id>/chat/completions', methods=['POST']) # noqa: F821
|
|
||||||
|
@manager.route("/agents_openai/<agent_id>/chat/completions", methods=["POST"]) # noqa: F821
|
||||||
@validate_request("model", "messages") # noqa: F821
|
@validate_request("model", "messages") # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def agents_completion_openai_compatibility(tenant_id, agent_id):
|
def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||||
@ -375,14 +424,16 @@ def agents_completion_openai_compatibility (tenant_id, agent_id):
|
|||||||
filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
|
filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
|
||||||
prompt_tokens = sum(len(tiktokenenc.encode(m["content"])) for m in filtered_messages)
|
prompt_tokens = sum(len(tiktokenenc.encode(m["content"])) for m in filtered_messages)
|
||||||
if not filtered_messages:
|
if not filtered_messages:
|
||||||
return jsonify(get_data_openai(
|
return jsonify(
|
||||||
|
get_data_openai(
|
||||||
id=agent_id,
|
id=agent_id,
|
||||||
content="No valid messages found (user or assistant).",
|
content="No valid messages found (user or assistant).",
|
||||||
finish_reason="stop",
|
finish_reason="stop",
|
||||||
model=req.get("model", ""),
|
model=req.get("model", ""),
|
||||||
completion_tokens=len(tiktokenenc.encode("No valid messages found (user or assistant).")),
|
completion_tokens=len(tiktokenenc.encode("No valid messages found (user or assistant).")),
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Get the last user message as the question
|
# Get the last user message as the question
|
||||||
question = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")
|
question = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")
|
||||||
@ -577,19 +628,13 @@ def delete(tenant_id, chat_id):
|
|||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
return get_result(
|
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
||||||
data={"success_count": success_count, "errors": errors},
|
|
||||||
message=f"Partially deleted {success_count} sessions with {len(errors)} errors"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return get_error_data_result(message="; ".join(errors))
|
return get_error_data_result(message="; ".join(errors))
|
||||||
|
|
||||||
if duplicate_messages:
|
if duplicate_messages:
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
return get_result(
|
return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
|
||||||
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
|
|
||||||
data={"success_count": success_count, "errors": duplicate_messages}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return get_error_data_result(message=";".join(duplicate_messages))
|
return get_error_data_result(message=";".join(duplicate_messages))
|
||||||
|
|
||||||
@ -635,19 +680,13 @@ def delete_agent_session(tenant_id, agent_id):
|
|||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
return get_result(
|
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
||||||
data={"success_count": success_count, "errors": errors},
|
|
||||||
message=f"Partially deleted {success_count} sessions with {len(errors)} errors"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return get_error_data_result(message="; ".join(errors))
|
return get_error_data_result(message="; ".join(errors))
|
||||||
|
|
||||||
if duplicate_messages:
|
if duplicate_messages:
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
return get_result(
|
return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
|
||||||
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
|
|
||||||
data={"success_count": success_count, "errors": duplicate_messages}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return get_error_data_result(message=";".join(duplicate_messages))
|
return get_error_data_result(message=";".join(duplicate_messages))
|
||||||
|
|
||||||
|
|||||||
@ -69,21 +69,31 @@ from openai import OpenAI
|
|||||||
model = "model"
|
model = "model"
|
||||||
client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/<chat_id>")
|
client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/<chat_id>")
|
||||||
|
|
||||||
|
stream = True
|
||||||
|
reference = True
|
||||||
|
|
||||||
completion = client.chat.completions.create(
|
completion = client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Who are you?"},
|
{"role": "user", "content": "Who are you?"},
|
||||||
|
{"role": "assistant", "content": "I am an AI assistant named..."},
|
||||||
|
{"role": "user", "content": "Can you tell me how to install neovim"},
|
||||||
],
|
],
|
||||||
stream=True
|
stream=stream,
|
||||||
|
extra_body={"reference": reference}
|
||||||
)
|
)
|
||||||
|
|
||||||
stream = True
|
|
||||||
if stream:
|
if stream:
|
||||||
for chunk in completion:
|
for chunk in completion:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
if reference and chunk.choices[0].finish_reason == "stop":
|
||||||
|
print(f"Reference:\n{chunk.choices[0].delta.reference}")
|
||||||
|
print(f"Final content:\n{chunk.choices[0].delta.final_content}")
|
||||||
else:
|
else:
|
||||||
print(completion.choices[0].message.content)
|
print(completion.choices[0].message.content)
|
||||||
|
if reference:
|
||||||
|
print(completion.choices[0].message.reference)
|
||||||
```
|
```
|
||||||
|
|
||||||
## DATASET MANAGEMENT
|
## DATASET MANAGEMENT
|
||||||
|
|||||||
Reference in New Issue
Block a user