Feat: Alter flask to Quart for async API serving. (#11275)

### What problem does this PR solve?

#11277

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-11-18 17:05:16 +08:00
committed by GitHub
parent c2b7c305fa
commit d1716d865a
49 changed files with 4120 additions and 3888 deletions

View File

@ -18,12 +18,11 @@ import sys
import logging
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from flask import Blueprint, Flask
from quart import Blueprint, Quart, request, g, current_app, session
from werkzeug.wrappers.request import Request
from flask_cors import CORS
from flasgger import Swagger
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from quart_cors import cors
from common.constants import StatusEnum
from api.db.db_models import close_connection
from api.db.services import UserService
@ -31,17 +30,20 @@ from api.utils.json_encode import CustomJSONEncoder
from api.utils import commands
from flask_mail import Mail
from flask_session import Session
from flask_login import LoginManager
from quart_auth import Unauthorized
from common import settings
from api.utils.api_utils import server_error_response
from api.constants import API_VERSION
from common.misc_utils import get_uuid
settings.init_settings()
__all__ = ["app"]
Request.json = property(lambda self: self.get_json(force=True, silent=True))
app = Flask(__name__)
app = Quart(__name__)
app = cors(app, allow_origin="*")
smtp_mail_server = Mail()
# Add this at the beginning of your file to configure Swagger UI
@ -76,7 +78,6 @@ swagger = Swagger(
},
)
CORS(app, supports_credentials=True, max_age=2592000)
app.url_map.strict_slashes = False
app.json_encoder = CustomJSONEncoder
app.errorhandler(Exception)(server_error_response)
@ -84,17 +85,143 @@ app.errorhandler(Exception)(server_error_response)
## convince for dev and debug
# app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False
app.config["SESSION_TYPE"] = "filesystem"
app.config["SESSION_TYPE"] = "redis"
app.config["SESSION_REDIS"] = settings.decrypt_database_config(name="redis")
app.config["MAX_CONTENT_LENGTH"] = int(
os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024)
)
Session(app)
login_manager = LoginManager()
login_manager.init_app(app)
app.config['SECRET_KEY'] = settings.SECRET_KEY
app.secret_key = settings.SECRET_KEY
commands.register_commands(app)
from functools import wraps
from typing import ParamSpec, TypeVar
from collections.abc import Awaitable, Callable
from werkzeug.local import LocalProxy
T = TypeVar("T")
P = ParamSpec("P")
def _load_user():
jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = request.headers.get("Authorization")
g.user = None
if not authorization:
return
try:
access_token = str(jwt.loads(authorization))
if not access_token or not access_token.strip():
logging.warning("Authentication attempt with empty access token")
return None
# Access tokens should be UUIDs (32 hex characters)
if len(access_token.strip()) < 32:
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
return None
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return None
g.user = user[0]
return user[0]
except Exception as e:
logging.warning(f"load_user got exception {e}")
current_user = LocalProxy(_load_user)
def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
"""A decorator to restrict route access to authenticated users.
This should be used to wrap a route handler (or view function) to
enforce that only authenticated requests can access it. Note that
it is important that this decorator be wrapped by the route
decorator and not vice, versa, as below.
.. code-block:: python
@app.route('/')
@login_required
async def index():
...
If the request is not authenticated a
`quart.exceptions.Unauthorized` exception will be raised.
"""
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if not current_user:# or not session.get("_user_id"):
raise Unauthorized()
else:
return await current_app.ensure_async(func)(*args, **kwargs)
return wrapper
def login_user(user, remember=False, duration=None, force=False, fresh=True):
"""
Logs a user in. You should pass the actual user object to this. If the
user's `is_active` property is ``False``, they will not be logged in
unless `force` is ``True``.
This will return ``True`` if the log in attempt succeeds, and ``False`` if
it fails (i.e. because the user is inactive).
:param user: The user object to log in.
:type user: object
:param remember: Whether to remember the user after their session expires.
Defaults to ``False``.
:type remember: bool
:param duration: The amount of time before the remember cookie expires. If
``None`` the value set in the settings is used. Defaults to ``None``.
:type duration: :class:`datetime.timedelta`
:param force: If the user is inactive, setting this to ``True`` will log
them in regardless. Defaults to ``False``.
:type force: bool
:param fresh: setting this to ``False`` will log in the user with a session
marked as not "fresh". Defaults to ``True``.
:type fresh: bool
"""
if not force and not user.is_active:
return False
session["_user_id"] = user.id
session["_fresh"] = fresh
session["_id"] = get_uuid()
return True
def logout_user():
"""
Logs a user out. (You do not need to pass the actual user.) This will
also clean up the remember me cookie if it exists.
"""
if "_user_id" in session:
session.pop("_user_id")
if "_fresh" in session:
session.pop("_fresh")
if "_id" in session:
session.pop("_id")
COOKIE_NAME = "remember_token"
cookie_name = current_app.config.get("REMEMBER_COOKIE_NAME", COOKIE_NAME)
if cookie_name in request.cookies:
session["_remember"] = "clear"
if "_remember_seconds" in session:
session.pop("_remember_seconds")
return True
def search_pages_path(page_path):
app_path_list = [
@ -142,40 +269,6 @@ client_urls_prefix = [
]
@login_manager.request_loader
def load_user(web_request):
jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = web_request.headers.get("Authorization")
if authorization:
try:
access_token = str(jwt.loads(authorization))
if not access_token or not access_token.strip():
logging.warning("Authentication attempt with empty access token")
return None
# Access tokens should be UUIDs (32 hex characters)
if len(access_token.strip()) < 32:
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
return None
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return None
return user[0]
else:
return None
except Exception as e:
logging.warning(f"load_user got exception {e}")
return None
else:
return None
@app.teardown_request
def _db_close(exception):
if exception:

View File

@ -14,20 +14,20 @@
# limitations under the License.
#
from datetime import datetime, timedelta
from flask import request
from flask_login import login_required, current_user
from quart import request
from api.db.db_models import APIToken
from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
generate_confirmation_token
from common.time_utils import current_timestamp, datetime_format
from api.apps import login_required, current_user
@manager.route('/new_token', methods=['POST']) # noqa: F821
@login_required
def new_token():
req = request.json
async def new_token():
req = await request.json
try:
tenants = UserTenantService.query(user_id=current_user.id)
if not tenants:
@ -72,8 +72,8 @@ def token_list():
@manager.route('/rm', methods=['POST']) # noqa: F821
@validate_request("tokens", "tenant_id")
@login_required
def rm():
req = request.json
async def rm():
req = await request.json
try:
for token in req["tokens"]:
APITokenService.filter_delete(

View File

@ -18,12 +18,8 @@ import logging
import re
import sys
from functools import partial
import flask
import trio
from flask import request, Response
from flask_login import login_required, current_user
from quart import request, Response, make_response
from agent.component import LLM
from api.db import CanvasCategory, FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
@ -35,7 +31,8 @@ from api.db.services.user_service import TenantService
from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode
from common.misc_utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
request_json
from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken, Task
@ -46,6 +43,7 @@ from rag.flow.pipeline import Pipeline
from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN
from common import settings
from api.apps import login_required, current_user
@manager.route('/templates', methods=['GET']) # noqa: F821
@ -57,8 +55,9 @@ def templates():
@manager.route('/rm', methods=['POST']) # noqa: F821
@validate_request("canvas_ids")
@login_required
def rm():
for i in request.json["canvas_ids"]:
async def rm():
req = await request_json()
for i in req["canvas_ids"]:
if not UserCanvasService.accessible(i, current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
@ -70,8 +69,8 @@ def rm():
@manager.route('/set', methods=['POST']) # noqa: F821
@validate_request("dsl", "title")
@login_required
def save():
req = request.json
async def save():
req = await request_json()
if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"])
@ -129,8 +128,8 @@ def getsse(canvas_id):
@manager.route('/completion', methods=['POST']) # noqa: F821
@validate_request("id")
@login_required
def run():
req = request.json
async def run():
req = await request_json()
query = req.get("query", "")
files = req.get("files", [])
inputs = req.get("inputs", {})
@ -160,10 +159,10 @@ def run():
except Exception as e:
return server_error_response(e)
def sse():
async def sse():
nonlocal canvas, user_id
try:
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
cvs.dsl = json.loads(str(canvas))
@ -179,15 +178,15 @@ def run():
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
resp.call_on_close(lambda: canvas.cancel_task())
#resp.call_on_close(lambda: canvas.cancel_task())
return resp
@manager.route('/rerun', methods=['POST']) # noqa: F821
@validate_request("id", "dsl", "component_id")
@login_required
def rerun():
req = request.json
async def rerun():
req = await request_json()
doc = PipelineOperationLogService.get_documents_info(req["id"])
if not doc:
return get_data_error_result(message="Document not found.")
@ -224,8 +223,8 @@ def cancel(task_id):
@manager.route('/reset', methods=['POST']) # noqa: F821
@validate_request("id")
@login_required
def reset():
req = request.json
async def reset():
req = await request_json()
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
@ -245,7 +244,7 @@ def reset():
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
def upload(canvas_id):
async def upload(canvas_id):
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
if not e:
return get_data_error_result(message="canvas not found.")
@ -311,7 +310,8 @@ def upload(canvas_id):
except Exception as e:
return server_error_response(e)
file = request.files['file']
files = await request.files
file = files['file']
try:
DocumentService.check_doc_health(user_id, file.filename)
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
@ -342,8 +342,8 @@ def input_form():
@manager.route('/debug', methods=['POST']) # noqa: F821
@validate_request("id", "component_id", "params")
@login_required
def debug():
req = request.json
async def debug():
req = await request_json()
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
@ -374,8 +374,8 @@ def debug():
@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
@validate_request("db_type", "database", "username", "host", "port", "password")
@login_required
def test_db_connect():
req = request.json
async def test_db_connect():
req = await request_json()
try:
if req["db_type"] in ["mysql", "mariadb"]:
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
@ -519,8 +519,8 @@ def list_canvas():
@manager.route('/setting', methods=['POST']) # noqa: F821
@validate_request("id", "title", "permission")
@login_required
def setting():
req = request.json
async def setting():
req = await request_json()
req["user_id"] = current_user.id
if not UserCanvasService.accessible(req["id"], current_user.id):
@ -601,8 +601,8 @@ def prompts():
@manager.route('/download', methods=['GET']) # noqa: F821
def download():
async def download():
id = request.args.get("id")
created_by = request.args.get("created_by")
blob = FileService.get_blob(created_by, id)
return flask.make_response(blob)
return await make_response(blob)

View File

@ -18,8 +18,7 @@ import json
import re
import xxhash
from flask import request
from flask_login import current_user, login_required
from quart import request
from api.db.services.dialog_service import meta_filter
from api.db.services.document_service import DocumentService
@ -27,7 +26,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
request_json
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
@ -35,13 +35,14 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
from common.string_utils import remove_redundant_spaces
from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
from common import settings
from api.apps import login_required, current_user
@manager.route('/list', methods=['POST']) # noqa: F821
@login_required
@validate_request("doc_id")
def list_chunk():
req = request.json
async def list_chunk():
req = await request_json()
doc_id = req["doc_id"]
page = int(req.get("page", 1))
size = int(req.get("size", 30))
@ -121,8 +122,8 @@ def get():
@manager.route('/set', methods=['POST']) # noqa: F821
@login_required
@validate_request("doc_id", "chunk_id", "content_with_weight")
def set():
req = request.json
async def set():
req = await request_json()
d = {
"id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]}
@ -178,8 +179,8 @@ def set():
@manager.route('/switch', methods=['POST']) # noqa: F821
@login_required
@validate_request("chunk_ids", "available_int", "doc_id")
def switch():
req = request.json
async def switch():
req = await request_json()
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
@ -198,8 +199,8 @@ def switch():
@manager.route('/rm', methods=['POST']) # noqa: F821
@login_required
@validate_request("chunk_ids", "doc_id")
def rm():
req = request.json
async def rm():
req = await request_json()
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
@ -222,8 +223,8 @@ def rm():
@manager.route('/create', methods=['POST']) # noqa: F821
@login_required
@validate_request("doc_id", "content_with_weight")
def create():
req = request.json
async def create():
req = await request_json()
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
"content_with_weight": req["content_with_weight"]}
@ -280,8 +281,8 @@ def create():
@manager.route('/retrieval_test', methods=['POST']) # noqa: F821
@login_required
@validate_request("kb_id", "question")
def retrieval_test():
req = request.json
async def retrieval_test():
req = await request_json()
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]

View File

@ -20,8 +20,9 @@ import uuid
from html import escape
from typing import Any
from flask import make_response, request
from flask_login import current_user, login_required
import flask
import trio
from quart import request
from google_auth_oauthlib.flow import Flow
from api.db import InputType
@ -32,12 +33,13 @@ from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, Docum
from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
from common.misc_utils import get_uuid
from rag.utils.redis_conn import REDIS_CONN
from api.apps import login_required, current_user
@manager.route("/set", methods=["POST"]) # noqa: F821
@login_required
def set_connector():
req = request.json
async def set_connector():
req = await request.json
if req.get("id"):
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
ConnectorService.update_by_id(req["id"], conn)
@ -57,7 +59,7 @@ def set_connector():
}
ConnectorService.save(**conn)
time.sleep(1)
await trio.sleep(1)
e, conn = ConnectorService.get_by_id(req["id"])
return get_json_result(data=conn.to_dict())
@ -88,8 +90,8 @@ def list_logs(connector_id):
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
@login_required
def resume(connector_id):
req = request.json
async def resume(connector_id):
req = await request.json
if req.get("resume"):
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
else:
@ -100,8 +102,8 @@ def resume(connector_id):
@manager.route("/<connector_id>/rebuild", methods=["PUT"]) # noqa: F821
@login_required
@validate_request("kb_id")
def rebuild(connector_id):
req = request.json
async def rebuild(connector_id):
req = await request.json
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
if err:
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
@ -163,7 +165,7 @@ def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
payload_json=payload_json,
auto_close=auto_close,
)
response = make_response(html, 200)
response = flask.make_response(html, 200)
response.headers["Content-Type"] = "text/html; charset=utf-8"
return response
@ -171,14 +173,14 @@ def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821
@login_required
@validate_request("credentials")
def start_google_drive_web_oauth():
async def start_google_drive_web_oauth():
if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI:
return get_json_result(
code=RetCode.SERVER_ERROR,
message="Google Drive OAuth redirect URI is not configured on the server.",
)
req = request.json or {}
req = await request.json or {}
raw_credentials = req.get("credentials", "")
try:
credentials = _load_credentials(raw_credentials)
@ -279,8 +281,8 @@ def google_drive_web_oauth_callback():
@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821
@login_required
@validate_request("flow_id")
def poll_google_drive_web_result():
req = request.json or {}
async def poll_google_drive_web_result():
req = await request.json or {}
flow_id = req.get("flow_id")
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id))
if not cache_raw:

