diff --git a/api/apps/__init__.py b/api/apps/__init__.py index d654f3a5f..bb8fd3e5f 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -63,12 +63,17 @@ login_manager.init_app(app) def search_pages_path(pages_dir): - return [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')] + app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')] + api_path_list = [path for path in pages_dir.glob('*_api.py') if not path.name.startswith('.')] + app_path_list.extend(api_path_list) + return app_path_list def register_page(page_path): - page_name = page_path.stem.rstrip('_app') - module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name, )) + path = f'{page_path}' + + page_name = page_path.stem.rstrip('_api') if "_api" in path else page_path.stem.rstrip('_app') + module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name,)) spec = spec_from_file_location(module_name, page_path) page = module_from_spec(spec) @@ -76,17 +81,17 @@ def register_page(page_path): page.manager = Blueprint(page_name, module_name) sys.modules[module_name] = page spec.loader.exec_module(page) - page_name = getattr(page, 'page_name', page_name) - url_prefix = f'/{API_VERSION}/{page_name}' + url_prefix = f'/api/{API_VERSION}/{page_name}' if "_api" in path else f'/{API_VERSION}/{page_name}' app.register_blueprint(page.manager, url_prefix=url_prefix) + print(f'API file: {page_path}, URL: {url_prefix}') return url_prefix pages_dir = [ Path(__file__).parent, - Path(__file__).parent.parent / 'api' / 'apps', + Path(__file__).parent.parent / 'api' / 'apps', # FIXME: ragflow/api/api/apps, can be remove? ] client_urls_prefix = [ diff --git a/api/apps/dataset_api.py b/api/apps/dataset_api.py new file mode 100644 index 000000000..8d3db6b21 --- /dev/null +++ b/api/apps/dataset_api.py @@ -0,0 +1,142 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +import re +from datetime import datetime, timedelta +from flask import request, Response +from flask_login import login_required, current_user + +from api.db import FileType, ParserType, FileSource, StatusEnum +from api.db.db_models import APIToken, API4Conversation, Task, File +from api.db.services import duplicate_name +from api.db.services.api_service import APITokenService, API4ConversationService +from api.db.services.dialog_service import DialogService, chat +from api.db.services.document_service import DocumentService +from api.db.services.file2document_service import File2DocumentService +from api.db.services.file_service import FileService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.task_service import queue_tasks, TaskService +from api.db.services.user_service import UserTenantService, TenantService +from api.settings import RetCode, retrievaler +from api.utils import get_uuid, current_timestamp, datetime_format +# from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request +from itsdangerous import URLSafeTimedSerializer + +from api.utils.file_utils import filename_type, thumbnail +from rag.utils.minio_conn import MINIO + +# import library +from api.utils.api_utils import construct_json_result, construct_result, construct_error_response, validate_request +from api.contants import NAME_LENGTH_LIMIT + +# ------------------------------ create a dataset --------------------------------------- +@manager.route('/', methods=['POST']) +@login_required # use login +@validate_request("name") # check name key +def create_dataset(): + # Check if Authorization header is present + authorization_token = request.headers.get('Authorization') + if not authorization_token: + return construct_json_result(code=RetCode.AUTHENTICATION_ERROR, message="Authorization header is missing.") + + # TODO: Login or API key + # objs = APIToken.query(token=authorization_token) + # + # # Authorization error + # if not objs: + # return construct_json_result(code=RetCode.AUTHENTICATION_ERROR, message="Token is invalid.") + # + # tenant_id = objs[0].tenant_id + + tenant_id = current_user.id + request_body = request.json + + # In case that there's no name + if "name" not in request_body: + return construct_json_result(code=RetCode.DATA_ERROR, message="Expected 'name' field in request body") + + dataset_name = request_body["name"] + + # empty dataset_name + if not dataset_name: + return construct_json_result(code=RetCode.DATA_ERROR, message="Empty dataset name") + + # In case that there's space in the head or the tail + dataset_name = dataset_name.strip() + + # In case that the length of the name exceeds the limit + dataset_name_length = len(dataset_name) + if dataset_name_length > NAME_LENGTH_LIMIT: + return construct_json_result( + message=f"Dataset name: {dataset_name} with length {dataset_name_length} exceeds {NAME_LENGTH_LIMIT}!") + + # In case that there are other fields in the data-binary + if len(request_body.keys()) > 1: + name_list = [] + for key_name in request_body.keys(): + if key_name != 'name': + name_list.append(key_name) + return construct_json_result(code=RetCode.DATA_ERROR, + message=f"fields: {name_list}, are not allowed in request body.") + + # If there is a duplicate name, it will modify it to make it unique + request_body["name"] = duplicate_name( + KnowledgebaseService.query, + name=dataset_name, + tenant_id=tenant_id, + status=StatusEnum.VALID.value) + try: + request_body["id"] = get_uuid() + request_body["tenant_id"] = tenant_id + request_body["created_by"] = tenant_id + e, t = TenantService.get_by_id(tenant_id) + if not e: + return construct_result(code=RetCode.AUTHENTICATION_ERROR, message="Tenant not found.") + request_body["embd_id"] = t.embd_id + if not KnowledgebaseService.save(**request_body): + # failed to create new dataset + return construct_result() + return construct_json_result(data={"dataset_id": request_body["id"]}) + except Exception as e: + return construct_error_response(e) + + +@manager.route('/', methods=['DELETE']) +@login_required +def remove_dataset(dataset_id): + return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to remove dataset: {dataset_id}") + + +@manager.route('/', methods=['PUT']) +@login_required +@validate_request("name") +def update_dataset(dataset_id): + return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to update dataset: {dataset_id}") + + +@manager.route('/', methods=['GET']) +@login_required +def get_dataset(dataset_id): + return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}") + + +@manager.route('/', methods=['GET']) +@login_required +def list_datasets(): + return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to list datasets") + diff --git a/api/contants.py b/api/contants.py new file mode 100644 index 000000000..61c13ec9e --- /dev/null +++ b/api/contants.py @@ -0,0 +1,16 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +NAME_LENGTH_LIMIT = 2 ** 10 \ No newline at end of file diff --git a/api/settings.py b/api/settings.py index 769ab6cc9..78f3d231a 100644 --- a/api/settings.py +++ b/api/settings.py @@ -239,4 +239,5 @@ class RetCode(IntEnum, CustomEnum): RUNNING = 106 PERMISSION_ERROR = 108 AUTHENTICATION_ERROR = 109 + UNAUTHORIZED = 401 SERVER_ERROR = 500 diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index df8d6dfe1..08b12188d 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -38,7 +38,6 @@ from base64 import b64encode from hmac import HMAC from urllib.parse import quote, urlencode - requests.models.complexjson.dumps = functools.partial( json.dumps, cls=CustomJSONEncoder) @@ -235,3 +234,35 @@ def cors_reponse(retcode=RetCode.SUCCESS, response.headers["Access-Control-Allow-Headers"] = "*" response.headers["Access-Control-Expose-Headers"] = "Authorization" return response + +def construct_result(code=RetCode.DATA_ERROR, message='data is missing'): + import re + result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)} + response = {} + for key, value in result_dict.items(): + if value is None and key != "code": + continue + else: + response[key] = value + return jsonify(response) + + +def construct_json_result(code=RetCode.SUCCESS, message='success', data=None): + if data == None: + return jsonify({"code": code, "message": message}) + else: + return jsonify({"code": code, "message": message, "data": data}) + +def construct_error_response(e): + stat_logger.exception(e) + try: + if e.code == 401: + return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e)) + except BaseException: + pass + if len(e.args) > 1: + return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) + if repr(e).find("index_not_found_exception") >=0: + return construct_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.") + + return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) diff --git a/sdk/python/README.md b/sdk/python/README.md index d4964be69..fa4ec0ba8 100644 --- a/sdk/python/README.md +++ b/sdk/python/README.md @@ -1 +1,41 @@ -# ragflow \ No newline at end of file +# python-ragflow + +# update python client + +- Update "version" field of [project] chapter +- build new python SDK +- upload to pypi.org +- install new python SDK + +# build python SDK + +```shell +rm -f dist/* && python setup.py sdist bdist_wheel +``` + +# install python SDK +```shell +pip uninstall -y ragflow && pip install dist/*.whl +``` + +This will install ragflow-sdk and its dependencies. + +# upload to pypi.org +```shell +twine upload dist/*.whl +``` + +Enter your pypi API token according to the prompt. + +Note that pypi allow a version of a package [be uploaded only once](https://pypi.org/help/#file-name-reuse). You need to change the `version` inside the `pyproject.toml` before build and upload. + +# using + +```python + +``` + +# For developer +```shell +pip install -e . +``` diff --git a/sdk/python/ragflow/dataset.py b/sdk/python/ragflow/dataset.py new file mode 100644 index 000000000..1d4b56cea --- /dev/null +++ b/sdk/python/ragflow/dataset.py @@ -0,0 +1,21 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class DataSet: + def __init__(self, user_key, dataset_url, uuid, name): + self.user_key = user_key + self.dataset_url = dataset_url + self.uuid = uuid + self.name = name \ No newline at end of file diff --git a/sdk/python/ragflow/ragflow.py b/sdk/python/ragflow/ragflow.py index f25ca3fac..ec106687f 100644 --- a/sdk/python/ragflow/ragflow.py +++ b/sdk/python/ragflow/ragflow.py @@ -12,33 +12,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + import os -from abc import ABC import requests +import json - -class RAGFLow(ABC): - def __init__(self, user_key, base_url): +class RAGFLow: + def __init__(self, user_key, base_url, version = 'v1'): + ''' + api_url: http:///api/v1 + dataset_url: http:///api/v1/dataset + ''' self.user_key = user_key - self.base_url = base_url + self.api_url = f"{base_url}/api/{version}" + self.dataset_url = f"{self.api_url}/dataset" + self.authorization_header = {"Authorization": "{}".format(self.user_key)} - def create_dataset(self, name): - return name + def create_dataset(self, dataset_name): + """ + name: dataset name + """ + res = requests.post(url=self.dataset_url, json={"name": dataset_name}, headers=self.authorization_header) + result_dict = json.loads(res.text) + return result_dict - def delete_dataset(self, name): - return name + def delete_dataset(self, dataset_name = None, dataset_id = None): + return dataset_name def list_dataset(self): - endpoint = f"{self.base_url}/api/v1/dataset" - response = requests.get(endpoint) + response = requests.get(self.dataset_url) + print(response) if response.status_code == 200: return response.json()['datasets'] else: return None def get_dataset(self, dataset_id): - endpoint = f"{self.base_url}/api/v1/dataset/{dataset_id}" + endpoint = f"{self.dataset_url}/{dataset_id}" response = requests.get(endpoint) if response.status_code == 200: return response.json() @@ -46,7 +56,7 @@ class RAGFLow(ABC): return None def update_dataset(self, dataset_id, params): - endpoint = f"{self.base_url}/api/v1/dataset/{dataset_id}" + endpoint = f"{self.dataset_url}/{dataset_id}" response = requests.put(endpoint, json=params) if response.status_code == 200: return True diff --git a/sdk/python/test/common.py b/sdk/python/test/common.py new file mode 100644 index 000000000..7187f98ed --- /dev/null +++ b/sdk/python/test/common.py @@ -0,0 +1,4 @@ + + +API_KEY = 'IjJiMTVkZWNhMjU3MzExZWY4YzNiNjQ0OTdkMTllYjM3Ig.ZmQZrA.x9Z7c-1ErBUSL3m8SRtBRgGq5uE' +HOST_ADDRESS = 'http://127.0.0.1:9380' \ No newline at end of file diff --git a/sdk/python/test/test_basic.py b/sdk/python/test/test_basic.py index 8774c5c38..33cff8729 100644 --- a/sdk/python/test/test_basic.py +++ b/sdk/python/test/test_basic.py @@ -3,49 +3,46 @@ import ragflow from ragflow.ragflow import RAGFLow import pytest from unittest.mock import MagicMock +from common import API_KEY, HOST_ADDRESS -class TestCase(TestSdk): - - @pytest.fixture - def ragflow_instance(self): - # Here we create a mock instance of RAGFlow for testing - return ragflow.ragflow.RAGFLow('123', 'url') +class TestBasic(TestSdk): def test_version(self): print(ragflow.__version__) - def test_create_dataset(self): - assert ragflow.ragflow.RAGFLow('123', 'url').create_dataset('abc') == 'abc' - - def test_delete_dataset(self): - assert ragflow.ragflow.RAGFLow('123', 'url').delete_dataset('abc') == 'abc' - - def test_list_dataset_success(self, ragflow_instance, monkeypatch): - # Mocking the response of requests.get method - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = {'datasets': [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}]} - - # Patching requests.get to return the mock_response - monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response)) - - # Call the method under test - result = ragflow_instance.list_dataset() - - # Assertion - assert result == [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}] - - def test_list_dataset_failure(self, ragflow_instance, monkeypatch): - # Mocking the response of requests.get method - mock_response = MagicMock() - mock_response.status_code = 404 # Simulating a failed request - - # Patching requests.get to return the mock_response - monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response)) - - # Call the method under test - result = ragflow_instance.list_dataset() - - # Assertion - assert result is None + # def test_create_dataset(self): + # res = RAGFLow(API_KEY, HOST_ADDRESS).create_dataset('abc') + # print(res) + # + # def test_delete_dataset(self): + # assert RAGFLow('123', 'url').delete_dataset('abc') == 'abc' + # + # def test_list_dataset_success(self, ragflow_instance, monkeypatch): + # # Mocking the response of requests.get method + # mock_response = MagicMock() + # mock_response.status_code = 200 + # mock_response.json.return_value = {'datasets': [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}]} + # + # # Patching requests.get to return the mock_response + # monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response)) + # + # # Call the method under test + # result = ragflow_instance.list_dataset() + # + # # Assertion + # assert result == [{'id': 1, 'name': 'dataset1'}, {'id': 2, 'name': 'dataset2'}] + # + # def test_list_dataset_failure(self, ragflow_instance, monkeypatch): + # # Mocking the response of requests.get method + # mock_response = MagicMock() + # mock_response.status_code = 404 # Simulating a failed request + # + # # Patching requests.get to return the mock_response + # monkeypatch.setattr("requests.get", MagicMock(return_value=mock_response)) + # + # # Call the method under test + # result = ragflow_instance.list_dataset() + # + # # Assertion + # assert result is None diff --git a/sdk/python/test/test_dataset.py b/sdk/python/test/test_dataset.py new file mode 100644 index 000000000..868eddcd1 --- /dev/null +++ b/sdk/python/test/test_dataset.py @@ -0,0 +1,26 @@ +from test_sdkbase import TestSdk +import ragflow +from ragflow.ragflow import RAGFLow +import pytest +from unittest.mock import MagicMock +from common import API_KEY, HOST_ADDRESS + +class TestDataset(TestSdk): + + def test_create_dataset(self): + ''' + 1. create a kb + 2. list the kb + 3. get the detail info according to the kb id + 4. update the kb + 5. delete the kb + ''' + ragflow = RAGFLow(API_KEY, HOST_ADDRESS) + + # create a kb + res = ragflow.create_dataset("kb1") + assert res['code'] == 0 and res['message'] == 'success' + dataset_id = res['data']['dataset_id'] + print(dataset_id) + + # TODO: list the kb \ No newline at end of file