diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 0468381d3..3b10ff3d8 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -18,13 +18,14 @@ import copy import re import time -import tiktoken import os import tempfile import logging from quart import Response, jsonify, request +from common.token_utils import num_tokens_from_string + from agent.canvas import Canvas from api.db.db_models import APIToken from api.db.services.api_service import API4ConversationService @@ -265,7 +266,7 @@ async def chat_completion_openai_like(tenant_id, chat_id): prompt = messages[-1]["content"] # Treat context tokens as reasoning tokens - context_token_used = sum(len(message["content"]) for message in messages) + context_token_used = sum(num_tokens_from_string(message["content"]) for message in messages) dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value) if not dia: @@ -358,7 +359,7 @@ async def chat_completion_openai_like(tenant_id, chat_id): delta = ans.get("answer") or "" if not delta: continue - token_used += len(delta) + token_used += num_tokens_from_string(delta) if in_think: full_reasoning += delta response["choices"][0]["delta"]["reasoning_content"] = delta @@ -376,7 +377,8 @@ async def chat_completion_openai_like(tenant_id, chat_id): response["choices"][0]["delta"]["content"] = None response["choices"][0]["delta"]["reasoning_content"] = None response["choices"][0]["finish_reason"] = "stop" - response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used} + prompt_tokens = num_tokens_from_string(prompt) + response["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": token_used, "total_tokens": prompt_tokens + token_used} if need_reference: reference_payload = final_reference if final_reference is not None else last_ans.get("reference", []) response["choices"][0]["delta"]["reference"] = chunks_format(reference_payload) @@ -407,12 +409,12 @@ async def chat_completion_openai_like(tenant_id, chat_id): "created": int(time.time()), "model": req.get("model", ""), "usage": { - "prompt_tokens": len(prompt), - "completion_tokens": len(content), - "total_tokens": len(prompt) + len(content), + "prompt_tokens": num_tokens_from_string(prompt), + "completion_tokens": num_tokens_from_string(content), + "total_tokens": num_tokens_from_string(prompt) + num_tokens_from_string(content), "completion_tokens_details": { "reasoning_tokens": context_token_used, - "accepted_prediction_tokens": len(content), + "accepted_prediction_tokens": num_tokens_from_string(content), "rejected_prediction_tokens": 0, # 0 for simplicity }, }, @@ -439,7 +441,6 @@ async def chat_completion_openai_like(tenant_id, chat_id): @token_required async def agents_completion_openai_compatibility(tenant_id, agent_id): req = await get_request_json() - tiktoken_encode = tiktoken.get_encoding("cl100k_base") messages = req.get("messages", []) if not messages: return get_error_data_result("You must provide at least one message.") @@ -447,7 +448,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id): return get_error_data_result(f"You don't own the agent {agent_id}") filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]] - prompt_tokens = sum(len(tiktoken_encode.encode(m["content"])) for m in filtered_messages) + prompt_tokens = sum(num_tokens_from_string(m["content"]) for m in filtered_messages) if not filtered_messages: return jsonify( get_data_openai( @@ -455,7 +456,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id): content="No valid messages found (user or assistant).", finish_reason="stop", model=req.get("model", ""), - completion_tokens=len(tiktoken_encode.encode("No valid messages found (user or assistant).")), + completion_tokens=num_tokens_from_string("No valid messages found (user or assistant)."), prompt_tokens=prompt_tokens, ) ) diff --git a/test/testcases/test_http_api/common.py b/test/testcases/test_http_api/common.py index 7e1d9927a..c1567f574 100644 --- a/test/testcases/test_http_api/common.py +++ b/test/testcases/test_http_api/common.py @@ -374,3 +374,22 @@ def chat_completions(auth, chat_id, payload=None): url = f"{HOST_ADDRESS}/api/{VERSION}/chats/{chat_id}/completions" res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() + + +def chat_completions_openai(auth, chat_id, payload=None): + """ + Send a request to the OpenAI-compatible chat completions endpoint. + + Args: + auth: Authentication object + chat_id: Chat assistant ID + payload: Dictionary in OpenAI chat completions format containing: + - messages: list (required) - List of message objects with 'role' and 'content' + - stream: bool (optional) - Whether to stream responses, default False + + Returns: + Response JSON in OpenAI chat completions format with usage information + """ + url = f"{HOST_ADDRESS}/api/{VERSION}/chats_openai/{chat_id}/chat/completions" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() diff --git a/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py b/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py new file mode 100644 index 000000000..e126119ad --- /dev/null +++ b/test/testcases/test_http_api/test_session_management/test_chat_completions_openai.py @@ -0,0 +1,132 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import ( + bulk_upload_documents, + chat_completions_openai, + create_chat_assistant, + delete_chat_assistants, + list_documents, + parse_documents, +) +from utils import wait_for + + +@wait_for(200, 1, "Document parsing timeout") +def _parse_done(auth, dataset_id, document_ids=None): + res = list_documents(auth, dataset_id) + target_docs = res["data"]["docs"] + if document_ids is None: + return all(doc.get("run") == "DONE" for doc in target_docs) + target_ids = set(document_ids) + for doc in target_docs: + if doc.get("id") in target_ids and doc.get("run") != "DONE": + return False + return True + + +class TestChatCompletionsOpenAI: + """Test cases for the OpenAI-compatible chat completions endpoint""" + + @pytest.mark.p2 + def test_openai_chat_completion_non_stream(self, HttpApiAuth, add_dataset_func, tmp_path, request): + """Test OpenAI-compatible endpoint returns proper response with token usage""" + dataset_id = add_dataset_func + document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) + res = parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == 0, res + _parse_done(HttpApiAuth, dataset_id, document_ids) + + res = create_chat_assistant(HttpApiAuth, {"name": "openai_endpoint_test", "dataset_ids": [dataset_id]}) + assert res["code"] == 0, res + chat_id = res["data"]["id"] + request.addfinalizer(lambda: delete_chat_assistants(HttpApiAuth)) + + res = chat_completions_openai( + HttpApiAuth, + chat_id, + { + "model": "model", # Required by OpenAI-compatible API, value is ignored by RAGFlow + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + ) + + # Verify OpenAI-compatible response structure + assert "choices" in res, f"Response should contain 'choices': {res}" + assert len(res["choices"]) > 0, f"'choices' should not be empty: {res}" + assert "message" in res["choices"][0], f"Choice should contain 'message': {res}" + assert "content" in res["choices"][0]["message"], f"Message should contain 'content': {res}" + + # Verify token usage is present and uses actual token counts (not character counts) + assert "usage" in res, f"Response should contain 'usage': {res}" + usage = res["usage"] + assert "prompt_tokens" in usage, f"'usage' should contain 'prompt_tokens': {usage}" + assert "completion_tokens" in usage, f"'usage' should contain 'completion_tokens': {usage}" + assert "total_tokens" in usage, f"'usage' should contain 'total_tokens': {usage}" + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"], \ + f"total_tokens should equal prompt_tokens + completion_tokens: {usage}" + + @pytest.mark.p2 + def test_openai_chat_completion_token_count_reasonable(self, HttpApiAuth, add_dataset_func, tmp_path, request): + """Test that token counts are reasonable (using tiktoken, not character counts)""" + dataset_id = add_dataset_func + document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) + res = parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == 0, res + _parse_done(HttpApiAuth, dataset_id, document_ids) + + res = create_chat_assistant(HttpApiAuth, {"name": "openai_token_count_test", "dataset_ids": [dataset_id]}) + assert res["code"] == 0, res + chat_id = res["data"]["id"] + request.addfinalizer(lambda: delete_chat_assistants(HttpApiAuth)) + + # Use a message with known token count + # "hello" is 1 token in cl100k_base encoding + res = chat_completions_openai( + HttpApiAuth, + chat_id, + { + "model": "model", # Required by OpenAI-compatible API, value is ignored by RAGFlow + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + ) + + assert "usage" in res, f"Response should contain 'usage': {res}" + usage = res["usage"] + + # The prompt tokens should be reasonable for the message "hello" plus any system context + # If using len() instead of tiktoken, a short response could have equal or fewer tokens + # than characters, which would be incorrect + # With tiktoken, "hello" = 1 token, so prompt_tokens should include that plus context + assert usage["prompt_tokens"] > 0, f"prompt_tokens should be greater than 0: {usage}" + assert usage["completion_tokens"] > 0, f"completion_tokens should be greater than 0: {usage}" + + @pytest.mark.p2 + def test_openai_chat_completion_invalid_chat(self, HttpApiAuth): + """Test OpenAI endpoint returns error for invalid chat ID""" + res = chat_completions_openai( + HttpApiAuth, + "invalid_chat_id", + { + "model": "model", # Required by OpenAI-compatible API, value is ignored by RAGFlow + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + ) + # Should return an error (format may vary based on implementation) + assert "error" in res or res.get("code") != 0, f"Should return error for invalid chat: {res}"