View File

@ -17,8 +17,8 @@ import json
import re
import logging
from copy import deepcopy
from flask import Response, request
from flask_login import current_user, login_required
from quart import Response, request
from api.apps import current_user, login_required
from api.db.db_models import APIToken
from api.db.services.conversation_service import ConversationService, structure_answer
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
@ -34,8 +34,8 @@ from common.constants import RetCode, LLMType
@manager.route("/set", methods=["POST"]) # noqa: F821
@login_required
def set_conversation():
req = request.json
async def set_conversation():
req = await request.json
conv_id = req.get("conversation_id")
is_new = req.get("is_new")
name = req.get("name", "New conversation")
@ -128,8 +128,9 @@ def getsse(dialog_id):
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
def rm():
conv_ids = request.json["conversation_ids"]
async def rm():
req = await request.json
conv_ids = req["conversation_ids"]
try:
for cid in conv_ids:
exist, conv = ConversationService.get_by_id(cid)
@ -165,8 +166,8 @@ def list_conversation():
@manager.route("/completion", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id", "messages")
def completion():
req = request.json
async def completion():
req = await request.json
msg = []
for m in req["messages"]:
if m["role"] == "system":
@ -250,8 +251,8 @@ def completion():
@manager.route("/tts", methods=["POST"]) # noqa: F821
@login_required
def tts():
req = request.json
async def tts():
req = await request.json
text = req["text"]
tenants = TenantService.get_info_by(current_user.id)
@ -283,8 +284,8 @@ def tts():
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id", "message_id")
def delete_msg():
req = request.json
async def delete_msg():
req = await request.json
e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
@ -306,8 +307,8 @@ def delete_msg():
@manager.route("/thumbup", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id", "message_id")
def thumbup():
req = request.json
async def thumbup():
req = await request.json
e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
@ -333,8 +334,8 @@ def thumbup():
@manager.route("/ask", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question", "kb_ids")
def ask_about():
req = request.json
async def ask_about():
req = await request.json
uid = current_user.id
search_id = req.get("search_id", "")
@ -365,8 +366,8 @@ def ask_about():
@manager.route("/mindmap", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question", "kb_ids")
def mindmap():
req = request.json
async def mindmap():
req = await request.json
search_id = req.get("search_id", "")
search_app = SearchService.get_detail(search_id) if search_id else {}
search_config = search_app.get("search_config", {}) if search_app else {}
@ -383,8 +384,8 @@ def mindmap():
@manager.route("/related_questions", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question")
def related_questions():
req = request.json
async def related_questions():
req = await request.json
search_id = req.get("search_id", "")
search_config = {}

View File

@ -14,8 +14,7 @@
# limitations under the License.
#
from flask import request
from flask_login import login_required, current_user
from quart import request
from api.db.services import duplicate_name
from api.db.services.dialog_service import DialogService
from common.constants import StatusEnum
@ -26,13 +25,14 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
from common.misc_utils import get_uuid
from common.constants import RetCode
from api.utils.api_utils import get_json_result
from api.apps import login_required, current_user
@manager.route('/set', methods=['POST']) # noqa: F821
@validate_request("prompt_config")
@login_required
def set_dialog():
req = request.json
async def set_dialog():
req = await request.json
dialog_id = req.get("dialog_id", "")
is_create = not dialog_id
name = req.get("name", "New Dialog")
@ -169,18 +169,19 @@ def list_dialogs():
@manager.route('/next', methods=['POST']) # noqa: F821
@login_required
def list_dialogs_next():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
parser_id = request.args.get("parser_id")
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc", "true").lower() == "false":
async def list_dialogs_next():
args = request.args
keywords = args.get("keywords", "")
page_number = int(args.get("page", 0))
items_per_page = int(args.get("page_size", 0))
parser_id = args.get("parser_id")
orderby = args.get("orderby", "create_time")
if args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
req = request.get_json()
req = await request.get_json()
owner_ids = req.get("owner_ids", [])
try:
if not owner_ids:
@ -207,8 +208,8 @@ def list_dialogs_next():
@manager.route('/rm', methods=['POST']) # noqa: F821
@login_required
@validate_request("dialog_ids")
def rm():
req = request.json
async def rm():
req = await request.json
dialog_list=[]
tenants = UserTenantService.query(user_id=current_user.id)
try:

View File

@ -18,11 +18,8 @@ import os.path
import pathlib
import re
from pathlib import Path
import flask
from flask import request
from flask_login import current_user, login_required
from quart import request, make_response
from api.apps import current_user, login_required
from api.common.check_team_permission import check_kb_team_permission
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
from api.db import VALID_FILE_TYPES, FileType
@ -39,7 +36,7 @@ from api.utils.api_utils import (
get_data_error_result,
get_json_result,
server_error_response,
validate_request,
validate_request, request_json,
)
from api.utils.file_utils import filename_type, thumbnail
from common.file_utils import get_project_base_directory
@ -53,14 +50,16 @@ from common import settings
@manager.route("/upload", methods=["POST"]) # noqa: F821
@login_required
@validate_request("kb_id")
def upload():
kb_id = request.form.get("kb_id")
async def upload():
form = await request.form
kb_id = form.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
if "file" not in request.files:
files = await request.files
if "file" not in files:
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist("file")
file_objs = files.getlist("file")
for file_obj in file_objs:
if file_obj.filename == "":
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
@ -87,12 +86,13 @@ def upload():
@manager.route("/web_crawl", methods=["POST"]) # noqa: F821
@login_required
@validate_request("kb_id", "name", "url")
def web_crawl():
kb_id = request.form.get("kb_id")
async def web_crawl():
form = await request.form
kb_id = form.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
name = request.form.get("name")
url = request.form.get("url")
name = form.get("name")
url = form.get("url")
if not is_valid_url(url):
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id)
@ -152,8 +152,8 @@ def web_crawl():
@manager.route("/create", methods=["POST"]) # noqa: F821
@login_required
@validate_request("name", "kb_id")
def create():
req = request.json
async def create():
req = await request_json()
kb_id = req["kb_id"]
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@ -208,7 +208,7 @@ def create():
@manager.route("/list", methods=["POST"]) # noqa: F821
@login_required
def list_docs():
async def list_docs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@ -230,7 +230,7 @@ def list_docs():
create_time_from = int(request.args.get("create_time_from", 0))
create_time_to = int(request.args.get("create_time_to", 0))
req = request.get_json()
req = await request.get_json()
run_status = req.get("run_status", [])
if run_status:
@ -270,8 +270,8 @@ def list_docs():
@manager.route("/filter", methods=["POST"]) # noqa: F821
@login_required
def get_filter():
req = request.get_json()
async def get_filter():
req = await request.get_json()
kb_id = req.get("kb_id")
if not kb_id:
@ -308,8 +308,8 @@ def get_filter():
@manager.route("/infos", methods=["POST"]) # noqa: F821
@login_required
def doc_infos():
req = request.json
async def doc_infos():
req = await request_json()
doc_ids = req["doc_ids"]
for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id):
@ -340,8 +340,8 @@ def thumbnails():
@manager.route("/change_status", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_ids", "status")
def change_status():
req = request.get_json()
async def change_status():
req = await request.get_json()
doc_ids = req.get("doc_ids", [])
status = str(req.get("status", ""))
@ -380,8 +380,8 @@ def change_status():
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_id")
def rm():
req = request.json
async def rm():
req = await request_json()
doc_ids = req["doc_id"]
if isinstance(doc_ids, str):
doc_ids = [doc_ids]
@ -401,8 +401,8 @@ def rm():
@manager.route("/run", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_ids", "run")
def run():
req = request.json
async def run():
req = await request_json()
for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
@ -448,8 +448,8 @@ def run():
@manager.route("/rename", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_id", "name")
def rename():
req = request.json
async def rename():
req = await request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
@ -495,14 +495,14 @@ def rename():
@manager.route("/get/<doc_id>", methods=["GET"]) # noqa: F821
# @login_required
def get(doc_id):
async def get(doc_id):
try:
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(message="Document not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
response = flask.make_response(settings.STORAGE_IMPL.get(b, n))
response = await make_response(settings.STORAGE_IMPL.get(b, n))
ext = re.search(r"\.([^.]+)$", doc.name.lower())
ext = ext.group(1) if ext else None
@ -520,12 +520,12 @@ def get(doc_id):
@manager.route("/download/<attachment_id>", methods=["GET"]) # noqa: F821
@login_required
def download_attachment(attachment_id):
async def download_attachment(attachment_id):
try:
ext = request.args.get("ext", "markdown")
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id)
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
response = flask.make_response(data)
response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
return response
@ -537,9 +537,9 @@ def download_attachment(attachment_id):
@manager.route("/change_parser", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_id")
def change_parser():
async def change_parser():
req = request.json
req = await request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
@ -590,13 +590,13 @@ def change_parser():
@manager.route("/image/<image_id>", methods=["GET"]) # noqa: F821
# @login_required
def get_image(image_id):
async def get_image(image_id):
try:
arr = image_id.split("-")
if len(arr) != 2:
return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-")
response = flask.make_response(settings.STORAGE_IMPL.get(bkt, nm))
response = await make_response(settings.STORAGE_IMPL.get(bkt, nm))
response.headers.set("Content-Type", "image/JPEG")
return response
except Exception as e:
@ -606,24 +606,25 @@ def get_image(image_id):
@manager.route("/upload_and_parse", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id")
def upload_and_parse():
if "file" not in request.files:
async def upload_and_parse():
files = await request.file
if "file" not in files:
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist("file")
file_objs = files.getlist("file")
for file_obj in file_objs:
if file_obj.filename == "":
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
form = await request.form
doc_ids = doc_upload_and_parse(form.get("conversation_id"), file_objs, current_user.id)
return get_json_result(data=doc_ids)
@manager.route("/parse", methods=["POST"]) # noqa: F821
@login_required
def parse():
url = request.json.get("url") if request.json else ""
async def parse():
url = await request.json.get("url") if await request.json else ""
if url:
if not is_valid_url(url):
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
@ -664,10 +665,11 @@ def parse():
txt = FileService.parse_docs([f], current_user.id)
return get_json_result(data=txt)
if "file" not in request.files:
files = await request.files
if "file" not in files:
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist("file")
file_objs = files.getlist("file")
txt = FileService.parse_docs(file_objs, current_user.id)
return get_json_result(data=txt)
@ -676,8 +678,8 @@ def parse():
@manager.route("/set_meta", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_id", "meta")
def set_meta():
req = request.json
async def set_meta():
req = await request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:

View File

@ -19,8 +19,8 @@ from pathlib import Path
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from flask import request
from flask_login import login_required, current_user
from quart import request
from api.apps import login_required, current_user
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from common.misc_utils import get_uuid
@ -33,8 +33,8 @@ from api.utils.api_utils import get_json_result
@manager.route('/convert', methods=['POST']) # noqa: F821
@login_required
@validate_request("file_ids", "kb_ids")
def convert():
req = request.json
async def convert():
req = await request.json
kb_ids = req["kb_ids"]
file_ids = req["file_ids"]
file2documents = []
@ -103,8 +103,8 @@ def convert():
@manager.route('/rm', methods=['POST']) # noqa: F821
@login_required
@validate_request("file_ids")
def rm():
req = request.json
async def rm():
req = await request.json
file_ids = req["file_ids"]
if not file_ids:
return get_json_result(

View File

@ -17,10 +17,8 @@ import logging
import os
import pathlib
import re
import flask
from flask import request
from flask_login import login_required, current_user
from quart import request, make_response
from api.apps import login_required, current_user
from api.common.check_team_permission import check_file_team_permission
from api.db.services.document_service import DocumentService
@ -40,17 +38,19 @@ from common import settings
@manager.route('/upload', methods=['POST']) # noqa: F821
@login_required
# @validate_request("parent_id")
def upload():
pf_id = request.form.get("parent_id")
async def upload():
form = await request.form
pf_id = form.get("parent_id")
if not pf_id:
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
if 'file' not in request.files:
files = await request.files
if 'file' not in files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file')
file_objs = files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
@ -123,10 +123,10 @@ def upload():
@manager.route('/create', methods=['POST']) # noqa: F821
@login_required
@validate_request("name")
def create():
req = request.json
pf_id = request.json.get("parent_id")
input_file_type = request.json.get("type")
async def create():
req = await request.json
pf_id = await request.json.get("parent_id")
input_file_type = await request.json.get("type")
if not pf_id:
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
@ -238,8 +238,8 @@ def get_all_parent_folders():
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("file_ids")
def rm():
req = request.json
async def rm():
req = await request.json
file_ids = req["file_ids"]
def _delete_single_file(file):
@ -299,8 +299,8 @@ def rm():
@manager.route('/rename', methods=['POST']) # noqa: F821
@login_required
@validate_request("file_id", "name")
def rename():
req = request.json
async def rename():
req = await request.json
try:
e, file = FileService.get_by_id(req["file_id"])
if not e:
@ -338,7 +338,7 @@ def rename():
@manager.route('/get/<file_id>', methods=['GET']) # noqa: F821
@login_required
def get(file_id):
async def get(file_id):
try:
e, file = FileService.get_by_id(file_id)
if not e:
@ -351,7 +351,7 @@ def get(file_id):
b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = settings.STORAGE_IMPL.get(b, n)
response = flask.make_response(blob)
response = await make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name.lower())
ext = ext.group(1) if ext else None
if ext:
@ -368,8 +368,8 @@ def get(file_id):
@manager.route("/mv", methods=["POST"]) # noqa: F821
@login_required
@validate_request("src_file_ids", "dest_file_id")
def move():
req = request.json
async def move():
req = await request.json
try:
file_ids = req["src_file_ids"]
dest_parent_id = req["dest_file_id"]

View File

@ -18,11 +18,9 @@ import logging
import random
import re
from flask import request
from flask_login import login_required, current_user
from quart import request
import numpy as np
from api.db.services.connector_service import Connector2KbService
from api.db.services.llm_service import LLMBundle
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
@ -31,7 +29,8 @@ from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
request_json
from api.db import VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File
@ -42,27 +41,28 @@ from rag.utils.redis_conn import REDIS_CONN
from rag.utils.doc_store_conn import OrderByExpr
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
from common import settings
from api.apps import login_required, current_user
@manager.route('/create', methods=['post']) # noqa: F821
@login_required
@validate_request("name")
def create():
req = request.json
req = KnowledgebaseService.create_with_name(
async def create():
req = await request_json()
e, res = KnowledgebaseService.create_with_name(
name = req.pop("name", None),
tenant_id = current_user.id,
parser_id = req.pop("parser_id", None),
**req
)
code = req.get("code")
if code:
return get_data_error_result(code=code, message=req.get("message"))
if not e:
return res
try:
if not KnowledgebaseService.save(**req):
if not KnowledgebaseService.save(**res):
return get_data_error_result()
return get_json_result(data={"kb_id":req["id"]})
return get_json_result(data={"kb_id":res["id"]})
except Exception as e:
return server_error_response(e)
@ -71,8 +71,8 @@ def create():
@login_required
@validate_request("kb_id", "name", "description", "parser_id")
@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
def update():
req = request.json
async def update():
req = await request_json()
if not isinstance(req["name"], str):
return get_data_error_result(message="Dataset name must be string.")
if req["name"].strip() == "":
@ -170,18 +170,19 @@ def detail():
@manager.route('/list', methods=['POST']) # noqa: F821
@login_required
def list_kbs():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
parser_id = request.args.get("parser_id")
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc", "true").lower() == "false":
async def list_kbs():
args = request.args
keywords = args.get("keywords", "")
page_number = int(args.get("page", 0))
items_per_page = int(args.get("page_size", 0))
parser_id = args.get("parser_id")
orderby = args.get("orderby", "create_time")
if args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
req = request.get_json()
req = await request_json()
owner_ids = req.get("owner_ids", [])
try:
if not owner_ids:
@ -203,11 +204,12 @@ def list_kbs():
except Exception as e:
return server_error_response(e)
@manager.route('/rm', methods=['post']) # noqa: F821
@login_required
@validate_request("kb_id")
def rm():
req = request.json
async def rm():
req = await request_json()
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
data=False,
@ -283,8 +285,8 @@ def list_tags_from_kbs():
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
@login_required
def rm_tags(kb_id):
req = request.json
async def rm_tags(kb_id):
req = await request_json()
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@ -303,8 +305,8 @@ def rm_tags(kb_id):
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
@login_required
def rename_tags(kb_id):
req = request.json
async def rename_tags(kb_id):
req = await request_json()
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@ -407,7 +409,7 @@ def get_basic_info():
@manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821
@login_required
def list_pipeline_logs():
async def list_pipeline_logs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@ -426,7 +428,7 @@ def list_pipeline_logs():
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
req = request.get_json()
req = await request_json()
operation_status = req.get("operation_status", [])
if operation_status:
@ -451,7 +453,7 @@ def list_pipeline_logs():
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
@login_required
def list_pipeline_dataset_logs():
async def list_pipeline_dataset_logs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@ -468,7 +470,7 @@ def list_pipeline_dataset_logs():
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
req = request.get_json()
req = await request_json()
operation_status = req.get("operation_status", [])
if operation_status:
@ -485,12 +487,12 @@ def list_pipeline_dataset_logs():
@manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821
@login_required
def delete_pipeline_logs():
async def delete_pipeline_logs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
req = request.get_json()
req = await request_json()
log_ids = req.get("log_ids", [])
PipelineOperationLogService.delete_by_ids(log_ids)
@ -514,8 +516,8 @@ def pipeline_log_detail():
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
@login_required
def run_graphrag():
req = request.json
async def run_graphrag():
req = await request_json()
kb_id = req.get("kb_id", "")
if not kb_id:
@ -583,8 +585,8 @@ def trace_graphrag():
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
@login_required
def run_raptor():
req = request.json
async def run_raptor():
req = await request_json()
kb_id = req.get("kb_id", "")
if not kb_id:
@ -652,8 +654,8 @@ def trace_raptor():
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
@login_required
def run_mindmap():
req = request.json
async def run_mindmap():
req = await request_json()
kb_id = req.get("kb_id", "")
if not kb_id:
@ -768,7 +770,7 @@ def delete_kb_task():
@manager.route("/check_embedding", methods=["post"]) # noqa: F821
@login_required
def check_embedding():
async def check_embedding():
def _guess_vec_field(src: dict) -> str | None:
for k in src or {}:
@ -859,7 +861,7 @@ def check_embedding():
def _clean(s: str) -> str:
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
return s if s else "None"
req = request.json
req = await request_json()
kb_id = req.get("kb_id", "")
embd_id = req.get("embd_id", "")
n = int(req.get("check_num", 5))

View File

@ -15,8 +15,8 @@
#
from flask import request
from flask_login import current_user, login_required
from quart import request
from api.apps import current_user, login_required
from langfuse import Langfuse
from api.db.db_models import DB
@ -27,8 +27,8 @@ from api.utils.api_utils import get_error_data_result, get_json_result, server_e
@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
@login_required
@validate_request("secret_key", "public_key", "host")
def set_api_key():
req = request.get_json()
async def set_api_key():
req = await request.get_json()
secret_key = req.get("secret_key", "")
public_key = req.get("public_key", "")
host = req.get("host", "")

View File

@ -16,8 +16,9 @@
import logging
import json
import os
from flask import request
from flask_login import login_required, current_user
from quart import request
from api.apps import login_required, current_user
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
@ -52,8 +53,8 @@ def factories():
@manager.route("/set_api_key", methods=["POST"]) # noqa: F821
@login_required
@validate_request("llm_factory", "api_key")
def set_api_key():
req = request.json
async def set_api_key():
req = await request.json
# test if api key works
chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"]
@ -122,8 +123,8 @@ def set_api_key():
@manager.route("/add_llm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("llm_factory")
def add_llm():
req = request.json
async def add_llm():
req = await request.json
factory = req["llm_factory"]
api_key = req.get("api_key", "x")
llm_name = req.get("llm_name")
@ -142,11 +143,11 @@ def add_llm():
elif factory == "Tencent Hunyuan":
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
return set_api_key()
return await set_api_key()
elif factory == "Tencent Cloud":
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
return set_api_key()
return await set_api_key()
elif factory == "Bedrock":
# For Bedrock, due to its special authentication method
@ -267,8 +268,8 @@ def add_llm():
@manager.route("/delete_llm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("llm_factory", "llm_name")
def delete_llm():
req = request.json
async def delete_llm():
req = await request.json
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
return get_json_result(data=True)
@ -276,8 +277,8 @@ def delete_llm():
@manager.route("/enable_llm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("llm_factory", "llm_name")
def enable_llm():
req = request.json
async def enable_llm():
req = await request.json
TenantLLMService.filter_update(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
)
@ -287,8 +288,8 @@ def enable_llm():
@manager.route("/delete_factory", methods=["POST"]) # noqa: F821
@login_required
@validate_request("llm_factory")
def delete_factory():
req = request.json
async def delete_factory():
req = await request.json
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
return get_json_result(data=True)

View File

@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import Response, request
from flask_login import current_user, login_required
from quart import Response, request
from api.apps import current_user, login_required
from api.db.db_models import MCPServer
from api.db.services.mcp_server_service import MCPServerService
@ -30,7 +30,7 @@ from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_too
@manager.route("/list", methods=["POST"]) # noqa: F821
@login_required
def list_mcp() -> Response:
async def list_mcp() -> Response:
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
@ -40,7 +40,7 @@ def list_mcp() -> Response:
else:
desc = True
req = request.get_json()
req = await request.get_json()
mcp_ids = req.get("mcp_ids", [])
try:
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
@ -72,8 +72,8 @@ def detail() -> Response:
@manager.route("/create", methods=["POST"]) # noqa: F821
@login_required
@validate_request("name", "url", "server_type")
def create() -> Response:
req = request.get_json()
async def create() -> Response:
req = await request.get_json()
server_type = req.get("server_type", "")
if server_type not in VALID_MCP_SERVER_TYPES:
@ -127,8 +127,8 @@ def create() -> Response:
@manager.route("/update", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_id")
def update() -> Response:
req = request.get_json()
async def update() -> Response:
req = await request.get_json()
mcp_id = req.get("mcp_id", "")
e, mcp_server = MCPServerService.get_by_id(mcp_id)
@ -183,8 +183,8 @@ def update() -> Response:
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_ids")
def rm() -> Response:
req = request.get_json()
async def rm() -> Response:
req = await request.get_json()
mcp_ids = req.get("mcp_ids", [])
try:
@ -201,8 +201,8 @@ def rm() -> Response:
@manager.route("/import", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcpServers")
def import_multiple() -> Response:
req = request.get_json()
async def import_multiple() -> Response:
req = await request.get_json()
servers = req.get("mcpServers", {})
if not servers:
return get_data_error_result(message="No MCP servers provided.")
@ -268,8 +268,8 @@ def import_multiple() -> Response:
@manager.route("/export", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_ids")
def export_multiple() -> Response:
req = request.get_json()
async def export_multiple() -> Response:
req = await request.get_json()
mcp_ids = req.get("mcp_ids", [])
if not mcp_ids:
@ -300,8 +300,8 @@ def export_multiple() -> Response:
@manager.route("/list_tools", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_ids")
def list_tools() -> Response:
req = request.get_json()
async def list_tools() -> Response:
req = await request.get_json()
mcp_ids = req.get("mcp_ids", [])
if not mcp_ids:
return get_data_error_result(message="No MCP server IDs provided.")
@ -347,8 +347,8 @@ def list_tools() -> Response:
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_id", "tool_name", "arguments")
def test_tool() -> Response:
req = request.get_json()
async def test_tool() -> Response:
req = await request.get_json()
mcp_id = req.get("mcp_id", "")
if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.")
@ -380,8 +380,8 @@ def test_tool() -> Response:
@manager.route("/cache_tools", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_id", "tools")
def cache_tool() -> Response:
req = request.get_json()
async def cache_tool() -> Response:
req = await request.get_json()
mcp_id = req.get("mcp_id", "")
if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.")
@ -403,8 +403,8 @@ def cache_tool() -> Response:
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821
@validate_request("url", "server_type")
def test_mcp() -> Response:
req = request.get_json()
async def test_mcp() -> Response:
req = await request.get_json()
url = req.get("url", "")
if not url:

View File

@ -15,8 +15,8 @@
#
from flask import Response
from flask_login import login_required
from quart import Response
from api.apps import login_required
from api.utils.api_utils import get_json_result
from plugin import GlobalPluginManager

View File

@ -27,7 +27,7 @@ from common.constants import RetCode
from common.misc_utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required
from api.utils.api_utils import get_result
from flask import request, Response
from quart import request, Response
@manager.route('/agents', methods=['GET']) # noqa: F821
@ -52,8 +52,8 @@ def list_agents(tenant_id):
@manager.route("/agents", methods=["POST"]) # noqa: F821
@token_required
def create_agent(tenant_id: str):
req: dict[str, Any] = cast(dict[str, Any], request.json)
async def create_agent(tenant_id: str):
req: dict[str, Any] = cast(dict[str, Any], await request.json)
req["user_id"] = tenant_id
if req.get("dsl") is not None:
@ -89,8 +89,8 @@ def create_agent(tenant_id: str):
@manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821
@token_required
def update_agent(tenant_id: str, agent_id: str):
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], request.json).items() if v is not None}
async def update_agent(tenant_id: str, agent_id: str):
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await request.json)).items() if v is not None}
req["user_id"] = tenant_id
if req.get("dsl") is not None:
@ -135,8 +135,8 @@ def delete_agent(tenant_id: str, agent_id: str):
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
@token_required
def webhook(tenant_id: str, agent_id: str):
req = request.json
async def webhook(tenant_id: str, agent_id: str):
req = await request.json
if not UserCanvasService.accessible(req["id"], tenant_id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',

View File

@ -14,22 +14,20 @@
# limitations under the License.
#
import logging
from flask import request
from quart import request
from api.db.services.dialog_service import DialogService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService
from common.misc_utils import get_uuid
from common.constants import RetCode, StatusEnum
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, request_json
@manager.route("/chats", methods=["POST"]) # noqa: F821
@token_required
def create(tenant_id):
req = request.json
async def create(tenant_id):
req = await request_json()
ids = [i for i in req.get("dataset_ids", []) if i]
for kb_id in ids:
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
@ -145,10 +143,10 @@ def create(tenant_id):
@manager.route("/chats/<chat_id>", methods=["PUT"]) # noqa: F821
@token_required
def update(tenant_id, chat_id):
async def update(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message="You do not own the chat")
req = request.json
req = await request_json()
ids = req.get("dataset_ids", [])
if "show_quotation" in req:
req["do_refer"] = req.pop("show_quotation")
@ -228,10 +226,10 @@ def update(tenant_id, chat_id):
@manager.route("/chats", methods=["DELETE"]) # noqa: F821
@token_required
def delete(tenant_id):
async def delete_chats(tenant_id):
errors = []
success_count = 0
req = request.json
req = await request_json()
if not req:
ids = None
else:
@ -251,8 +249,8 @@ def delete(tenant_id):
errors.append(f"Assistant({id}) not found.")
continue
temp_dict = {"status": StatusEnum.INVALID.value}
DialogService.update_by_id(id, temp_dict)
success_count += 1
success_count += DialogService.update_by_id(id, temp_dict)
print(success_count, "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$", flush=True)
if errors:
if success_count > 0:

View File

@ -18,7 +18,7 @@
import logging
import os
import json
from flask import request
from quart import request
from peewee import OperationalError
from api.db.db_models import File
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
@ -54,7 +54,7 @@ from common import settings
@manager.route("/datasets", methods=["POST"]) # noqa: F821
@token_required
def create(tenant_id):
async def create(tenant_id):
"""
Create a new dataset.
---
@ -116,16 +116,19 @@ def create(tenant_id):
# | embedding_model| embd_id |
# | chunk_method | parser_id |
req, err = validate_and_parse_json_request(request, CreateDatasetReq)
req, err = await validate_and_parse_json_request(request, CreateDatasetReq)
if err is not None:
return get_error_argument_result(err)
req = KnowledgebaseService.create_with_name(
e, req = KnowledgebaseService.create_with_name(
name = req.pop("name", None),
tenant_id = tenant_id,
parser_id = req.pop("parser_id", None),
**req
)
if not e:
return req
# Insert embedding model(embd id)
ok, t = TenantService.get_by_id(tenant_id)
if not ok:
@ -152,7 +155,7 @@ def create(tenant_id):
@manager.route("/datasets", methods=["DELETE"]) # noqa: F821
@token_required
def delete(tenant_id):
async def delete(tenant_id):
"""
Delete datasets.
---
@ -190,7 +193,7 @@ def delete(tenant_id):
schema:
type: object
"""
req, err = validate_and_parse_json_request(request, DeleteDatasetReq)
req, err = await validate_and_parse_json_request(request, DeleteDatasetReq)
if err is not None:
return get_error_argument_result(err)
@ -250,7 +253,7 @@ def delete(tenant_id):
@manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
@token_required
def update(tenant_id, dataset_id):
async def update(tenant_id, dataset_id):
"""
Update a dataset.
---
@ -316,7 +319,7 @@ def update(tenant_id, dataset_id):
# | embedding_model| embd_id |
# | chunk_method | parser_id |
extras = {"dataset_id": dataset_id}
req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
req, err = await validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
if err is not None:
return get_error_argument_result(err)

View File

@ -15,7 +15,7 @@
#
import logging
from flask import request, jsonify
from quart import request, jsonify
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -29,7 +29,7 @@ from common import settings
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
@apikey_required
@validate_request("knowledge_id", "query")
def retrieval(tenant_id):
async def retrieval(tenant_id):
"""
Dify-compatible retrieval API
---
@ -113,7 +113,7 @@ def retrieval(tenant_id):
404:
description: Knowledge base or document not found
"""
req = request.json
req = await request.json
question = req["query"]
kb_id = req["knowledge_id"]
use_kg = req.get("use_kg", False)

View File

@ -20,7 +20,7 @@ import re
from io import BytesIO
import xxhash
from flask import request, send_file
from quart import request, send_file
from peewee import OperationalError
from pydantic import BaseModel, Field, validator
@ -35,7 +35,8 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.dialog_service import meta_filter, convert_conditions
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
request_json
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
@ -69,7 +70,7 @@ class Chunk(BaseModel):
@manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821
@token_required
def upload(dataset_id, tenant_id):
async def upload(dataset_id, tenant_id):
"""
Upload documents to a dataset.
---
@ -130,9 +131,11 @@ def upload(dataset_id, tenant_id):
type: string
description: Processing status.
"""
if "file" not in request.files:
form = await request.form
files = await request.files
if "file" not in files:
return get_error_data_result(message="No file part!", code=RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist("file")
file_objs = files.getlist("file")
for file_obj in file_objs:
if file_obj.filename == "":
return get_result(message="No file selected!", code=RetCode.ARGUMENT_ERROR)
@ -155,7 +158,7 @@ def upload(dataset_id, tenant_id):
e, kb = KnowledgebaseService.get_by_id(dataset_id)
if not e:
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
err, files = FileService.upload_document(kb, file_objs, tenant_id, parent_path=request.form.get("parent_path"))
err, files = FileService.upload_document(kb, file_objs, tenant_id, parent_path=form.get("parent_path"))
if err:
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
# rename key's name
@ -179,7 +182,7 @@ def upload(dataset_id, tenant_id):
@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["PUT"]) # noqa: F821
@token_required
def update_doc(tenant_id, dataset_id, document_id):
async def update_doc(tenant_id, dataset_id, document_id):
"""
Update a document within a dataset.
---
@ -228,7 +231,7 @@ def update_doc(tenant_id, dataset_id, document_id):
schema:
type: object
"""
req = request.json
req = await request_json()
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(message="You don't own the dataset.")
e, kb = KnowledgebaseService.get_by_id(dataset_id)
@ -359,7 +362,7 @@ def update_doc(tenant_id, dataset_id, document_id):
@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["GET"]) # noqa: F821
@token_required
def download(tenant_id, dataset_id, document_id):
async def download(tenant_id, dataset_id, document_id):
"""
Download a document from a dataset.
---
@ -409,10 +412,10 @@ def download(tenant_id, dataset_id, document_id):
return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR)
file = BytesIO(file_stream)
# Use send_file with a proper filename and MIME type
return send_file(
return await send_file(
file,
as_attachment=True,
download_name=doc[0].name,
attachment_filename=doc[0].name,
mimetype="application/octet-stream", # Set a default MIME type
)
@ -589,7 +592,7 @@ def list_docs(dataset_id, tenant_id):
@manager.route("/datasets/<dataset_id>/documents", methods=["DELETE"]) # noqa: F821
@token_required
def delete(tenant_id, dataset_id):
async def delete(tenant_id, dataset_id):
"""
Delete documents from a dataset.
---
@ -628,7 +631,7 @@ def delete(tenant_id, dataset_id):
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
req = request.json
req = await request_json()
if not req:
doc_ids = None
else:
@ -699,7 +702,7 @@ def delete(tenant_id, dataset_id):
@manager.route("/datasets/<dataset_id>/chunks", methods=["POST"]) # noqa: F821
@token_required
def parse(tenant_id, dataset_id):
async def parse(tenant_id, dataset_id):
"""
Start parsing documents into chunks.
---
@ -738,7 +741,7 @@ def parse(tenant_id, dataset_id):
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
req = request.json
req = await request_json()
if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required")
doc_list = req.get("document_ids")
@ -782,7 +785,7 @@ def parse(tenant_id, dataset_id):
@manager.route("/datasets/<dataset_id>/chunks", methods=["DELETE"]) # noqa: F821
@token_required
def stop_parsing(tenant_id, dataset_id):
async def stop_parsing(tenant_id, dataset_id):
"""
Stop parsing documents into chunks.
---
@ -821,7 +824,7 @@ def stop_parsing(tenant_id, dataset_id):
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
req = request.json
req = await request_json()
if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required")
@ -1023,7 +1026,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
"/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["POST"]
)
@token_required
def add_chunk(tenant_id, dataset_id, document_id):
async def add_chunk(tenant_id, dataset_id, document_id):
"""
Add a chunk to a document.
---
@ -1093,7 +1096,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
if not doc:
return get_error_data_result(message=f"You don't own the document {document_id}.")
doc = doc[0]
req = request.json
req = await request_json()
if not str(req.get("content", "")).strip():
return get_error_data_result(message="`content` is required")
if "important_keywords" in req:
@ -1152,7 +1155,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
"datasets/<dataset_id>/documents/<document_id>/chunks", methods=["DELETE"]
)
@token_required
def rm_chunk(tenant_id, dataset_id, document_id):
async def rm_chunk(tenant_id, dataset_id, document_id):
"""
Remove chunks from a document.
---
@ -1199,7 +1202,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
docs = DocumentService.get_by_ids([document_id])
if not docs:
raise LookupError(f"Can't find the document with ID {document_id}!")
req = request.json
req = await request_json()
condition = {"doc_id": document_id}
if "chunk_ids" in req:
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
@ -1223,7 +1226,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
"/datasets/<dataset_id>/documents/<document_id>/chunks/<chunk_id>", methods=["PUT"]
)
@token_required
def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
"""
Update a chunk within a document.
---
@ -1285,7 +1288,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
if not doc:
return get_error_data_result(message=f"You don't own the document {document_id}.")
doc = doc[0]
req = request.json
req = await request_json()
if "content" in req:
content = req["content"]
else:
@ -1327,7 +1330,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
@manager.route("/retrieval", methods=["POST"]) # noqa: F821
@token_required
def retrieval_test(tenant_id):
async def retrieval_test(tenant_id):
"""
Retrieve chunks based on a query.
---
@ -1408,7 +1411,7 @@ def retrieval_test(tenant_id):
format: float
description: Similarity score.
"""
req = request.json
req = await request_json()
if not req.get("dataset_ids"):
return get_error_data_result("`dataset_ids` is required.")
kb_ids = req["dataset_ids"]

View File

@ -17,9 +17,7 @@
import pathlib
import re
import flask
from flask import request
from quart import request, make_response
from pathlib import Path
from api.db.services.document_service import DocumentService
@ -37,7 +35,7 @@ from common import settings
@manager.route('/file/upload', methods=['POST']) # noqa: F821
@token_required
def upload(tenant_id):
async def upload(tenant_id):
"""
Upload a file to the system.
---
@ -79,15 +77,17 @@ def upload(tenant_id):
type: string
description: File type (e.g., document, folder)
"""
pf_id = request.form.get("parent_id")
form = await request.form
files = await request.files
pf_id = form.get("parent_id")
if not pf_id:
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
if 'file' not in request.files:
if 'file' not in files:
return get_json_result(data=False, message='No file part!', code=400)
file_objs = request.files.getlist('file')
file_objs = files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
@ -151,7 +151,7 @@ def upload(tenant_id):
@manager.route('/file/create', methods=['POST']) # noqa: F821
@token_required
def create(tenant_id):
async def create(tenant_id):
"""
Create a new file or folder.
---
@ -193,9 +193,9 @@ def create(tenant_id):
type:
type: string
"""
req = request.json
pf_id = request.json.get("parent_id")
input_file_type = request.json.get("type")
req = await request.json
pf_id = await request.json.get("parent_id")
input_file_type = await request.json.get("type")
if not pf_id:
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
@ -450,7 +450,7 @@ def get_all_parent_folders(tenant_id):
@manager.route('/file/rm', methods=['POST']) # noqa: F821
@token_required
def rm(tenant_id):
async def rm(tenant_id):
"""
Delete one or multiple files/folders.
---
@ -481,7 +481,7 @@ def rm(tenant_id):
type: boolean
example: true
"""
req = request.json
req = await request.json
file_ids = req["file_ids"]
try:
for file_id in file_ids:
@ -524,7 +524,7 @@ def rm(tenant_id):
@manager.route('/file/rename', methods=['POST']) # noqa: F821
@token_required
def rename(tenant_id):
async def rename(tenant_id):
"""
Rename a file.
---
@ -556,7 +556,7 @@ def rename(tenant_id):
type: boolean
example: true
"""
req = request.json
req = await request.json
try:
e, file = FileService.get_by_id(req["file_id"])
if not e:
@ -585,7 +585,7 @@ def rename(tenant_id):
@manager.route('/file/get/<file_id>', methods=['GET']) # noqa: F821
@token_required
def get(tenant_id, file_id):
async def get(tenant_id, file_id):
"""
Download a file.
---
@ -619,7 +619,7 @@ def get(tenant_id, file_id):
b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = settings.STORAGE_IMPL.get(b, n)
response = flask.make_response(blob)
response = await make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name)
if ext:
if file.type == FileType.VISUAL.value:
@ -633,7 +633,7 @@ def get(tenant_id, file_id):
@manager.route('/file/mv', methods=['POST']) # noqa: F821
@token_required
def move(tenant_id):
async def move(tenant_id):
"""
Move one or multiple files to another folder.
---
@ -667,7 +667,7 @@ def move(tenant_id):
type: boolean
example: true
"""
req = request.json
req = await request.json
try:
file_ids = req["src_file_ids"]
parent_id = req["dest_file_id"]
@ -693,8 +693,8 @@ def move(tenant_id):
@manager.route('/file/convert', methods=['POST']) # noqa: F821
@token_required
def convert(tenant_id):
req = request.json
async def convert(tenant_id):
req = await request.json
kb_ids = req["kb_ids"]
file_ids = req["file_ids"]
file2documents = []

View File

@ -18,7 +18,7 @@ import re
import time
import tiktoken
from flask import Response, jsonify, request
from quart import Response, jsonify, request
from agent.canvas import Canvas
from api.db.db_models import APIToken
@ -44,8 +44,8 @@ from common import settings
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required
def create(tenant_id, chat_id):
req = request.json
async def create(tenant_id, chat_id):
req = await request.json
req["dialog_id"] = chat_id
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
if not dia:
@ -97,8 +97,8 @@ def create_agent_session(tenant_id, agent_id):
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
@token_required
def update(tenant_id, chat_id, session_id):
req = request.json
async def update(tenant_id, chat_id, session_id):
req = await request.json
req["dialog_id"] = chat_id
conv_id = session_id
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
@ -119,8 +119,8 @@ def update(tenant_id, chat_id, session_id):
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
@token_required
def chat_completion(tenant_id, chat_id):
req = request.json
async def chat_completion(tenant_id, chat_id):
req = await request.json
if not req:
req = {"question": ""}
if not req.get("session_id"):
@ -149,7 +149,7 @@ def chat_completion(tenant_id, chat_id):
@manager.route("/chats_openai/<chat_id>/chat/completions", methods=["POST"]) # noqa: F821
@validate_request("model", "messages") # noqa: F821
@token_required
def chat_completion_openai_like(tenant_id, chat_id):
async def chat_completion_openai_like(tenant_id, chat_id):
"""
OpenAI-like chat completion API that simulates the behavior of OpenAI's completions endpoint.
@ -206,7 +206,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
if reference:
print(completion.choices[0].message.reference)
"""
req = request.get_json()
req = await request.get_json()
need_reference = bool(req.get("reference", False))
@ -383,8 +383,8 @@ def chat_completion_openai_like(tenant_id, chat_id):
@manager.route("/agents_openai/<agent_id>/chat/completions", methods=["POST"]) # noqa: F821
@validate_request("model", "messages") # noqa: F821
@token_required
def agents_completion_openai_compatibility(tenant_id, agent_id):
req = request.json
async def agents_completion_openai_compatibility(tenant_id, agent_id):
req = await request.json
tiktokenenc = tiktoken.get_encoding("cl100k_base")
messages = req.get("messages", [])
if not messages:
@ -443,8 +443,8 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
@token_required
def agent_completions(tenant_id, agent_id):
req = request.json
async def agent_completions(tenant_id, agent_id):
req = await request.json
if req.get("stream", True):
@ -610,13 +610,13 @@ def list_agent_session(tenant_id, agent_id):
@manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required
def delete(tenant_id, chat_id):
async def delete(tenant_id, chat_id):
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_data_result(message="You don't own the chat")
errors = []
success_count = 0
req = request.json
req = await request.json
convs = ConversationService.query(dialog_id=chat_id)
if not req:
ids = None
@ -661,10 +661,10 @@ def delete(tenant_id, chat_id):
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required
def delete_agent_session(tenant_id, agent_id):
async def delete_agent_session(tenant_id, agent_id):
errors = []
success_count = 0
req = request.json
req = await request.json
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
if not cvs:
return get_error_data_result(f"You don't own the agent {agent_id}")
@ -716,8 +716,8 @@ def delete_agent_session(tenant_id, agent_id):
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
@token_required
def ask_about(tenant_id):
req = request.json
async def ask_about(tenant_id):
req = await request.json
if not req.get("question"):
return get_error_data_result("`question` is required.")
if not req.get("dataset_ids"):
@ -755,8 +755,8 @@ def ask_about(tenant_id):
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
@token_required
def related_questions(tenant_id):
req = request.json
async def related_questions(tenant_id):
req = await request.json
if not req.get("question"):
return get_error_data_result("`question` is required.")
question = req["question"]
@ -806,8 +806,8 @@ Related search terms:
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
def chatbot_completions(dialog_id):
req = request.json
async def chatbot_completions(dialog_id):
req = await request.json
token = request.headers.get("Authorization").split()
if len(token) != 2:
@ -856,8 +856,8 @@ def chatbots_inputs(dialog_id):
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
def agent_bot_completions(agent_id):
req = request.json
async def agent_bot_completions(agent_id):
req = await request.json
token = request.headers.get("Authorization").split()
if len(token) != 2:
@ -901,7 +901,7 @@ def begin_inputs(agent_id):
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
@validate_request("question", "kb_ids")
def ask_about_embedded():
async def ask_about_embedded():
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
@ -910,7 +910,7 @@ def ask_about_embedded():
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
req = request.json
req = await request.json
uid = objs[0].tenant_id
search_id = req.get("search_id", "")
@ -940,7 +940,7 @@ def ask_about_embedded():
@manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821
@validate_request("kb_id", "question")
def retrieval_test_embedded():
async def retrieval_test_embedded():
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
@ -949,7 +949,7 @@ def retrieval_test_embedded():
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
req = request.json
req = await request.json
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]
@ -1039,7 +1039,7 @@ def retrieval_test_embedded():
@manager.route("/searchbots/related_questions", methods=["POST"]) # noqa: F821
@validate_request("question")
def related_questions_embedded():
async def related_questions_embedded():
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
@ -1048,7 +1048,7 @@ def related_questions_embedded():
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
req = request.json
req = await request.json
tenant_id = objs[0].tenant_id
if not tenant_id:
return get_error_data_result(message="permission denined.")
@ -1115,7 +1115,7 @@ def detail_share_embedded():
@manager.route("/searchbots/mindmap", methods=["POST"]) # noqa: F821
@validate_request("question", "kb_ids")
def mindmap():
async def mindmap():
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
@ -1125,7 +1125,7 @@ def mindmap():
return get_error_data_result(message='Authentication error: API key is invalid!"')
tenant_id = objs[0].tenant_id
req = request.json
req = await request.json
search_id = req.get("search_id", "")
search_app = SearchService.get_detail(search_id) if search_id else {}

View File

@ -14,8 +14,8 @@
# limitations under the License.
#
from flask import request
from flask_login import current_user, login_required
from quart import request
from api.apps import current_user, login_required
from api.constants import DATASET_NAME_LIMIT
from api.db.db_models import DB
@ -30,8 +30,8 @@ from api.utils.api_utils import get_data_error_result, get_json_result, not_allo
@manager.route("/create", methods=["post"]) # noqa: F821
@login_required
@validate_request("name")
def create():
req = request.get_json()
async def create():
req = await request.get_json()
search_name = req["name"]
description = req.get("description", "")
if not isinstance(search_name, str):
@ -65,8 +65,8 @@ def create():
@login_required
@validate_request("search_id", "name", "search_config", "tenant_id")
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
def update():
req = request.get_json()
async def update():
req = await request.get_json()
if not isinstance(req["name"], str):
return get_data_error_result(message="Search name must be string.")
if req["name"].strip() == "":
@ -140,7 +140,7 @@ def detail():
@manager.route("/list", methods=["POST"]) # noqa: F821
@login_required
def list_search_app():
async def list_search_app():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
@ -150,7 +150,7 @@ def list_search_app():
else:
desc = True
req = request.get_json()
req = await request.get_json()
owner_ids = req.get("owner_ids", [])
try:
if not owner_ids:
@ -173,8 +173,8 @@ def list_search_app():
@manager.route("/rm", methods=["post"]) # noqa: F821
@login_required
@validate_request("search_id")
def rm():
req = request.get_json()
async def rm():
req = await request.get_json()
search_id = req["search_id"]
if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)

View File

@ -17,7 +17,7 @@ import logging
from datetime import datetime
import json
from flask_login import login_required, current_user
from api.apps import login_required, current_user
from api.db.db_models import APIToken
from api.db.services.api_service import APITokenService
@ -34,7 +34,7 @@ from common.time_utils import current_timestamp, datetime_format
from timeit import default_timer as timer
from rag.utils.redis_conn import REDIS_CONN
from flask import jsonify
from quart import jsonify
from api.utils.health_utils import run_health_checks
from common import settings

View File

@ -14,10 +14,7 @@
# limitations under the License.
#
from flask import request
from flask_login import login_required, current_user
from api.apps import smtp_mail_server
from quart import request
from api.db import UserTenantRole
from api.db.db_models import UserTenant
from api.db.services.user_service import UserTenantService, UserService
@ -28,6 +25,7 @@ from common.time_utils import delta_seconds
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
from api.utils.web_utils import send_invite_email
from common import settings
from api.apps import smtp_mail_server, login_required, current_user
@manager.route("/<tenant_id>/user/list", methods=["GET"]) # noqa: F821
@ -51,14 +49,14 @@ def user_list(tenant_id):
@manager.route('/<tenant_id>/user', methods=['POST']) # noqa: F821
@login_required
@validate_request("email")
def create(tenant_id):
async def create(tenant_id):
if current_user.id != tenant_id:
return get_json_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR)
req = request.json
req = await request.json
invite_user_email = req["email"]
invite_users = UserService.query(email=invite_user_email)
if not invite_users:

View File

@ -22,8 +22,7 @@ import secrets
import time
from datetime import datetime
from flask import redirect, request, session, make_response
from flask_login import current_user, login_required, login_user, logout_user
from quart import redirect, request, session, make_response
from werkzeug.security import check_password_hash, generate_password_hash
from api.apps.auth import get_auth_client
@ -45,7 +44,7 @@ from api.utils.api_utils import (
)
from api.utils.crypt import decrypt
from rag.utils.redis_conn import REDIS_CONN
from api.apps import smtp_mail_server
from api.apps import smtp_mail_server, login_required, current_user, login_user, logout_user
from api.utils.web_utils import (
send_email_html,
OTP_LENGTH,
@ -61,7 +60,7 @@ from common import settings
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
def login():
async def login():
"""
User login endpoint.
---
@ -91,10 +90,11 @@ def login():
schema:
type: object
"""
if not request.json:
json_body = await request.json
if not json_body:
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
email = request.json.get("email", "")
email = json_body.get("email", "")
users = UserService.query(email=email)
if not users:
return get_json_result(
@ -103,7 +103,7 @@ def login():
message=f"Email: {email} is not registered!",
)
password = request.json.get("password")
password = json_body.get("password")
try:
password = decrypt(password)
except BaseException:
@ -125,7 +125,8 @@ def login():
user.update_date = (datetime_format(datetime.now()),)
user.save()
msg = "Welcome back!"
return construct_response(data=response_data, auth=user.get_id(), message=msg)
return await construct_response(data=response_data, auth=user.get_id(), message=msg)
else:
return get_json_result(
data=False,
@ -501,7 +502,7 @@ def log_out():
@manager.route("/setting", methods=["POST"]) # noqa: F821
@login_required
def setting_user():
async def setting_user():
"""
Update user settings.
---
@ -530,7 +531,7 @@ def setting_user():
type: object
"""
update_dict = {}
request_data = request.json
request_data = await request.json
if request_data.get("password"):
new_password = request_data.get("new_password")
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
@ -660,7 +661,7 @@ def user_register(user_id, user):
@manager.route("/register", methods=["POST"]) # noqa: F821
@validate_request("nickname", "email", "password")
def user_add():
async def user_add():
"""
Register a new user.
---
@ -697,7 +698,7 @@ def user_add():
code=RetCode.OPERATING_ERROR,
)
req = request.json
req = await request.json
email_address = req["email"]
# Validate the email address
@ -737,7 +738,7 @@ def user_add():
raise Exception(f"Same email: {email_address} exists!")
user = users[0]
login_user(user)
return construct_response(
return await construct_response(
data=user.to_json(),
auth=user.get_id(),
message=f"{nickname}, welcome aboard!",
@ -793,7 +794,7 @@ def tenant_info():
@manager.route("/set_tenant_info", methods=["POST"]) # noqa: F821
@login_required
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
def set_tenant_info():
async def set_tenant_info():
"""
Update tenant information.
---
@ -830,7 +831,7 @@ def set_tenant_info():
schema:
type: object
"""
req = request.json
req = await request.json
try:
tid = req.pop("tenant_id")
TenantService.update_by_id(tid, req)
@ -840,7 +841,7 @@ def set_tenant_info():
@manager.route("/forget/captcha", methods=["GET"]) # noqa: F821
def forget_get_captcha():
async def forget_get_captcha():
"""
GET /forget/captcha?email=<email>
- Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = OTP_TTL_SECONDS.
@ -862,19 +863,19 @@ def forget_get_captcha():
from captcha.image import ImageCaptcha
image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70])
img_bytes = image.generate(captcha_text).read()
response = make_response(img_bytes)
response = await make_response(img_bytes)
response.headers.set("Content-Type", "image/JPEG")
return response
@manager.route("/forget/otp", methods=["POST"]) # noqa: F821
def forget_send_otp():
async def forget_send_otp():
"""
POST /forget/otp
- Verify the image captcha stored at captcha:{email} (case-insensitive).
- On success, generate an email OTP (AZ with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
"""
req = request.get_json()
req = await request.get_json()
email = req.get("email") or ""
captcha = (req.get("captcha") or "").strip()
@ -935,12 +936,12 @@ def forget_send_otp():
@manager.route("/forget", methods=["POST"]) # noqa: F821
def forget():
async def forget():
"""
POST: Verify email + OTP and reset password, then log the user in.
Request JSON: { email, otp, new_password, confirm_new_password }
"""
req = request.get_json()
req = await request.get_json()
email = req.get("email") or ""
otp = (req.get("otp") or "").strip()
new_pwd = req.get("new_password")

View File

@ -25,7 +25,7 @@ from datetime import datetime, timezone
from enum import Enum
from functools import wraps
from flask_login import UserMixin
from quart_auth import AuthUser
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
@ -595,7 +595,7 @@ def fill_db_model_object(model_object, human_model_dict):
return model_object
class User(DataBaseModel, UserMixin):
class User(DataBaseModel, AuthUser):
id = CharField(max_length=32, primary_key=True)
access_token = CharField(max_length=255, null=True, index=True)
nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)

View File

@ -24,7 +24,6 @@ from api.db import InputType
from api.db.db_models import Connector, SyncLogs, Connector2Kb, Knowledgebase
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService
from common.misc_utils import get_uuid
from common.constants import TaskStatus
from common.time_utils import current_timestamp, timestamp_to_date
@ -68,6 +67,7 @@ class ConnectorService(CommonService):
@classmethod
def rebuild(cls, kb_id:str, connector_id: str, tenant_id:str):
from api.db.services.file_service import FileService
e, conn = cls.get_by_id(connector_id)
if not e:
return None
@ -191,6 +191,7 @@ class SyncLogsService(CommonService):
@classmethod
def duplicate_and_parse(cls, kb, docs, tenant_id, src, auto_parse=True):
from api.db.services.file_service import FileService
if not docs:
return None

View File

@ -41,6 +41,7 @@ from rag.utils.redis_conn import REDIS_CONN
from rag.utils.doc_store_conn import OrderByExpr
from common import settings
class DocumentService(CommonService):
model = Document

View File

@ -18,7 +18,6 @@ import re
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from flask_login import current_user
from peewee import fn
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileType
@ -34,6 +33,7 @@ from api.db.services.task_service import TaskService
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img, sanitize_path
from rag.llm.cv_model import GptV4
from common import settings
from api.apps import current_user
class FileService(CommonService):

View File

@ -24,9 +24,10 @@ from common.time_utils import current_timestamp, datetime_format
from api.db.services import duplicate_name
from api.db.services.user_service import TenantService
from common.misc_utils import get_uuid
from common.constants import StatusEnum, RetCode
from common.constants import StatusEnum
from api.constants import DATASET_NAME_LIMIT
from api.utils.api_utils import get_parser_config
from api.utils.api_utils import get_parser_config, get_data_error_result
class KnowledgebaseService(CommonService):
"""Service class for managing knowledge base operations.
@ -391,12 +392,12 @@ class KnowledgebaseService(CommonService):
"""
# Validate name
if not isinstance(name, str):
return {"code": RetCode.DATA_ERROR, "message": "Dataset name must be string."}
return False, get_data_error_result(message="Dataset name must be string.")
dataset_name = name.strip()
if len(dataset_name) == 0:
return {"code": RetCode.DATA_ERROR, "message": "Dataset name can't be empty."}
if dataset_name == "":
return False, get_data_error_result(message="Dataset name can't be empty.")
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
return {"code": RetCode.DATA_ERROR, "message": f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}"}
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
# Deduplicate name within tenant
dataset_name = duplicate_name(
@ -409,7 +410,7 @@ class KnowledgebaseService(CommonService):
# Verify tenant exists
ok, _t = TenantService.get_by_id(tenant_id)
if not ok:
return {"code": RetCode.DATA_ERROR, "message": "Tenant does not exist."}
return False, get_data_error_result(message="Tenant not found.")
# Build payload
kb_id = get_uuid()
@ -425,7 +426,7 @@ class KnowledgebaseService(CommonService):
# Update parser_config (always override with validated default/merged config)
payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config"))
return payload
return True, payload
@classmethod

View File

@ -31,7 +31,6 @@ import traceback
import threading
import uuid
from werkzeug.serving import run_simple
from api.apps import app, smtp_mail_server
from api.db.runtime_config import RuntimeConfig
from api.db.services.document_service import DocumentService
@ -153,14 +152,7 @@ if __name__ == '__main__':
# start http server
try:
logging.info("RAGFlow HTTP server start...")
run_simple(
hostname=settings.HOST_IP,
port=settings.HOST_PORT,
application=app,
threaded=True,
use_reloader=RuntimeConfig.DEBUG,
use_debugger=RuntimeConfig.DEBUG,
)
app.run(host=settings.HOST_IP, port=settings.HOST_PORT)
except Exception:
traceback.print_exc()
stop_event.set()

View File

@ -15,6 +15,7 @@
#
import functools
import inspect
import json
import logging
import os
@ -24,14 +25,12 @@ from functools import wraps
import requests
import trio
from flask import (
from quart import (
Response,
jsonify,
request
)
from flask_login import current_user
from flask import (
request as flask_request,
)
from peewee import OperationalError
from common.constants import ActiveEnum
@ -46,6 +45,12 @@ from common import settings
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
async def request_json():
try:
return await request.json
except Exception:
return {}
def serialize_for_json(obj):
"""
Recursively serialize objects to make them JSON serializable.
@ -105,31 +110,37 @@ def server_error_response(e):
def validate_request(*args, **kwargs):
def process_args(input_arguments):
no_arguments = []
error_arguments = []
for arg in args:
if arg not in input_arguments:
no_arguments.append(arg)
for k, v in kwargs.items():
config_value = input_arguments.get(k, None)
if config_value is None:
no_arguments.append(k)
elif isinstance(v, (tuple, list)):
if config_value not in v:
error_arguments.append((k, set(v)))
elif config_value != v:
error_arguments.append((k, v))
if no_arguments or error_arguments:
error_string = ""
if no_arguments:
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
if error_arguments:
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return error_string
def wrapper(func):
@wraps(func)
def decorated_function(*_args, **_kwargs):
input_arguments = flask_request.json or flask_request.form.to_dict()
no_arguments = []
error_arguments = []
for arg in args:
if arg not in input_arguments:
no_arguments.append(arg)
for k, v in kwargs.items():
config_value = input_arguments.get(k, None)
if config_value is None:
no_arguments.append(k)
elif isinstance(v, (tuple, list)):
if config_value not in v:
error_arguments.append((k, set(v)))
elif config_value != v:
error_arguments.append((k, v))
if no_arguments or error_arguments:
error_string = ""
if no_arguments:
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
if error_arguments:
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=error_string)
async def decorated_function(*_args, **_kwargs):
errs = process_args(await request.json or (await request.form).to_dict())
if errs:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
if inspect.iscoroutinefunction(func):
return await func(*_args, **_kwargs)
return func(*_args, **_kwargs)
return decorated_function
@ -138,30 +149,34 @@ def validate_request(*args, **kwargs):
def not_allowed_parameters(*params):
def decorator(f):
def wrapper(*args, **kwargs):
input_arguments = flask_request.json or flask_request.form.to_dict()
def decorator(func):
async def wrapper(*args, **kwargs):
input_arguments = await request.json or (await request.form).to_dict()
for param in params:
if param in input_arguments:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
return f(*args, **kwargs)
if inspect.iscoroutinefunction(func):
return await func(*args, **kwargs)
return func(*args, **kwargs)
return wrapper
return decorator
def active_required(f):
@wraps(f)
def wrapper(*args, **kwargs):
def active_required(func):
@wraps(func)
async def wrapper(*args, **kwargs):
from api.db.services import UserService
from api.apps import current_user
user_id = current_user.id
usr = UserService.filter_by_id(user_id)
# check is_active
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
return get_json_result(code=RetCode.FORBIDDEN, message="User isn't active, please activate first.")
return f(*args, **kwargs)
if inspect.iscoroutinefunction(func):
return await func(*args, **kwargs)
return func(*args, **kwargs)
return wrapper
@ -173,12 +188,15 @@ def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=Non
def apikey_required(func):
@wraps(func)
def decorated_function(*args, **kwargs):
token = flask_request.headers.get("Authorization").split()[1]
async def decorated_function(*args, **kwargs):
token = request.headers.get("Authorization").split()[1]
objs = APIToken.query(token=token)
if not objs:
return build_error_result(message="API-KEY is invalid!", code=RetCode.FORBIDDEN)
kwargs["tenant_id"] = objs[0].tenant_id
if inspect.iscoroutinefunction(func):
return await func(*args, **kwargs)
return func(*args, **kwargs)
return decorated_function
@ -199,23 +217,38 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da
def token_required(func):
@wraps(func)
def decorated_function(*args, **kwargs):
def get_tenant_id(**kwargs):
if os.environ.get("DISABLE_SDK"):
return get_json_result(data=False, message="`Authorization` can't be empty")
authorization_str = flask_request.headers.get("Authorization")
return False, get_json_result(data=False, message="`Authorization` can't be empty")
authorization_str = request.headers.get("Authorization")
if not authorization_str:
return get_json_result(data=False, message="`Authorization` can't be empty")
return False, get_json_result(data=False, message="`Authorization` can't be empty")
authorization_list = authorization_str.split()
if len(authorization_list) < 2:
return get_json_result(data=False, message="Please check your authorization format.")
return False, get_json_result(data=False, message="Please check your authorization format.")
token = authorization_list[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR)
return False, get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR)
kwargs["tenant_id"] = objs[0].tenant_id
return True, kwargs
@wraps(func)
def decorated_function(*args, **kwargs):
e, kwargs = get_tenant_id(**kwargs)
if not e:
return kwargs
return func(*args, **kwargs)
@wraps(func)
async def adecorated_function(*args, **kwargs):
e, kwargs = get_tenant_id(**kwargs)
if not e:
return kwargs
return await func(*args, **kwargs)
if inspect.iscoroutinefunction(func):
return adecorated_function
return decorated_function

View File

@ -18,7 +18,7 @@ import base64
import click
import re
from flask import Flask
from quart import Quart
from werkzeug.security import generate_password_hash
from api.db.services import UserService
@ -73,6 +73,7 @@ def reset_email(email, new_email, email_confirm):
UserService.update_user(user[0].id,user_dict)
click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
def register_commands(app: Flask):
def register_commands(app: Quart):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)

View File

@ -17,7 +17,7 @@ from collections import Counter
from typing import Annotated, Any, Literal
from uuid import UUID
from flask import Request
from quart import Request
from pydantic import (
BaseModel,
ConfigDict,
@ -32,7 +32,7 @@ from werkzeug.exceptions import BadRequest, UnsupportedMediaType
from api.constants import DATASET_NAME_LIMIT
def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
async def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
"""
Validates and parses JSON requests through a multi-stage validation pipeline.
@ -81,7 +81,7 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
from the final output after validation
"""
try:
payload = request.get_json() or {}
payload = await request.get_json() or {}
except UnsupportedMediaType:
return None, f"Unsupported content type: Expected application/json, got {request.content_type}"
except BadRequest:

View File

@ -23,7 +23,7 @@ from urllib.parse import urlparse
from api.apps import smtp_mail_server
from flask_mail import Message
from flask import render_template_string
from quart import render_template_string
from api.utils.email_templates import EMAIL_TEMPLATES
from selenium import webdriver
from selenium.common.exceptions import TimeoutException