diff --git a/test/testcases/test_http_api/common.py b/test/testcases/test_http_api/common.py index dba320d98..6810ca647 100644 --- a/test/testcases/test_http_api/common.py +++ b/test/testcases/test_http_api/common.py @@ -28,6 +28,8 @@ CHUNK_API_URL = f"/api/{VERSION}/datasets/{{dataset_id}}/documents/{{document_id CHAT_ASSISTANT_API_URL = f"/api/{VERSION}/chats" SESSION_WITH_CHAT_ASSISTANT_API_URL = f"/api/{VERSION}/chats/{{chat_id}}/sessions" SESSION_WITH_AGENT_API_URL = f"/api/{VERSION}/agents/{{agent_id}}/sessions" +AGENT_API_URL = f"/api/{VERSION}/agents" +RETRIEVAL_API_URL = f"/api/{VERSION}/retrieval" # DATASET MANAGEMENT @@ -170,7 +172,7 @@ def delete_chunks(auth, dataset_id, document_id, payload=None): def retrieval_chunks(auth, payload=None): - url = f"{HOST_ADDRESS}/api/v1/retrieval" + url = f"{HOST_ADDRESS}{RETRIEVAL_API_URL}" res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() @@ -237,6 +239,8 @@ def update_session_with_chat_assistant(auth, chat_assistant_id, session_id, payl def delete_session_with_chat_assistants(auth, chat_assistant_id, payload=None): url = f"{HOST_ADDRESS}{SESSION_WITH_CHAT_ASSISTANT_API_URL}".format(chat_id=chat_assistant_id) + if payload is None: + payload = {} res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() @@ -247,3 +251,107 @@ def batch_add_sessions_with_chat_assistant(auth, chat_assistant_id, num): res = create_session_with_chat_assistant(auth, chat_assistant_id, {"name": f"session_with_chat_assistant_{i}"}) session_ids.append(res["data"]["id"]) return session_ids + + +# DATASET GRAPH AND TASKS +def knowledge_graph(auth, dataset_id, params=None): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/knowledge_graph" + res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) + return res.json() + + +def delete_knowledge_graph(auth, dataset_id, payload=None): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/knowledge_graph" + if payload is None: + res = requests.delete(url=url, headers=HEADERS, auth=auth) + else: + res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def run_graphrag(auth, dataset_id, payload=None): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/run_graphrag" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def trace_graphrag(auth, dataset_id, params=None): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/trace_graphrag" + res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) + return res.json() + + +def run_raptor(auth, dataset_id, payload=None): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/run_raptor" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def trace_raptor(auth, dataset_id, params=None): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/trace_raptor" + res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) + return res.json() + + +def metadata_summary(auth, dataset_id, params=None): + url = f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}/metadata/summary" + res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) + return res.json() + + +# CHAT COMPLETIONS AND RELATED QUESTIONS +def chat_completions(auth, chat_assistant_id, payload=None): + url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}/{chat_assistant_id}/completions" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def related_questions(auth, payload=None): + url = f"{HOST_ADDRESS}/api/{VERSION}/sessions/related_questions" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +# AGENT MANAGEMENT AND SESSIONS +def create_agent(auth, payload=None): + url = f"{HOST_ADDRESS}{AGENT_API_URL}" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def list_agents(auth, params=None): + url = f"{HOST_ADDRESS}{AGENT_API_URL}" + res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) + return res.json() + + +def delete_agent(auth, agent_id): + url = f"{HOST_ADDRESS}{AGENT_API_URL}/{agent_id}" + res = requests.delete(url=url, headers=HEADERS, auth=auth) + return res.json() + + +def create_agent_session(auth, agent_id, payload=None, params=None): + url = f"{HOST_ADDRESS}{SESSION_WITH_AGENT_API_URL}".format(agent_id=agent_id) + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload, params=params) + return res.json() + + +def list_agent_sessions(auth, agent_id, params=None): + url = f"{HOST_ADDRESS}{SESSION_WITH_AGENT_API_URL}".format(agent_id=agent_id) + res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) + return res.json() + + +def delete_agent_sessions(auth, agent_id, payload=None): + url = f"{HOST_ADDRESS}{SESSION_WITH_AGENT_API_URL}".format(agent_id=agent_id) + if payload is None: + payload = {} + res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def agent_completions(auth, agent_id, payload=None): + url = f"{HOST_ADDRESS}{AGENT_API_URL}/{agent_id}/completions" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() diff --git a/test/testcases/test_http_api/test_dataset_management/test_graphrag_tasks.py b/test/testcases/test_http_api/test_dataset_management/test_graphrag_tasks.py new file mode 100644 index 000000000..a805be9a6 --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_management/test_graphrag_tasks.py @@ -0,0 +1,89 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import bulk_upload_documents, list_documents, parse_documents, run_graphrag, trace_graphrag +from utils import wait_for + + +@wait_for(200, 1, "Document parsing timeout") +def _parse_done(auth, dataset_id, document_ids=None): + res = list_documents(auth, dataset_id) + target_docs = res["data"]["docs"] + if document_ids is None: + return all(doc.get("run") == "DONE" for doc in target_docs) + target_ids = set(document_ids) + for doc in target_docs: + if doc.get("id") in target_ids and doc.get("run") != "DONE": + return False + return True + + +class TestGraphRAGTasks: + @pytest.mark.p2 + def test_trace_graphrag_before_run(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = trace_graphrag(HttpApiAuth, dataset_id) + assert res["code"] == 0, res + assert res["data"] == {}, res + + @pytest.mark.p2 + def test_run_graphrag_no_documents(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = run_graphrag(HttpApiAuth, dataset_id) + assert res["code"] == 102, res + assert "No documents in Dataset" in res.get("message", ""), res + + @pytest.mark.p3 + def test_run_graphrag_returns_task_id(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) + res = run_graphrag(HttpApiAuth, dataset_id) + assert res["code"] == 0, res + assert res["data"].get("graphrag_task_id"), res + + @pytest.mark.p3 + def test_trace_graphrag_until_complete(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) + res = parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == 0, res + _parse_done(HttpApiAuth, dataset_id, document_ids) + + res = run_graphrag(HttpApiAuth, dataset_id) + assert res["code"] == 0, res + + last_res = {} + + @wait_for(200, 1, "GraphRAG task timeout") + def condition(): + res = trace_graphrag(HttpApiAuth, dataset_id) + if res["code"] != 0: + return False + data = res.get("data") or {} + if not data: + return False + if data.get("task_type") != "graphrag": + return False + progress = data.get("progress") + if progress in (-1, 1, -1.0, 1.0): + last_res["res"] = res + return True + return False + + condition() + res = last_res["res"] + assert res["data"]["task_type"] == "graphrag", res + assert res["data"].get("progress") in (-1, 1, -1.0, 1.0), res diff --git a/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py b/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py new file mode 100644 index 000000000..61be5881d --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py @@ -0,0 +1,53 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import delete_knowledge_graph, knowledge_graph +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 0, "Authorization"), + (RAGFlowHttpApiAuth(INVALID_API_TOKEN), 109, "API key is invalid"), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = knowledge_graph(invalid_auth, "dataset_id") + assert res["code"] == expected_code + assert expected_message in res.get("message", "") + + +class TestKnowledgeGraph: + @pytest.mark.p2 + def test_get_knowledge_graph_empty(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = knowledge_graph(HttpApiAuth, dataset_id) + assert res["code"] == 0, res + assert "graph" in res["data"], res + assert "mind_map" in res["data"], res + assert isinstance(res["data"]["graph"], dict), res + assert isinstance(res["data"]["mind_map"], dict), res + + @pytest.mark.p2 + def test_delete_knowledge_graph(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = delete_knowledge_graph(HttpApiAuth, dataset_id) + assert res["code"] == 0, res + assert res["data"] is True, res diff --git a/test/testcases/test_http_api/test_dataset_management/test_raptor_tasks.py b/test/testcases/test_http_api/test_dataset_management/test_raptor_tasks.py new file mode 100644 index 000000000..6358fc266 --- /dev/null +++ b/test/testcases/test_http_api/test_dataset_management/test_raptor_tasks.py @@ -0,0 +1,89 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import bulk_upload_documents, list_documents, parse_documents, run_raptor, trace_raptor +from utils import wait_for + + +@wait_for(200, 1, "Document parsing timeout") +def _parse_done(auth, dataset_id, document_ids=None): + res = list_documents(auth, dataset_id) + target_docs = res["data"]["docs"] + if document_ids is None: + return all(doc.get("run") == "DONE" for doc in target_docs) + target_ids = set(document_ids) + for doc in target_docs: + if doc.get("id") in target_ids and doc.get("run") != "DONE": + return False + return True + + +class TestRaptorTasks: + @pytest.mark.p2 + def test_trace_raptor_before_run(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = trace_raptor(HttpApiAuth, dataset_id) + assert res["code"] == 0, res + assert res["data"] == {}, res + + @pytest.mark.p2 + def test_run_raptor_no_documents(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + res = run_raptor(HttpApiAuth, dataset_id) + assert res["code"] == 102, res + assert "No documents in Dataset" in res.get("message", ""), res + + @pytest.mark.p3 + def test_run_raptor_returns_task_id(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) + res = run_raptor(HttpApiAuth, dataset_id) + assert res["code"] == 0, res + assert res["data"].get("raptor_task_id"), res + + @pytest.mark.p3 + def test_trace_raptor_until_complete(self, HttpApiAuth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func + document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) + res = parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == 0, res + _parse_done(HttpApiAuth, dataset_id, document_ids) + + res = run_raptor(HttpApiAuth, dataset_id) + assert res["code"] == 0, res + + last_res = {} + + @wait_for(200, 1, "RAPTOR task timeout") + def condition(): + res = trace_raptor(HttpApiAuth, dataset_id) + if res["code"] != 0: + return False + data = res.get("data") or {} + if not data: + return False + if data.get("task_type") != "raptor": + return False + progress = data.get("progress") + if progress in (-1, 1, -1.0, 1.0): + last_res["res"] = res + return True + return False + + condition() + res = last_res["res"] + assert res["data"]["task_type"] == "raptor", res + assert res["data"].get("progress") in (-1, 1, -1.0, 1.0), res diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_summary.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_summary.py new file mode 100644 index 000000000..6466c24ce --- /dev/null +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_metadata_summary.py @@ -0,0 +1,52 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Although the docs group this under "chunk management," the backend aggregates +# Document.meta_fields via document_service#get_metadata_summary and the test +# uses update_document, so it belongs with file/document management tests. +import pytest +from common import metadata_summary, update_document + + +def _summary_to_counts(summary): + counts = {} + for key, pairs in summary.items(): + counts[key] = {str(k): v for k, v in pairs} + return counts + + +class TestMetadataSummary: + @pytest.mark.p2 + def test_metadata_summary_counts(self, HttpApiAuth, add_documents_func): + dataset_id, document_ids = add_documents_func + payloads = [ + {"tags": ["foo", "bar"], "author": "alice"}, + {"tags": ["foo"], "author": "bob"}, + {"tags": ["bar", "baz"], "author": None}, + ] + for doc_id, meta_fields in zip(document_ids, payloads): + res = update_document(HttpApiAuth, dataset_id, doc_id, {"meta_fields": meta_fields}) + assert res["code"] == 0, res + + res = metadata_summary(HttpApiAuth, dataset_id) + assert res["code"] == 0, res + summary = res["data"]["summary"] + counts = _summary_to_counts(summary) + assert counts["tags"]["foo"] == 2, counts + assert counts["tags"]["bar"] == 2, counts + assert counts["tags"]["baz"] == 1, counts + assert counts["author"]["alice"] == 1, counts + assert counts["author"]["bob"] == 1, counts + assert "None" not in counts["author"], counts diff --git a/test/testcases/test_http_api/test_session_management/test_agent_completions.py b/test/testcases/test_http_api/test_session_management/test_agent_completions.py new file mode 100644 index 000000000..e34cc21ec --- /dev/null +++ b/test/testcases/test_http_api/test_session_management/test_agent_completions.py @@ -0,0 +1,96 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import ( + agent_completions, + create_agent, + create_agent_session, + delete_agent, + delete_agent_sessions, + list_agents, +) + +AGENT_TITLE = "test_agent_http" +MINIMAL_DSL = { + "components": { + "begin": { + "obj": {"component_name": "Begin", "params": {}}, + "downstream": ["message"], + "upstream": [], + }, + "message": { + "obj": {"component_name": "Message", "params": {"content": ["{sys.query}"]}}, + "downstream": [], + "upstream": ["begin"], + }, + }, + "history": [], + "retrieval": [], + "path": [], + "globals": { + "sys.query": "", + "sys.user_id": "", + "sys.conversation_turns": 0, + "sys.files": [], + }, + "variables": {}, +} + +@pytest.fixture(scope="function") +def agent_id(HttpApiAuth, request): + res = list_agents(HttpApiAuth, {"page_size": 1000}) + assert res["code"] == 0, res + for agent in res.get("data", []): + if agent.get("title") == AGENT_TITLE: + delete_agent(HttpApiAuth, agent["id"]) + + res = create_agent(HttpApiAuth, {"title": AGENT_TITLE, "dsl": MINIMAL_DSL}) + assert res["code"] == 0, res + res = list_agents(HttpApiAuth, {"title": AGENT_TITLE}) + assert res["code"] == 0, res + assert res.get("data"), res + agent_id = res["data"][0]["id"] + + def cleanup(): + delete_agent_sessions(HttpApiAuth, agent_id) + delete_agent(HttpApiAuth, agent_id) + + request.addfinalizer(cleanup) + return agent_id + + +class TestAgentCompletions: + @pytest.mark.p2 + def test_agent_completion_stream_false(self, HttpApiAuth, agent_id): + res = create_agent_session(HttpApiAuth, agent_id, payload={}) + assert res["code"] == 0, res + session_id = res["data"]["id"] + + res = agent_completions( + HttpApiAuth, + agent_id, + {"question": "hello", "stream": False, "session_id": session_id}, + ) + assert res["code"] == 0, res + if isinstance(res["data"], dict): + assert isinstance(res["data"].get("data"), dict), res + content = res["data"]["data"].get("content", "") + assert content, res + assert "hello" in content, res + assert res["data"].get("session_id") == session_id, res + else: + assert isinstance(res["data"], str), res + assert res["data"].startswith("**ERROR**"), res diff --git a/test/testcases/test_http_api/test_session_management/test_agent_sessions.py b/test/testcases/test_http_api/test_session_management/test_agent_sessions.py new file mode 100644 index 000000000..6f1d65fa5 --- /dev/null +++ b/test/testcases/test_http_api/test_session_management/test_agent_sessions.py @@ -0,0 +1,89 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import ( + create_agent, + create_agent_session, + delete_agent, + delete_agent_sessions, + list_agent_sessions, + list_agents, +) + +AGENT_TITLE = "test_agent_http" +MINIMAL_DSL = { + "components": { + "begin": { + "obj": {"component_name": "Begin", "params": {}}, + "downstream": ["message"], + "upstream": [], + }, + "message": { + "obj": {"component_name": "Message", "params": {"content": ["{sys.query}"]}}, + "downstream": [], + "upstream": ["begin"], + }, + }, + "history": [], + "retrieval": [], + "path": [], + "globals": { + "sys.query": "", + "sys.user_id": "", + "sys.conversation_turns": 0, + "sys.files": [], + }, + "variables": {}, +} + +@pytest.fixture(scope="function") +def agent_id(HttpApiAuth, request): + res = list_agents(HttpApiAuth, {"page_size": 1000}) + assert res["code"] == 0, res + for agent in res.get("data", []): + if agent.get("title") == AGENT_TITLE: + delete_agent(HttpApiAuth, agent["id"]) + + res = create_agent(HttpApiAuth, {"title": AGENT_TITLE, "dsl": MINIMAL_DSL}) + assert res["code"] == 0, res + res = list_agents(HttpApiAuth, {"title": AGENT_TITLE}) + assert res["code"] == 0, res + assert res.get("data"), res + agent_id = res["data"][0]["id"] + + def cleanup(): + delete_agent_sessions(HttpApiAuth, agent_id) + delete_agent(HttpApiAuth, agent_id) + + request.addfinalizer(cleanup) + return agent_id + + +class TestAgentSessions: + @pytest.mark.p2 + def test_create_list_delete_agent_sessions(self, HttpApiAuth, agent_id): + res = create_agent_session(HttpApiAuth, agent_id, payload={}) + assert res["code"] == 0, res + session_id = res["data"]["id"] + assert res["data"]["agent_id"] == agent_id, res + + res = list_agent_sessions(HttpApiAuth, agent_id, params={"id": session_id}) + assert res["code"] == 0, res + assert len(res["data"]) == 1, res + assert res["data"][0]["id"] == session_id, res + + res = delete_agent_sessions(HttpApiAuth, agent_id, {"ids": [session_id]}) + assert res["code"] == 0, res diff --git a/test/testcases/test_http_api/test_session_management/test_chat_completions.py b/test/testcases/test_http_api/test_session_management/test_chat_completions.py new file mode 100644 index 000000000..fa2e225ca --- /dev/null +++ b/test/testcases/test_http_api/test_session_management/test_chat_completions.py @@ -0,0 +1,122 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import ( + bulk_upload_documents, + chat_completions, + create_chat_assistant, + create_session_with_chat_assistant, + delete_chat_assistants, + delete_session_with_chat_assistants, + list_documents, + parse_documents, +) +from utils import wait_for + + +@wait_for(200, 1, "Document parsing timeout") +def _parse_done(auth, dataset_id, document_ids=None): + res = list_documents(auth, dataset_id) + target_docs = res["data"]["docs"] + if document_ids is None: + return all(doc.get("run") == "DONE" for doc in target_docs) + target_ids = set(document_ids) + for doc in target_docs: + if doc.get("id") in target_ids and doc.get("run") != "DONE": + return False + return True + + +class TestChatCompletions: + @pytest.mark.p3 + def test_chat_completion_stream_false_with_session(self, HttpApiAuth, add_dataset_func, tmp_path, request): + dataset_id = add_dataset_func + document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, 1, tmp_path) + res = parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == 0, res + _parse_done(HttpApiAuth, dataset_id, document_ids) + + res = create_chat_assistant(HttpApiAuth, {"name": "chat_completion_test", "dataset_ids": [dataset_id]}) + assert res["code"] == 0, res + chat_id = res["data"]["id"] + request.addfinalizer(lambda: delete_session_with_chat_assistants(HttpApiAuth, chat_id)) + request.addfinalizer(lambda: delete_chat_assistants(HttpApiAuth)) + + res = create_session_with_chat_assistant(HttpApiAuth, chat_id, {"name": "session_for_completion"}) + assert res["code"] == 0, res + session_id = res["data"]["id"] + + res = chat_completions( + HttpApiAuth, + chat_id, + {"question": "hello", "stream": False, "session_id": session_id}, + ) + assert res["code"] == 0, res + assert isinstance(res["data"], dict), res + for key in ["answer", "reference", "audio_binary", "id", "session_id"]: + assert key in res["data"], res + assert res["data"]["session_id"] == session_id, res + + @pytest.mark.p2 + def test_chat_completion_invalid_chat(self, HttpApiAuth): + res = chat_completions( + HttpApiAuth, + "invalid_chat_id", + {"question": "hello", "stream": False, "session_id": "invalid_session"}, + ) + assert res["code"] == 102, res + assert "You don't own the chat" in res.get("message", ""), res + + @pytest.mark.p2 + def test_chat_completion_invalid_session(self, HttpApiAuth, request): + res = create_chat_assistant(HttpApiAuth, {"name": "chat_completion_invalid_session", "dataset_ids": []}) + assert res["code"] == 0, res + chat_id = res["data"]["id"] + request.addfinalizer(lambda: delete_session_with_chat_assistants(HttpApiAuth, chat_id)) + request.addfinalizer(lambda: delete_chat_assistants(HttpApiAuth)) + + res = chat_completions( + HttpApiAuth, + chat_id, + {"question": "hello", "stream": False, "session_id": "invalid_session"}, + ) + assert res["code"] == 102, res + assert "You don't own the session" in res.get("message", ""), res + + @pytest.mark.p2 + def test_chat_completion_invalid_metadata_condition(self, HttpApiAuth, request): + res = create_chat_assistant(HttpApiAuth, {"name": "chat_completion_invalid_meta", "dataset_ids": []}) + assert res["code"] == 0, res + chat_id = res["data"]["id"] + request.addfinalizer(lambda: delete_session_with_chat_assistants(HttpApiAuth, chat_id)) + request.addfinalizer(lambda: delete_chat_assistants(HttpApiAuth)) + + res = create_session_with_chat_assistant(HttpApiAuth, chat_id, {"name": "session_for_meta"}) + assert res["code"] == 0, res + session_id = res["data"]["id"] + + res = chat_completions( + HttpApiAuth, + chat_id, + { + "question": "hello", + "stream": False, + "session_id": session_id, + "metadata_condition": "invalid", + }, + ) + assert res["code"] == 102, res + assert "metadata_condition" in res.get("message", ""), res diff --git a/test/testcases/test_http_api/test_session_management/test_related_questions.py b/test/testcases/test_http_api/test_session_management/test_related_questions.py new file mode 100644 index 000000000..427708b27 --- /dev/null +++ b/test/testcases/test_http_api/test_session_management/test_related_questions.py @@ -0,0 +1,39 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from common import related_questions +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowHttpApiAuth + + +class TestRelatedQuestions: + @pytest.mark.p3 + def test_related_questions_success(self, HttpApiAuth): + res = related_questions(HttpApiAuth, {"question": "ragflow", "industry": "search"}) + assert res["code"] == 0, res + assert isinstance(res.get("data"), list), res + + @pytest.mark.p2 + def test_related_questions_missing_question(self, HttpApiAuth): + res = related_questions(HttpApiAuth, {"industry": "search"}) + assert res["code"] == 102, res + assert "question" in res.get("message", ""), res + + @pytest.mark.p2 + def test_related_questions_invalid_auth(self): + res = related_questions(RAGFlowHttpApiAuth(INVALID_API_TOKEN), {"question": "ragflow", "industry": "search"}) + assert res["code"] == 109, res + assert "API key is invalid" in res.get("message", ""), res