mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-04 03:25:30 +08:00
Compare commits
9 Commits
0db00f70b2
...
50bc53a1f5
| Author | SHA1 | Date | |
|---|---|---|---|
| 50bc53a1f5 | |||
| 8cd4882596 | |||
| 35e5fade93 | |||
| 4942a23290 | |||
| d1716d865a | |||
| c2b7c305fa | |||
| 341e5904c8 | |||
| ded9bf80c5 | |||
| fea157ba08 |
@ -20,8 +20,10 @@ import logging
|
||||
import time
|
||||
import threading
|
||||
import traceback
|
||||
from werkzeug.serving import run_simple
|
||||
|
||||
from flask import Flask
|
||||
from flask_login import LoginManager
|
||||
from werkzeug.serving import run_simple
|
||||
from routes import admin_bp
|
||||
from common.log_utils import init_root_logger
|
||||
from common.constants import SERVICE_CONF
|
||||
@ -30,7 +32,6 @@ from common import settings
|
||||
from config import load_configurations, SERVICE_CONFIGS
|
||||
from auth import init_default_admin, setup_auth
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from common.versions import get_ragflow_version
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
@ -19,7 +19,8 @@ import logging
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from datetime import datetime
|
||||
from flask import request, jsonify
|
||||
|
||||
from flask import jsonify, request
|
||||
from flask_login import current_user, login_user
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
|
||||
@ -13,8 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
from flask import jsonify
|
||||
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
import secrets
|
||||
|
||||
from flask import Blueprint, request
|
||||
from flask_login import current_user, logout_user, login_required
|
||||
from flask_login import current_user, login_required, logout_user
|
||||
|
||||
from auth import login_verify, login_admin, check_admin_auth
|
||||
from responses import success_response, error_response
|
||||
@ -30,12 +30,12 @@ admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
|
||||
|
||||
|
||||
@admin_bp.route('/login', methods=['POST'])
|
||||
def login():
|
||||
if not request.json:
|
||||
async def login():
|
||||
if not await request.json:
|
||||
return error_response('Authorize admin failed.' ,400)
|
||||
try:
|
||||
email = request.json.get("email", "")
|
||||
password = request.json.get("password", "")
|
||||
email = await request.json.get("email", "")
|
||||
password = await request.json.get("password", "")
|
||||
return login_admin(email, password)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
@ -76,9 +76,9 @@ def list_users():
|
||||
@admin_bp.route('/users', methods=['POST'])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def create_user():
|
||||
async def create_user():
|
||||
try:
|
||||
data = request.get_json()
|
||||
data = await request.get_json()
|
||||
if not data or 'username' not in data or 'password' not in data:
|
||||
return error_response("Username and password are required", 400)
|
||||
|
||||
@ -120,9 +120,9 @@ def delete_user(username):
|
||||
@admin_bp.route('/users/<username>/password', methods=['PUT'])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def change_password(username):
|
||||
async def change_password(username):
|
||||
try:
|
||||
data = request.get_json()
|
||||
data = await request.get_json()
|
||||
if not data or 'new_password' not in data:
|
||||
return error_response("New password is required", 400)
|
||||
|
||||
@ -139,9 +139,9 @@ def change_password(username):
|
||||
@admin_bp.route('/users/<username>/activate', methods=['PUT'])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def alter_user_activate_status(username):
|
||||
async def alter_user_activate_status(username):
|
||||
try:
|
||||
data = request.get_json()
|
||||
data = await request.get_json()
|
||||
if not data or 'activate_status' not in data:
|
||||
return error_response("Activation status is required", 400)
|
||||
activate_status = data['activate_status']
|
||||
@ -253,9 +253,9 @@ def restart_service(service_id):
|
||||
@admin_bp.route('/roles', methods=['POST'])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def create_role():
|
||||
async def create_role():
|
||||
try:
|
||||
data = request.get_json()
|
||||
data = await request.get_json()
|
||||
if not data or 'role_name' not in data:
|
||||
return error_response("Role name is required", 400)
|
||||
role_name: str = data['role_name']
|
||||
@ -269,9 +269,9 @@ def create_role():
|
||||
@admin_bp.route('/roles/<role_name>', methods=['PUT'])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def update_role(role_name: str):
|
||||
async def update_role(role_name: str):
|
||||
try:
|
||||
data = request.get_json()
|
||||
data = await request.get_json()
|
||||
if not data or 'description' not in data:
|
||||
return error_response("Role description is required", 400)
|
||||
description: str = data['description']
|
||||
@ -317,9 +317,9 @@ def get_role_permission(role_name: str):
|
||||
@admin_bp.route('/roles/<role_name>/permission', methods=['POST'])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def grant_role_permission(role_name: str):
|
||||
async def grant_role_permission(role_name: str):
|
||||
try:
|
||||
data = request.get_json()
|
||||
data = await request.get_json()
|
||||
if not data or 'actions' not in data or 'resource' not in data:
|
||||
return error_response("Permission is required", 400)
|
||||
actions: list = data['actions']
|
||||
@ -333,9 +333,9 @@ def grant_role_permission(role_name: str):
|
||||
@admin_bp.route('/roles/<role_name>/permission', methods=['DELETE'])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def revoke_role_permission(role_name: str):
|
||||
async def revoke_role_permission(role_name: str):
|
||||
try:
|
||||
data = request.get_json()
|
||||
data = await request.get_json()
|
||||
if not data or 'actions' not in data or 'resource' not in data:
|
||||
return error_response("Permission is required", 400)
|
||||
actions: list = data['actions']
|
||||
@ -349,9 +349,9 @@ def revoke_role_permission(role_name: str):
|
||||
@admin_bp.route('/users/<user_name>/role', methods=['PUT'])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def update_user_role(user_name: str):
|
||||
async def update_user_role(user_name: str):
|
||||
try:
|
||||
data = request.get_json()
|
||||
data = await request.get_json()
|
||||
if not data or 'role_name' not in data:
|
||||
return error_response("Role name is required", 400)
|
||||
role_name: str = data['role_name']
|
||||
|
||||
@ -25,7 +25,6 @@ from typing import Any, Union, Tuple
|
||||
|
||||
from agent.component import component_class
|
||||
from agent.component.base import ComponentBase
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.misc_utils import get_uuid, hash_str2int
|
||||
from common.exceptions import TaskCanceledException
|
||||
@ -217,6 +216,38 @@ class Graph:
|
||||
else:
|
||||
cur = getattr(cur, key, None)
|
||||
return cur
|
||||
|
||||
def set_variable_value(self, exp: str,value):
|
||||
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
|
||||
if exp.find("@") < 0:
|
||||
self.globals[exp] = value
|
||||
return
|
||||
cpn_id, var_nm = exp.split("@")
|
||||
cpn = self.get_component(cpn_id)
|
||||
if not cpn:
|
||||
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
||||
parts = var_nm.split(".", 1)
|
||||
root_key = parts[0]
|
||||
rest = parts[1] if len(parts) > 1 else ""
|
||||
if not rest:
|
||||
cpn["obj"].set_output(root_key, value)
|
||||
return
|
||||
root_val = cpn["obj"].output(root_key)
|
||||
if not root_val:
|
||||
root_val = {}
|
||||
cpn["obj"].set_output(root_key, self.set_variable_param_value(root_val,rest,value))
|
||||
|
||||
def set_variable_param_value(self, obj: Any, path: str, value) -> Any:
|
||||
cur = obj
|
||||
keys = path.split('.')
|
||||
if not path:
|
||||
return value
|
||||
for key in keys:
|
||||
if key not in cur or not isinstance(cur[key], dict):
|
||||
cur[key] = {}
|
||||
cur = cur[key]
|
||||
cur[keys[-1]] = value
|
||||
return obj
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
return has_canceled(self.task_id)
|
||||
@ -284,7 +315,7 @@ class Canvas(Graph):
|
||||
else:
|
||||
self.globals[k] = None
|
||||
|
||||
def run(self, **kwargs):
|
||||
async def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
self.message_id = get_uuid()
|
||||
created_at = int(time.time())
|
||||
@ -549,6 +580,7 @@ class Canvas(Graph):
|
||||
return self.components[cpnnm]["obj"].get_input_elements()
|
||||
|
||||
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
||||
from api.db.services.file_service import FileService
|
||||
if not files:
|
||||
return []
|
||||
def image_to_base64(file):
|
||||
|
||||
@ -163,12 +163,7 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
output_structure=None
|
||||
try:
|
||||
output_structure=self._param.outputs['structured']
|
||||
except Exception:
|
||||
pass
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]):
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
|
||||
return
|
||||
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import ast
|
||||
import os
|
||||
|
||||
@ -32,6 +32,7 @@ class IterationParam(ComponentParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.items_ref = ""
|
||||
self.veriable={}
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
|
||||
@ -249,7 +249,7 @@ class LLM(ComponentBase):
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]):
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self._stream_output, prompt, msg))
|
||||
return
|
||||
|
||||
|
||||
@ -17,7 +17,6 @@ import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import pypandoc
|
||||
import logging
|
||||
import tempfile
|
||||
from functools import partial
|
||||
@ -30,6 +29,7 @@ from common.connection_utils import timeout
|
||||
from common.misc_utils import get_uuid
|
||||
from common import settings
|
||||
|
||||
|
||||
class MessageParam(ComponentParamBase):
|
||||
"""
|
||||
Define the Message component parameters.
|
||||
@ -176,6 +176,10 @@ class Message(ComponentBase):
|
||||
return ""
|
||||
|
||||
def _convert_content(self, content):
|
||||
if not self._param.output_format:
|
||||
return
|
||||
|
||||
import pypandoc
|
||||
doc_id = get_uuid()
|
||||
|
||||
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}:
|
||||
|
||||
187
agent/component/variable_assigner.py
Normal file
187
agent/component/variable_assigner.py
Normal file
@ -0,0 +1,187 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import os
|
||||
import numbers
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
class VariableAssignerParam(ComponentParamBase):
|
||||
"""
|
||||
Define the Variable Assigner component parameters.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.variables=[]
|
||||
|
||||
def check(self):
|
||||
return True
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"items": {
|
||||
"type": "json",
|
||||
"name": "Items"
|
||||
}
|
||||
}
|
||||
|
||||
class VariableAssigner(ComponentBase,ABC):
|
||||
component_name = "VariableAssigner"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not isinstance(self._param.variables,list):
|
||||
return
|
||||
else:
|
||||
for item in self._param.variables:
|
||||
variable=item["variable"]
|
||||
operator=item["operator"]
|
||||
parameter=item["parameter"]
|
||||
variable_value=self._canvas.get_variable_value(variable)
|
||||
new_variable=self._operate(variable_value,operator,parameter)
|
||||
self._canvas.set_variable_value(variable,new_variable)
|
||||
|
||||
def _operate(self,variable,operator,parameter):
|
||||
if operator == "overwrite":
|
||||
return self._overwrite(parameter)
|
||||
elif operator == "clear":
|
||||
return self._clear(variable)
|
||||
elif operator == "set":
|
||||
return self._set(variable,parameter)
|
||||
elif operator == "append":
|
||||
return self._append(variable,parameter)
|
||||
elif operator == "extend":
|
||||
return self._extend(variable,parameter)
|
||||
elif operator == "remove_first":
|
||||
return self._remove_first(variable)
|
||||
elif operator == "remove_last":
|
||||
return self._remove_last(variable)
|
||||
elif operator == "+=":
|
||||
return self._add(variable,parameter)
|
||||
elif operator == "-=":
|
||||
return self._subtract(variable,parameter)
|
||||
elif operator == "*=":
|
||||
return self._multiply(variable,parameter)
|
||||
elif operator == "/=":
|
||||
return self._divide(variable,parameter)
|
||||
else:
|
||||
return
|
||||
|
||||
def _overwrite(self,parameter):
|
||||
return self._canvas.get_variable_value(parameter)
|
||||
|
||||
def _clear(self,variable):
|
||||
if isinstance(variable,list):
|
||||
return []
|
||||
elif isinstance(variable,str):
|
||||
return ""
|
||||
elif isinstance(variable,dict):
|
||||
return {}
|
||||
elif isinstance(variable,int):
|
||||
return 0
|
||||
elif isinstance(variable,float):
|
||||
return 0.0
|
||||
elif isinstance(variable,bool):
|
||||
return False
|
||||
else:
|
||||
return None
|
||||
|
||||
def _set(self,variable,parameter):
|
||||
if variable is None:
|
||||
return self._canvas.get_value_with_variable(parameter)
|
||||
elif isinstance(variable,str):
|
||||
return self._canvas.get_value_with_variable(parameter)
|
||||
elif isinstance(variable,bool):
|
||||
return parameter
|
||||
elif isinstance(variable,int):
|
||||
return parameter
|
||||
elif isinstance(variable,float):
|
||||
return parameter
|
||||
else:
|
||||
return parameter
|
||||
|
||||
def _append(self,variable,parameter):
|
||||
parameter=self._canvas.get_variable_value(parameter)
|
||||
if variable is None:
|
||||
variable=[]
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
elif len(variable)!=0 and not isinstance(parameter,type(variable[0])):
|
||||
return "ERROR:PARAMETER_NOT_LIST_ELEMENT_TYPE"
|
||||
else:
|
||||
return variable+parameter
|
||||
|
||||
def _extend(self,variable,parameter):
|
||||
parameter=self._canvas.get_variable_value(parameter)
|
||||
if variable is None:
|
||||
variable=[]
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
elif not isinstance(parameter,list):
|
||||
return "ERROR:PARAMETER_NOT_LIST"
|
||||
elif len(variable)!=0 and len(parameter)!=0 and not isinstance(parameter[0],type(variable[0])):
|
||||
return "ERROR:PARAMETER_NOT_LIST_ELEMENT_TYPE"
|
||||
else:
|
||||
return variable+parameter
|
||||
|
||||
def _remove_first(self,variable):
|
||||
if len(variable)==0:
|
||||
return variable
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
else:
|
||||
return variable[1:]
|
||||
|
||||
def _remove_last(self,variable):
|
||||
if len(variable)==0:
|
||||
return variable
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
else:
|
||||
return variable[:-1]
|
||||
|
||||
|
||||
def is_number(self, value):
|
||||
if isinstance(value, bool):
|
||||
return False
|
||||
return isinstance(value, numbers.Number)
|
||||
|
||||
def _add(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
return variable+parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def _subtract(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
return variable-parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def _multiply(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
return variable*parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def _divide(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
if parameter==0:
|
||||
return "ERROR:DIVIDE_BY_ZERO"
|
||||
else:
|
||||
return variable/parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
@ -18,12 +18,11 @@ import sys
|
||||
import logging
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from flask import Blueprint, Flask
|
||||
from quart import Blueprint, Quart, request, g, current_app, session
|
||||
from werkzeug.wrappers.request import Request
|
||||
from flask_cors import CORS
|
||||
from flasgger import Swagger
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
from quart_cors import cors
|
||||
from common.constants import StatusEnum
|
||||
from api.db.db_models import close_connection
|
||||
from api.db.services import UserService
|
||||
@ -31,17 +30,20 @@ from api.utils.json_encode import CustomJSONEncoder
|
||||
from api.utils import commands
|
||||
|
||||
from flask_mail import Mail
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from quart_auth import Unauthorized
|
||||
from common import settings
|
||||
from api.utils.api_utils import server_error_response
|
||||
from api.constants import API_VERSION
|
||||
from common.misc_utils import get_uuid
|
||||
|
||||
settings.init_settings()
|
||||
|
||||
__all__ = ["app"]
|
||||
|
||||
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
||||
|
||||
app = Flask(__name__)
|
||||
app = Quart(__name__)
|
||||
app = cors(app, allow_origin="*")
|
||||
smtp_mail_server = Mail()
|
||||
|
||||
# Add this at the beginning of your file to configure Swagger UI
|
||||
@ -76,7 +78,6 @@ swagger = Swagger(
|
||||
},
|
||||
)
|
||||
|
||||
CORS(app, supports_credentials=True, max_age=2592000)
|
||||
app.url_map.strict_slashes = False
|
||||
app.json_encoder = CustomJSONEncoder
|
||||
app.errorhandler(Exception)(server_error_response)
|
||||
@ -84,17 +85,143 @@ app.errorhandler(Exception)(server_error_response)
|
||||
## convince for dev and debug
|
||||
# app.config["LOGIN_DISABLED"] = True
|
||||
app.config["SESSION_PERMANENT"] = False
|
||||
app.config["SESSION_TYPE"] = "filesystem"
|
||||
app.config["SESSION_TYPE"] = "redis"
|
||||
app.config["SESSION_REDIS"] = settings.decrypt_database_config(name="redis")
|
||||
app.config["MAX_CONTENT_LENGTH"] = int(
|
||||
os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024)
|
||||
)
|
||||
|
||||
Session(app)
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
app.config['SECRET_KEY'] = settings.SECRET_KEY
|
||||
app.secret_key = settings.SECRET_KEY
|
||||
commands.register_commands(app)
|
||||
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
from collections.abc import Awaitable, Callable
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
def _load_user():
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = request.headers.get("Authorization")
|
||||
g.user = None
|
||||
if not authorization:
|
||||
return
|
||||
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
logging.warning("Authentication attempt with empty access token")
|
||||
return None
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
|
||||
return None
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
g.user = user[0]
|
||||
return user[0]
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
|
||||
|
||||
current_user = LocalProxy(_load_user)
|
||||
|
||||
|
||||
def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
|
||||
"""A decorator to restrict route access to authenticated users.
|
||||
|
||||
This should be used to wrap a route handler (or view function) to
|
||||
enforce that only authenticated requests can access it. Note that
|
||||
it is important that this decorator be wrapped by the route
|
||||
decorator and not vice, versa, as below.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@app.route('/')
|
||||
@login_required
|
||||
async def index():
|
||||
...
|
||||
|
||||
If the request is not authenticated a
|
||||
`quart.exceptions.Unauthorized` exception will be raised.
|
||||
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
if not current_user:# or not session.get("_user_id"):
|
||||
raise Unauthorized()
|
||||
else:
|
||||
return await current_app.ensure_async(func)(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def login_user(user, remember=False, duration=None, force=False, fresh=True):
|
||||
"""
|
||||
Logs a user in. You should pass the actual user object to this. If the
|
||||
user's `is_active` property is ``False``, they will not be logged in
|
||||
unless `force` is ``True``.
|
||||
|
||||
This will return ``True`` if the log in attempt succeeds, and ``False`` if
|
||||
it fails (i.e. because the user is inactive).
|
||||
|
||||
:param user: The user object to log in.
|
||||
:type user: object
|
||||
:param remember: Whether to remember the user after their session expires.
|
||||
Defaults to ``False``.
|
||||
:type remember: bool
|
||||
:param duration: The amount of time before the remember cookie expires. If
|
||||
``None`` the value set in the settings is used. Defaults to ``None``.
|
||||
:type duration: :class:`datetime.timedelta`
|
||||
:param force: If the user is inactive, setting this to ``True`` will log
|
||||
them in regardless. Defaults to ``False``.
|
||||
:type force: bool
|
||||
:param fresh: setting this to ``False`` will log in the user with a session
|
||||
marked as not "fresh". Defaults to ``True``.
|
||||
:type fresh: bool
|
||||
"""
|
||||
if not force and not user.is_active:
|
||||
return False
|
||||
|
||||
session["_user_id"] = user.id
|
||||
session["_fresh"] = fresh
|
||||
session["_id"] = get_uuid()
|
||||
return True
|
||||
|
||||
|
||||
def logout_user():
|
||||
"""
|
||||
Logs a user out. (You do not need to pass the actual user.) This will
|
||||
also clean up the remember me cookie if it exists.
|
||||
"""
|
||||
if "_user_id" in session:
|
||||
session.pop("_user_id")
|
||||
|
||||
if "_fresh" in session:
|
||||
session.pop("_fresh")
|
||||
|
||||
if "_id" in session:
|
||||
session.pop("_id")
|
||||
|
||||
COOKIE_NAME = "remember_token"
|
||||
cookie_name = current_app.config.get("REMEMBER_COOKIE_NAME", COOKIE_NAME)
|
||||
if cookie_name in request.cookies:
|
||||
session["_remember"] = "clear"
|
||||
if "_remember_seconds" in session:
|
||||
session.pop("_remember_seconds")
|
||||
|
||||
return True
|
||||
|
||||
def search_pages_path(page_path):
|
||||
app_path_list = [
|
||||
@ -142,40 +269,6 @@ client_urls_prefix = [
|
||||
]
|
||||
|
||||
|
||||
@login_manager.request_loader
|
||||
def load_user(web_request):
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = web_request.headers.get("Authorization")
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
logging.warning("Authentication attempt with empty access token")
|
||||
return None
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
|
||||
return None
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
return user[0]
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@app.teardown_request
|
||||
def _db_close(exception):
|
||||
if exception:
|
||||
|
||||
@ -14,20 +14,20 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime, timedelta
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
||||
generate_confirmation_token
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def new_token():
|
||||
req = request.json
|
||||
async def new_token():
|
||||
req = await request.json
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if not tenants:
|
||||
@ -72,8 +72,8 @@ def token_list():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@validate_request("tokens", "tenant_id")
|
||||
@login_required
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await request.json
|
||||
try:
|
||||
for token in req["tokens"]:
|
||||
APITokenService.filter_delete(
|
||||
|
||||
@ -18,12 +18,8 @@ import logging
|
||||
import re
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
import flask
|
||||
import trio
|
||||
from flask import request, Response
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from quart import request, Response, make_response
|
||||
from agent.component import LLM
|
||||
from api.db import CanvasCategory, FileType
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||
@ -35,7 +31,8 @@ from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
|
||||
request_json
|
||||
from agent.canvas import Canvas
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
from api.db.db_models import APIToken, Task
|
||||
@ -46,6 +43,7 @@ from rag.flow.pipeline import Pipeline
|
||||
from rag.nlp import search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/templates', methods=['GET']) # noqa: F821
|
||||
@ -57,8 +55,9 @@ def templates():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@validate_request("canvas_ids")
|
||||
@login_required
|
||||
def rm():
|
||||
for i in request.json["canvas_ids"]:
|
||||
async def rm():
|
||||
req = await request_json()
|
||||
for i in req["canvas_ids"]:
|
||||
if not UserCanvasService.accessible(i, current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -70,8 +69,8 @@ def rm():
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@validate_request("dsl", "title")
|
||||
@login_required
|
||||
def save():
|
||||
req = request.json
|
||||
async def save():
|
||||
req = await request_json()
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
@ -129,8 +128,8 @@ def getsse(canvas_id):
|
||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def run():
|
||||
req = request.json
|
||||
async def run():
|
||||
req = await request_json()
|
||||
query = req.get("query", "")
|
||||
files = req.get("files", [])
|
||||
inputs = req.get("inputs", {})
|
||||
@ -160,10 +159,10 @@ def run():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
def sse():
|
||||
async def sse():
|
||||
nonlocal canvas, user_id
|
||||
try:
|
||||
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
@ -179,15 +178,15 @@ def run():
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
resp.call_on_close(lambda: canvas.cancel_task())
|
||||
#resp.call_on_close(lambda: canvas.cancel_task())
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "dsl", "component_id")
|
||||
@login_required
|
||||
def rerun():
|
||||
req = request.json
|
||||
async def rerun():
|
||||
req = await request_json()
|
||||
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||
if not doc:
|
||||
return get_data_error_result(message="Document not found.")
|
||||
@ -224,8 +223,8 @@ def cancel(task_id):
|
||||
@manager.route('/reset', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def reset():
|
||||
req = request.json
|
||||
async def reset():
|
||||
req = await request_json()
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -245,7 +244,7 @@ def reset():
|
||||
|
||||
|
||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||
def upload(canvas_id):
|
||||
async def upload(canvas_id):
|
||||
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
@ -311,7 +310,8 @@ def upload(canvas_id):
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
file = request.files['file']
|
||||
files = await request.files
|
||||
file = files['file']
|
||||
try:
|
||||
DocumentService.check_doc_health(user_id, file.filename)
|
||||
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
|
||||
@ -342,8 +342,8 @@ def input_form():
|
||||
@manager.route('/debug', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "component_id", "params")
|
||||
@login_required
|
||||
def debug():
|
||||
req = request.json
|
||||
async def debug():
|
||||
req = await request_json()
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -374,8 +374,8 @@ def debug():
|
||||
@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
|
||||
@validate_request("db_type", "database", "username", "host", "port", "password")
|
||||
@login_required
|
||||
def test_db_connect():
|
||||
req = request.json
|
||||
async def test_db_connect():
|
||||
req = await request_json()
|
||||
try:
|
||||
if req["db_type"] in ["mysql", "mariadb"]:
|
||||
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||
@ -519,8 +519,8 @@ def list_canvas():
|
||||
@manager.route('/setting', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "title", "permission")
|
||||
@login_required
|
||||
def setting():
|
||||
req = request.json
|
||||
async def setting():
|
||||
req = await request_json()
|
||||
req["user_id"] = current_user.id
|
||||
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
@ -601,8 +601,8 @@ def prompts():
|
||||
|
||||
|
||||
@manager.route('/download', methods=['GET']) # noqa: F821
|
||||
def download():
|
||||
async def download():
|
||||
id = request.args.get("id")
|
||||
created_by = request.args.get("created_by")
|
||||
blob = FileService.get_blob(created_by, id)
|
||||
return flask.make_response(blob)
|
||||
return await make_response(blob)
|
||||
|
||||
@ -18,8 +18,7 @@ import json
|
||||
import re
|
||||
|
||||
import xxhash
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import request
|
||||
|
||||
from api.db.services.dialog_service import meta_filter
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -27,7 +26,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
request_json
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
@ -35,13 +35,14 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def list_chunk():
|
||||
req = request.json
|
||||
async def list_chunk():
|
||||
req = await request_json()
|
||||
doc_id = req["doc_id"]
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
@ -121,8 +122,8 @@ def get():
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
||||
def set():
|
||||
req = request.json
|
||||
async def set():
|
||||
req = await request_json()
|
||||
d = {
|
||||
"id": req["chunk_id"],
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
@ -178,8 +179,8 @@ def set():
|
||||
@manager.route('/switch', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "available_int", "doc_id")
|
||||
def switch():
|
||||
req = request.json
|
||||
async def switch():
|
||||
req = await request_json()
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
@ -198,8 +199,8 @@ def switch():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "doc_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await request_json()
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
@ -222,8 +223,8 @@ def rm():
|
||||
@manager.route('/create', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "content_with_weight")
|
||||
def create():
|
||||
req = request.json
|
||||
async def create():
|
||||
req = await request_json()
|
||||
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
|
||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
@ -280,8 +281,8 @@ def create():
|
||||
@manager.route('/retrieval_test', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval_test():
|
||||
req = request.json
|
||||
async def retrieval_test():
|
||||
req = await request_json()
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
|
||||
@ -20,8 +20,9 @@ import uuid
|
||||
from html import escape
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_login import current_user, login_required
|
||||
import flask
|
||||
import trio
|
||||
from quart import request
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
|
||||
from api.db import InputType
|
||||
@ -32,12 +33,13 @@ from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, Docum
|
||||
from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
|
||||
from common.misc_utils import get_uuid
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def set_connector():
|
||||
req = request.json
|
||||
async def set_connector():
|
||||
req = await request.json
|
||||
if req.get("id"):
|
||||
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
|
||||
ConnectorService.update_by_id(req["id"], conn)
|
||||
@ -57,7 +59,7 @@ def set_connector():
|
||||
}
|
||||
ConnectorService.save(**conn)
|
||||
|
||||
time.sleep(1)
|
||||
await trio.sleep(1)
|
||||
e, conn = ConnectorService.get_by_id(req["id"])
|
||||
|
||||
return get_json_result(data=conn.to_dict())
|
||||
@ -88,8 +90,8 @@ def list_logs(connector_id):
|
||||
|
||||
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
def resume(connector_id):
|
||||
req = request.json
|
||||
async def resume(connector_id):
|
||||
req = await request.json
|
||||
if req.get("resume"):
|
||||
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
|
||||
else:
|
||||
@ -100,8 +102,8 @@ def resume(connector_id):
|
||||
@manager.route("/<connector_id>/rebuild", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def rebuild(connector_id):
|
||||
req = request.json
|
||||
async def rebuild(connector_id):
|
||||
req = await request.json
|
||||
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
|
||||
if err:
|
||||
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
|
||||
@ -163,7 +165,7 @@ def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
|
||||
payload_json=payload_json,
|
||||
auto_close=auto_close,
|
||||
)
|
||||
response = make_response(html, 200)
|
||||
response = flask.make_response(html, 200)
|
||||
response.headers["Content-Type"] = "text/html; charset=utf-8"
|
||||
return response
|
||||
|
||||
@ -171,14 +173,14 @@ def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
|
||||
@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("credentials")
|
||||
def start_google_drive_web_oauth():
|
||||
async def start_google_drive_web_oauth():
|
||||
if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI:
|
||||
return get_json_result(
|
||||
code=RetCode.SERVER_ERROR,
|
||||
message="Google Drive OAuth redirect URI is not configured on the server.",
|
||||
)
|
||||
|
||||
req = request.json or {}
|
||||
req = await request.json or {}
|
||||
raw_credentials = req.get("credentials", "")
|
||||
try:
|
||||
credentials = _load_credentials(raw_credentials)
|
||||
@ -279,8 +281,8 @@ def google_drive_web_oauth_callback():
|
||||
@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("flow_id")
|
||||
def poll_google_drive_web_result():
|
||||
req = request.json or {}
|
||||
async def poll_google_drive_web_result():
|
||||
req = await request.json or {}
|
||||
flow_id = req.get("flow_id")
|
||||
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id))
|
||||
if not cache_raw:
|
||||
|
||||
@ -17,8 +17,8 @@ import json
|
||||
import re
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import Response, request
|
||||
from api.apps import current_user, login_required
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
|
||||
@ -34,8 +34,8 @@ from common.constants import RetCode, LLMType
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def set_conversation():
|
||||
req = request.json
|
||||
async def set_conversation():
|
||||
req = await request.json
|
||||
conv_id = req.get("conversation_id")
|
||||
is_new = req.get("is_new")
|
||||
name = req.get("name", "New conversation")
|
||||
@ -128,8 +128,9 @@ def getsse(dialog_id):
|
||||
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def rm():
|
||||
conv_ids = request.json["conversation_ids"]
|
||||
async def rm():
|
||||
req = await request.json
|
||||
conv_ids = req["conversation_ids"]
|
||||
try:
|
||||
for cid in conv_ids:
|
||||
exist, conv = ConversationService.get_by_id(cid)
|
||||
@ -165,8 +166,8 @@ def list_conversation():
|
||||
@manager.route("/completion", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "messages")
|
||||
def completion():
|
||||
req = request.json
|
||||
async def completion():
|
||||
req = await request.json
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
@ -250,8 +251,8 @@ def completion():
|
||||
|
||||
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def tts():
|
||||
req = request.json
|
||||
async def tts():
|
||||
req = await request.json
|
||||
text = req["text"]
|
||||
|
||||
tenants = TenantService.get_info_by(current_user.id)
|
||||
@ -283,8 +284,8 @@ def tts():
|
||||
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def delete_msg():
|
||||
req = request.json
|
||||
async def delete_msg():
|
||||
req = await request.json
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
@ -306,8 +307,8 @@ def delete_msg():
|
||||
@manager.route("/thumbup", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def thumbup():
|
||||
req = request.json
|
||||
async def thumbup():
|
||||
req = await request.json
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
@ -333,8 +334,8 @@ def thumbup():
|
||||
@manager.route("/ask", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def ask_about():
|
||||
req = request.json
|
||||
async def ask_about():
|
||||
req = await request.json
|
||||
uid = current_user.id
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
@ -365,8 +366,8 @@ def ask_about():
|
||||
@manager.route("/mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def mindmap():
|
||||
req = request.json
|
||||
async def mindmap():
|
||||
req = await request.json
|
||||
search_id = req.get("search_id", "")
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
search_config = search_app.get("search_config", {}) if search_app else {}
|
||||
@ -383,8 +384,8 @@ def mindmap():
|
||||
@manager.route("/related_questions", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question")
|
||||
def related_questions():
|
||||
req = request.json
|
||||
async def related_questions():
|
||||
req = await request.json
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_config = {}
|
||||
|
||||
@ -14,8 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from common.constants import StatusEnum
|
||||
@ -26,13 +25,14 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@validate_request("prompt_config")
|
||||
@login_required
|
||||
def set_dialog():
|
||||
req = request.json
|
||||
async def set_dialog():
|
||||
req = await request.json
|
||||
dialog_id = req.get("dialog_id", "")
|
||||
is_create = not dialog_id
|
||||
name = req.get("name", "New Dialog")
|
||||
@ -169,18 +169,19 @@ def list_dialogs():
|
||||
|
||||
@manager.route('/next', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def list_dialogs_next():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
parser_id = request.args.get("parser_id")
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
async def list_dialogs_next():
|
||||
args = request.args
|
||||
keywords = args.get("keywords", "")
|
||||
page_number = int(args.get("page", 0))
|
||||
items_per_page = int(args.get("page_size", 0))
|
||||
parser_id = args.get("parser_id")
|
||||
orderby = args.get("orderby", "create_time")
|
||||
if args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await request.get_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -207,8 +208,8 @@ def list_dialogs_next():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("dialog_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await request.json
|
||||
dialog_list=[]
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
try:
|
||||
|
||||
@ -18,11 +18,8 @@ import os.path
|
||||
import pathlib
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
|
||||
from quart import request, make_response
|
||||
from api.apps import current_user, login_required
|
||||
from api.common.check_team_permission import check_kb_team_permission
|
||||
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
||||
from api.db import VALID_FILE_TYPES, FileType
|
||||
@ -39,7 +36,7 @@ from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
validate_request, request_json,
|
||||
)
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from common.file_utils import get_project_base_directory
|
||||
@ -53,14 +50,16 @@ from common import settings
|
||||
@manager.route("/upload", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def upload():
|
||||
kb_id = request.form.get("kb_id")
|
||||
async def upload():
|
||||
form = await request.form
|
||||
kb_id = form.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
if "file" not in request.files:
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -87,12 +86,13 @@ def upload():
|
||||
@manager.route("/web_crawl", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id", "name", "url")
|
||||
def web_crawl():
|
||||
kb_id = request.form.get("kb_id")
|
||||
async def web_crawl():
|
||||
form = await request.form
|
||||
kb_id = form.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
name = request.form.get("name")
|
||||
url = request.form.get("url")
|
||||
name = form.get("name")
|
||||
url = form.get("url")
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
@ -152,8 +152,8 @@ def web_crawl():
|
||||
@manager.route("/create", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "kb_id")
|
||||
def create():
|
||||
req = request.json
|
||||
async def create():
|
||||
req = await request_json()
|
||||
kb_id = req["kb_id"]
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -208,7 +208,7 @@ def create():
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_docs():
|
||||
async def list_docs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -230,7 +230,7 @@ def list_docs():
|
||||
create_time_from = int(request.args.get("create_time_from", 0))
|
||||
create_time_to = int(request.args.get("create_time_to", 0))
|
||||
|
||||
req = request.get_json()
|
||||
req = await request.get_json()
|
||||
|
||||
run_status = req.get("run_status", [])
|
||||
if run_status:
|
||||
@ -270,8 +270,8 @@ def list_docs():
|
||||
|
||||
@manager.route("/filter", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def get_filter():
|
||||
req = request.get_json()
|
||||
async def get_filter():
|
||||
req = await request.get_json()
|
||||
|
||||
kb_id = req.get("kb_id")
|
||||
if not kb_id:
|
||||
@ -308,8 +308,8 @@ def get_filter():
|
||||
|
||||
@manager.route("/infos", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def doc_infos():
|
||||
req = request.json
|
||||
async def doc_infos():
|
||||
req = await request_json()
|
||||
doc_ids = req["doc_ids"]
|
||||
for doc_id in doc_ids:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
@ -340,8 +340,8 @@ def thumbnails():
|
||||
@manager.route("/change_status", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_ids", "status")
|
||||
def change_status():
|
||||
req = request.get_json()
|
||||
async def change_status():
|
||||
req = await request.get_json()
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
status = str(req.get("status", ""))
|
||||
|
||||
@ -380,8 +380,8 @@ def change_status():
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await request_json()
|
||||
doc_ids = req["doc_id"]
|
||||
if isinstance(doc_ids, str):
|
||||
doc_ids = [doc_ids]
|
||||
@ -401,8 +401,8 @@ def rm():
|
||||
@manager.route("/run", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_ids", "run")
|
||||
def run():
|
||||
req = request.json
|
||||
async def run():
|
||||
req = await request_json()
|
||||
for doc_id in req["doc_ids"]:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
@ -448,8 +448,8 @@ def run():
|
||||
@manager.route("/rename", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "name")
|
||||
def rename():
|
||||
req = request.json
|
||||
async def rename():
|
||||
req = await request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
@ -495,14 +495,14 @@ def rename():
|
||||
|
||||
@manager.route("/get/<doc_id>", methods=["GET"]) # noqa: F821
|
||||
# @login_required
|
||||
def get(doc_id):
|
||||
async def get(doc_id):
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
response = flask.make_response(settings.STORAGE_IMPL.get(b, n))
|
||||
response = await make_response(settings.STORAGE_IMPL.get(b, n))
|
||||
|
||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
@ -520,12 +520,12 @@ def get(doc_id):
|
||||
|
||||
@manager.route("/download/<attachment_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def download_attachment(attachment_id):
|
||||
async def download_attachment(attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id)
|
||||
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
|
||||
response = flask.make_response(data)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
return response
|
||||
@ -537,9 +537,9 @@ def download_attachment(attachment_id):
|
||||
@manager.route("/change_parser", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def change_parser():
|
||||
async def change_parser():
|
||||
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
@ -590,13 +590,13 @@ def change_parser():
|
||||
|
||||
@manager.route("/image/<image_id>", methods=["GET"]) # noqa: F821
|
||||
# @login_required
|
||||
def get_image(image_id):
|
||||
async def get_image(image_id):
|
||||
try:
|
||||
arr = image_id.split("-")
|
||||
if len(arr) != 2:
|
||||
return get_data_error_result(message="Image not found.")
|
||||
bkt, nm = image_id.split("-")
|
||||
response = flask.make_response(settings.STORAGE_IMPL.get(bkt, nm))
|
||||
response = await make_response(settings.STORAGE_IMPL.get(bkt, nm))
|
||||
response.headers.set("Content-Type", "image/JPEG")
|
||||
return response
|
||||
except Exception as e:
|
||||
@ -606,24 +606,25 @@ def get_image(image_id):
|
||||
@manager.route("/upload_and_parse", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id")
|
||||
def upload_and_parse():
|
||||
if "file" not in request.files:
|
||||
async def upload_and_parse():
|
||||
files = await request.file
|
||||
if "file" not in files:
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
|
||||
|
||||
form = await request.form
|
||||
doc_ids = doc_upload_and_parse(form.get("conversation_id"), file_objs, current_user.id)
|
||||
return get_json_result(data=doc_ids)
|
||||
|
||||
|
||||
@manager.route("/parse", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def parse():
|
||||
url = request.json.get("url") if request.json else ""
|
||||
async def parse():
|
||||
url = await request.json.get("url") if await request.json else ""
|
||||
if url:
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -664,10 +665,11 @@ def parse():
|
||||
txt = FileService.parse_docs([f], current_user.id)
|
||||
return get_json_result(data=txt)
|
||||
|
||||
if "file" not in request.files:
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
txt = FileService.parse_docs(file_objs, current_user.id)
|
||||
|
||||
return get_json_result(data=txt)
|
||||
@ -676,8 +678,8 @@ def parse():
|
||||
@manager.route("/set_meta", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "meta")
|
||||
def set_meta():
|
||||
req = request.json
|
||||
async def set_meta():
|
||||
req = await request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
|
||||
@ -19,8 +19,8 @@ from pathlib import Path
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
from api.apps import login_required, current_user
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from common.misc_utils import get_uuid
|
||||
@ -33,8 +33,8 @@ from api.utils.api_utils import get_json_result
|
||||
@manager.route('/convert', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids", "kb_ids")
|
||||
def convert():
|
||||
req = request.json
|
||||
async def convert():
|
||||
req = await request.json
|
||||
kb_ids = req["kb_ids"]
|
||||
file_ids = req["file_ids"]
|
||||
file2documents = []
|
||||
@ -103,8 +103,8 @@ def convert():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await request.json
|
||||
file_ids = req["file_ids"]
|
||||
if not file_ids:
|
||||
return get_json_result(
|
||||
|
||||
@ -17,10 +17,8 @@ import logging
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request, make_response
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
from api.common.check_team_permission import check_file_team_permission
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -40,17 +38,19 @@ from common import settings
|
||||
@manager.route('/upload', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
# @validate_request("parent_id")
|
||||
def upload():
|
||||
pf_id = request.form.get("parent_id")
|
||||
async def upload():
|
||||
form = await request.form
|
||||
pf_id = form.get("parent_id")
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
if 'file' not in request.files:
|
||||
files = await request.files
|
||||
if 'file' not in files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
file_objs = request.files.getlist('file')
|
||||
file_objs = files.getlist('file')
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
@ -123,10 +123,10 @@ def upload():
|
||||
@manager.route('/create', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.json
|
||||
pf_id = request.json.get("parent_id")
|
||||
input_file_type = request.json.get("type")
|
||||
async def create():
|
||||
req = await request.json
|
||||
pf_id = await request.json.get("parent_id")
|
||||
input_file_type = await request.json.get("type")
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
@ -238,8 +238,8 @@ def get_all_parent_folders():
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await request.json
|
||||
file_ids = req["file_ids"]
|
||||
|
||||
def _delete_single_file(file):
|
||||
@ -299,8 +299,8 @@ def rm():
|
||||
@manager.route('/rename', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_id", "name")
|
||||
def rename():
|
||||
req = request.json
|
||||
async def rename():
|
||||
req = await request.json
|
||||
try:
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
@ -338,7 +338,7 @@ def rename():
|
||||
|
||||
@manager.route('/get/<file_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get(file_id):
|
||||
async def get(file_id):
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
@ -351,7 +351,7 @@ def get(file_id):
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = settings.STORAGE_IMPL.get(b, n)
|
||||
|
||||
response = flask.make_response(blob)
|
||||
response = await make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
if ext:
|
||||
@ -368,8 +368,8 @@ def get(file_id):
|
||||
@manager.route("/mv", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("src_file_ids", "dest_file_id")
|
||||
def move():
|
||||
req = request.json
|
||||
async def move():
|
||||
req = await request.json
|
||||
try:
|
||||
file_ids = req["src_file_ids"]
|
||||
dest_parent_id = req["dest_file_id"]
|
||||
|
||||
@ -18,11 +18,9 @@ import logging
|
||||
import random
|
||||
import re
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
import numpy as np
|
||||
|
||||
|
||||
from api.db.services.connector_service import Connector2KbService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
@ -31,7 +29,8 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
|
||||
request_json
|
||||
from api.db import VALID_FILE_TYPES
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
@ -42,27 +41,28 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/create', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.json
|
||||
req = KnowledgebaseService.create_with_name(
|
||||
async def create():
|
||||
req = await request_json()
|
||||
e, res = KnowledgebaseService.create_with_name(
|
||||
name = req.pop("name", None),
|
||||
tenant_id = current_user.id,
|
||||
parser_id = req.pop("parser_id", None),
|
||||
**req
|
||||
)
|
||||
|
||||
code = req.get("code")
|
||||
if code:
|
||||
return get_data_error_result(code=code, message=req.get("message"))
|
||||
if not e:
|
||||
return res
|
||||
|
||||
try:
|
||||
if not KnowledgebaseService.save(**req):
|
||||
if not KnowledgebaseService.save(**res):
|
||||
return get_data_error_result()
|
||||
return get_json_result(data={"kb_id":req["id"]})
|
||||
return get_json_result(data={"kb_id":res["id"]})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -71,8 +71,8 @@ def create():
|
||||
@login_required
|
||||
@validate_request("kb_id", "name", "description", "parser_id")
|
||||
@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||
def update():
|
||||
req = request.json
|
||||
async def update():
|
||||
req = await request_json()
|
||||
if not isinstance(req["name"], str):
|
||||
return get_data_error_result(message="Dataset name must be string.")
|
||||
if req["name"].strip() == "":
|
||||
@ -170,18 +170,19 @@ def detail():
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def list_kbs():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
parser_id = request.args.get("parser_id")
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
async def list_kbs():
|
||||
args = request.args
|
||||
keywords = args.get("keywords", "")
|
||||
page_number = int(args.get("page", 0))
|
||||
items_per_page = int(args.get("page_size", 0))
|
||||
parser_id = args.get("parser_id")
|
||||
orderby = args.get("orderby", "create_time")
|
||||
if args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await request_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -203,11 +204,12 @@ def list_kbs():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await request_json()
|
||||
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -283,8 +285,8 @@ def list_tags_from_kbs():
|
||||
|
||||
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def rm_tags(kb_id):
|
||||
req = request.json
|
||||
async def rm_tags(kb_id):
|
||||
req = await request_json()
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -303,8 +305,8 @@ def rm_tags(kb_id):
|
||||
|
||||
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def rename_tags(kb_id):
|
||||
req = request.json
|
||||
async def rename_tags(kb_id):
|
||||
req = await request_json()
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -407,7 +409,7 @@ def get_basic_info():
|
||||
|
||||
@manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_pipeline_logs():
|
||||
async def list_pipeline_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -426,7 +428,7 @@ def list_pipeline_logs():
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = request.get_json()
|
||||
req = await request_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
@ -451,7 +453,7 @@ def list_pipeline_logs():
|
||||
|
||||
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_pipeline_dataset_logs():
|
||||
async def list_pipeline_dataset_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -468,7 +470,7 @@ def list_pipeline_dataset_logs():
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = request.get_json()
|
||||
req = await request_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
@ -485,12 +487,12 @@ def list_pipeline_dataset_logs():
|
||||
|
||||
@manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def delete_pipeline_logs():
|
||||
async def delete_pipeline_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
req = request.get_json()
|
||||
req = await request_json()
|
||||
log_ids = req.get("log_ids", [])
|
||||
|
||||
PipelineOperationLogService.delete_by_ids(log_ids)
|
||||
@ -514,8 +516,8 @@ def pipeline_log_detail():
|
||||
|
||||
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_graphrag():
|
||||
req = request.json
|
||||
async def run_graphrag():
|
||||
req = await request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -583,8 +585,8 @@ def trace_graphrag():
|
||||
|
||||
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_raptor():
|
||||
req = request.json
|
||||
async def run_raptor():
|
||||
req = await request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -652,8 +654,8 @@ def trace_raptor():
|
||||
|
||||
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_mindmap():
|
||||
req = request.json
|
||||
async def run_mindmap():
|
||||
req = await request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -768,7 +770,7 @@ def delete_kb_task():
|
||||
|
||||
@manager.route("/check_embedding", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
def check_embedding():
|
||||
async def check_embedding():
|
||||
|
||||
def _guess_vec_field(src: dict) -> str | None:
|
||||
for k in src or {}:
|
||||
@ -819,7 +821,7 @@ def check_embedding():
|
||||
return []
|
||||
|
||||
n = min(n, total)
|
||||
offsets = sorted(random.sample(range(total), n))
|
||||
offsets = sorted(random.sample(range(min(total,1000)), n))
|
||||
out = []
|
||||
|
||||
for off in offsets:
|
||||
@ -859,7 +861,7 @@ def check_embedding():
|
||||
def _clean(s: str) -> str:
|
||||
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
|
||||
return s if s else "None"
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
kb_id = req.get("kb_id", "")
|
||||
embd_id = req.get("embd_id", "")
|
||||
n = int(req.get("check_num", 5))
|
||||
|
||||
@ -15,8 +15,8 @@
|
||||
#
|
||||
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import request
|
||||
from api.apps import current_user, login_required
|
||||
from langfuse import Langfuse
|
||||
|
||||
from api.db.db_models import DB
|
||||
@ -27,8 +27,8 @@ from api.utils.api_utils import get_error_data_result, get_json_result, server_e
|
||||
@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("secret_key", "public_key", "host")
|
||||
def set_api_key():
|
||||
req = request.get_json()
|
||||
async def set_api_key():
|
||||
req = await request.get_json()
|
||||
secret_key = req.get("secret_key", "")
|
||||
public_key = req.get("public_key", "")
|
||||
host = req.get("host", "")
|
||||
|
||||
@ -16,8 +16,9 @@
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
|
||||
from api.apps import login_required, current_user
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
@ -52,8 +53,8 @@ def factories():
|
||||
@manager.route("/set_api_key", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "api_key")
|
||||
def set_api_key():
|
||||
req = request.json
|
||||
async def set_api_key():
|
||||
req = await request.json
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
factory = req["llm_factory"]
|
||||
@ -122,8 +123,8 @@ def set_api_key():
|
||||
@manager.route("/add_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
def add_llm():
|
||||
req = request.json
|
||||
async def add_llm():
|
||||
req = await request.json
|
||||
factory = req["llm_factory"]
|
||||
api_key = req.get("api_key", "x")
|
||||
llm_name = req.get("llm_name")
|
||||
@ -142,11 +143,11 @@ def add_llm():
|
||||
|
||||
elif factory == "Tencent Hunyuan":
|
||||
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
|
||||
return set_api_key()
|
||||
return await set_api_key()
|
||||
|
||||
elif factory == "Tencent Cloud":
|
||||
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
|
||||
return set_api_key()
|
||||
return await set_api_key()
|
||||
|
||||
elif factory == "Bedrock":
|
||||
# For Bedrock, due to its special authentication method
|
||||
@ -267,8 +268,8 @@ def add_llm():
|
||||
@manager.route("/delete_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
def delete_llm():
|
||||
req = request.json
|
||||
async def delete_llm():
|
||||
req = await request.json
|
||||
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -276,8 +277,8 @@ def delete_llm():
|
||||
@manager.route("/enable_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
def enable_llm():
|
||||
req = request.json
|
||||
async def enable_llm():
|
||||
req = await request.json
|
||||
TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
|
||||
)
|
||||
@ -287,8 +288,8 @@ def enable_llm():
|
||||
@manager.route("/delete_factory", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
def delete_factory():
|
||||
req = request.json
|
||||
async def delete_factory():
|
||||
req = await request.json
|
||||
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import Response, request
|
||||
from api.apps import current_user, login_required
|
||||
|
||||
from api.db.db_models import MCPServer
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
@ -30,7 +30,7 @@ from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_too
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_mcp() -> Response:
|
||||
async def list_mcp() -> Response:
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
@ -40,7 +40,7 @@ def list_mcp() -> Response:
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await request.get_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
try:
|
||||
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
|
||||
@ -72,8 +72,8 @@ def detail() -> Response:
|
||||
@manager.route("/create", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "url", "server_type")
|
||||
def create() -> Response:
|
||||
req = request.get_json()
|
||||
async def create() -> Response:
|
||||
req = await request.get_json()
|
||||
|
||||
server_type = req.get("server_type", "")
|
||||
if server_type not in VALID_MCP_SERVER_TYPES:
|
||||
@ -127,8 +127,8 @@ def create() -> Response:
|
||||
@manager.route("/update", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id")
|
||||
def update() -> Response:
|
||||
req = request.get_json()
|
||||
async def update() -> Response:
|
||||
req = await request.get_json()
|
||||
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
@ -183,8 +183,8 @@ def update() -> Response:
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def rm() -> Response:
|
||||
req = request.get_json()
|
||||
async def rm() -> Response:
|
||||
req = await request.get_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
|
||||
try:
|
||||
@ -201,8 +201,8 @@ def rm() -> Response:
|
||||
@manager.route("/import", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcpServers")
|
||||
def import_multiple() -> Response:
|
||||
req = request.get_json()
|
||||
async def import_multiple() -> Response:
|
||||
req = await request.get_json()
|
||||
servers = req.get("mcpServers", {})
|
||||
if not servers:
|
||||
return get_data_error_result(message="No MCP servers provided.")
|
||||
@ -268,8 +268,8 @@ def import_multiple() -> Response:
|
||||
@manager.route("/export", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def export_multiple() -> Response:
|
||||
req = request.get_json()
|
||||
async def export_multiple() -> Response:
|
||||
req = await request.get_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
|
||||
if not mcp_ids:
|
||||
@ -300,8 +300,8 @@ def export_multiple() -> Response:
|
||||
@manager.route("/list_tools", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def list_tools() -> Response:
|
||||
req = request.get_json()
|
||||
async def list_tools() -> Response:
|
||||
req = await request.get_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
if not mcp_ids:
|
||||
return get_data_error_result(message="No MCP server IDs provided.")
|
||||
@ -347,8 +347,8 @@ def list_tools() -> Response:
|
||||
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id", "tool_name", "arguments")
|
||||
def test_tool() -> Response:
|
||||
req = request.get_json()
|
||||
async def test_tool() -> Response:
|
||||
req = await request.get_json()
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
@ -380,8 +380,8 @@ def test_tool() -> Response:
|
||||
@manager.route("/cache_tools", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id", "tools")
|
||||
def cache_tool() -> Response:
|
||||
req = request.get_json()
|
||||
async def cache_tool() -> Response:
|
||||
req = await request.get_json()
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
@ -403,8 +403,8 @@ def cache_tool() -> Response:
|
||||
|
||||
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821
|
||||
@validate_request("url", "server_type")
|
||||
def test_mcp() -> Response:
|
||||
req = request.get_json()
|
||||
async def test_mcp() -> Response:
|
||||
req = await request.get_json()
|
||||
|
||||
url = req.get("url", "")
|
||||
if not url:
|
||||
|
||||
@ -15,8 +15,8 @@
|
||||
#
|
||||
|
||||
|
||||
from flask import Response
|
||||
from flask_login import login_required
|
||||
from quart import Response
|
||||
from api.apps import login_required
|
||||
from api.utils.api_utils import get_json_result
|
||||
from plugin import GlobalPluginManager
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required
|
||||
from api.utils.api_utils import get_result
|
||||
from flask import request, Response
|
||||
from quart import request, Response
|
||||
|
||||
|
||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||
@ -52,8 +52,8 @@ def list_agents(tenant_id):
|
||||
|
||||
@manager.route("/agents", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create_agent(tenant_id: str):
|
||||
req: dict[str, Any] = cast(dict[str, Any], request.json)
|
||||
async def create_agent(tenant_id: str):
|
||||
req: dict[str, Any] = cast(dict[str, Any], await request.json)
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
@ -89,8 +89,8 @@ def create_agent(tenant_id: str):
|
||||
|
||||
@manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update_agent(tenant_id: str, agent_id: str):
|
||||
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], request.json).items() if v is not None}
|
||||
async def update_agent(tenant_id: str, agent_id: str):
|
||||
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await request.json)).items() if v is not None}
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
@ -135,8 +135,8 @@ def delete_agent(tenant_id: str, agent_id: str):
|
||||
|
||||
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def webhook(tenant_id: str, agent_id: str):
|
||||
req = request.json
|
||||
async def webhook(tenant_id: str, agent_id: str):
|
||||
req = await request.json
|
||||
if not UserCanvasService.accessible(req["id"], tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
|
||||
@ -14,22 +14,20 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
|
||||
from quart import request
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode, StatusEnum
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, request_json
|
||||
|
||||
|
||||
@manager.route("/chats", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
req = request.json
|
||||
async def create(tenant_id):
|
||||
req = await request_json()
|
||||
ids = [i for i in req.get("dataset_ids", []) if i]
|
||||
for kb_id in ids:
|
||||
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
|
||||
@ -145,10 +143,10 @@ def create(tenant_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, chat_id):
|
||||
async def update(tenant_id, chat_id):
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message="You do not own the chat")
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
ids = req.get("dataset_ids", [])
|
||||
if "show_quotation" in req:
|
||||
req["do_refer"] = req.pop("show_quotation")
|
||||
@ -228,10 +226,10 @@ def update(tenant_id, chat_id):
|
||||
|
||||
@manager.route("/chats", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id):
|
||||
async def delete_chats(tenant_id):
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if not req:
|
||||
ids = None
|
||||
else:
|
||||
@ -251,8 +249,8 @@ def delete(tenant_id):
|
||||
errors.append(f"Assistant({id}) not found.")
|
||||
continue
|
||||
temp_dict = {"status": StatusEnum.INVALID.value}
|
||||
DialogService.update_by_id(id, temp_dict)
|
||||
success_count += 1
|
||||
success_count += DialogService.update_by_id(id, temp_dict)
|
||||
print(success_count, "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$", flush=True)
|
||||
|
||||
if errors:
|
||||
if success_count > 0:
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from flask import request
|
||||
from quart import request
|
||||
from peewee import OperationalError
|
||||
from api.db.db_models import File
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
@ -54,7 +54,7 @@ from common import settings
|
||||
|
||||
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
async def create(tenant_id):
|
||||
"""
|
||||
Create a new dataset.
|
||||
---
|
||||
@ -116,16 +116,19 @@ def create(tenant_id):
|
||||
# | embedding_model| embd_id |
|
||||
# | chunk_method | parser_id |
|
||||
|
||||
req, err = validate_and_parse_json_request(request, CreateDatasetReq)
|
||||
req, err = await validate_and_parse_json_request(request, CreateDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
req = KnowledgebaseService.create_with_name(
|
||||
e, req = KnowledgebaseService.create_with_name(
|
||||
name = req.pop("name", None),
|
||||
tenant_id = tenant_id,
|
||||
parser_id = req.pop("parser_id", None),
|
||||
**req
|
||||
)
|
||||
|
||||
if not e:
|
||||
return req
|
||||
|
||||
# Insert embedding model(embd id)
|
||||
ok, t = TenantService.get_by_id(tenant_id)
|
||||
if not ok:
|
||||
@ -152,7 +155,7 @@ def create(tenant_id):
|
||||
|
||||
@manager.route("/datasets", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id):
|
||||
async def delete(tenant_id):
|
||||
"""
|
||||
Delete datasets.
|
||||
---
|
||||
@ -190,7 +193,7 @@ def delete(tenant_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req, err = validate_and_parse_json_request(request, DeleteDatasetReq)
|
||||
req, err = await validate_and_parse_json_request(request, DeleteDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
@ -250,7 +253,7 @@ def delete(tenant_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, dataset_id):
|
||||
async def update(tenant_id, dataset_id):
|
||||
"""
|
||||
Update a dataset.
|
||||
---
|
||||
@ -316,7 +319,7 @@ def update(tenant_id, dataset_id):
|
||||
# | embedding_model| embd_id |
|
||||
# | chunk_method | parser_id |
|
||||
extras = {"dataset_id": dataset_id}
|
||||
req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
|
||||
req, err = await validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
#
|
||||
import logging
|
||||
|
||||
from flask import request, jsonify
|
||||
from quart import request, jsonify
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -29,7 +29,7 @@ from common import settings
|
||||
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
|
||||
@apikey_required
|
||||
@validate_request("knowledge_id", "query")
|
||||
def retrieval(tenant_id):
|
||||
async def retrieval(tenant_id):
|
||||
"""
|
||||
Dify-compatible retrieval API
|
||||
---
|
||||
@ -113,7 +113,7 @@ def retrieval(tenant_id):
|
||||
404:
|
||||
description: Knowledge base or document not found
|
||||
"""
|
||||
req = request.json
|
||||
req = await request.json
|
||||
question = req["query"]
|
||||
kb_id = req["knowledge_id"]
|
||||
use_kg = req.get("use_kg", False)
|
||||
@ -131,12 +131,10 @@ def retrieval(tenant_id):
|
||||
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
print(metadata_condition)
|
||||
# print("after", convert_conditions(metadata_condition))
|
||||
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition)))
|
||||
# print("doc_ids", doc_ids)
|
||||
if not doc_ids and metadata_condition is not None:
|
||||
doc_ids = ['-999']
|
||||
if metadata_condition:
|
||||
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition)))
|
||||
if not doc_ids and metadata_condition:
|
||||
doc_ids = ["-999"]
|
||||
ranks = settings.retriever.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
|
||||
@ -20,7 +20,7 @@ import re
|
||||
from io import BytesIO
|
||||
|
||||
import xxhash
|
||||
from flask import request, send_file
|
||||
from quart import request, send_file
|
||||
from peewee import OperationalError
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
@ -35,7 +35,8 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.task_service import TaskService, queue_tasks
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
|
||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
|
||||
request_json
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
@ -69,7 +70,7 @@ class Chunk(BaseModel):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def upload(dataset_id, tenant_id):
|
||||
async def upload(dataset_id, tenant_id):
|
||||
"""
|
||||
Upload documents to a dataset.
|
||||
---
|
||||
@ -130,9 +131,11 @@ def upload(dataset_id, tenant_id):
|
||||
type: string
|
||||
description: Processing status.
|
||||
"""
|
||||
if "file" not in request.files:
|
||||
form = await request.form
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_error_data_result(message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_result(message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -155,7 +158,7 @@ def upload(dataset_id, tenant_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not e:
|
||||
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
|
||||
err, files = FileService.upload_document(kb, file_objs, tenant_id, parent_path=request.form.get("parent_path"))
|
||||
err, files = FileService.upload_document(kb, file_objs, tenant_id, parent_path=form.get("parent_path"))
|
||||
if err:
|
||||
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||
# rename key's name
|
||||
@ -179,7 +182,7 @@ def upload(dataset_id, tenant_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update_doc(tenant_id, dataset_id, document_id):
|
||||
async def update_doc(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Update a document within a dataset.
|
||||
---
|
||||
@ -228,7 +231,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
|
||||
return get_error_data_result(message="You don't own the dataset.")
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
@ -359,7 +362,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def download(tenant_id, dataset_id, document_id):
|
||||
async def download(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Download a document from a dataset.
|
||||
---
|
||||
@ -409,10 +412,10 @@ def download(tenant_id, dataset_id, document_id):
|
||||
return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR)
|
||||
file = BytesIO(file_stream)
|
||||
# Use send_file with a proper filename and MIME type
|
||||
return send_file(
|
||||
return await send_file(
|
||||
file,
|
||||
as_attachment=True,
|
||||
download_name=doc[0].name,
|
||||
attachment_filename=doc[0].name,
|
||||
mimetype="application/octet-stream", # Set a default MIME type
|
||||
)
|
||||
|
||||
@ -589,7 +592,7 @@ def list_docs(dataset_id, tenant_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id, dataset_id):
|
||||
async def delete(tenant_id, dataset_id):
|
||||
"""
|
||||
Delete documents from a dataset.
|
||||
---
|
||||
@ -628,7 +631,7 @@ def delete(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if not req:
|
||||
doc_ids = None
|
||||
else:
|
||||
@ -699,7 +702,7 @@ def delete(tenant_id, dataset_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/chunks", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def parse(tenant_id, dataset_id):
|
||||
async def parse(tenant_id, dataset_id):
|
||||
"""
|
||||
Start parsing documents into chunks.
|
||||
---
|
||||
@ -738,7 +741,7 @@ def parse(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if not req.get("document_ids"):
|
||||
return get_error_data_result("`document_ids` is required")
|
||||
doc_list = req.get("document_ids")
|
||||
@ -782,7 +785,7 @@ def parse(tenant_id, dataset_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/chunks", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def stop_parsing(tenant_id, dataset_id):
|
||||
async def stop_parsing(tenant_id, dataset_id):
|
||||
"""
|
||||
Stop parsing documents into chunks.
|
||||
---
|
||||
@ -821,7 +824,7 @@ def stop_parsing(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
|
||||
if not req.get("document_ids"):
|
||||
return get_error_data_result("`document_ids` is required")
|
||||
@ -1023,7 +1026,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
"/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["POST"]
|
||||
)
|
||||
@token_required
|
||||
def add_chunk(tenant_id, dataset_id, document_id):
|
||||
async def add_chunk(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Add a chunk to a document.
|
||||
---
|
||||
@ -1093,7 +1096,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
||||
if not doc:
|
||||
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
||||
doc = doc[0]
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if not str(req.get("content", "")).strip():
|
||||
return get_error_data_result(message="`content` is required")
|
||||
if "important_keywords" in req:
|
||||
@ -1152,7 +1155,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
||||
"datasets/<dataset_id>/documents/<document_id>/chunks", methods=["DELETE"]
|
||||
)
|
||||
@token_required
|
||||
def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
async def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Remove chunks from a document.
|
||||
---
|
||||
@ -1199,7 +1202,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
docs = DocumentService.get_by_ids([document_id])
|
||||
if not docs:
|
||||
raise LookupError(f"Can't find the document with ID {document_id}!")
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
condition = {"doc_id": document_id}
|
||||
if "chunk_ids" in req:
|
||||
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
|
||||
@ -1223,7 +1226,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
"/datasets/<dataset_id>/documents/<document_id>/chunks/<chunk_id>", methods=["PUT"]
|
||||
)
|
||||
@token_required
|
||||
def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
"""
|
||||
Update a chunk within a document.
|
||||
---
|
||||
@ -1285,7 +1288,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
if not doc:
|
||||
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
||||
doc = doc[0]
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if "content" in req:
|
||||
content = req["content"]
|
||||
else:
|
||||
@ -1327,7 +1330,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
|
||||
@manager.route("/retrieval", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def retrieval_test(tenant_id):
|
||||
async def retrieval_test(tenant_id):
|
||||
"""
|
||||
Retrieve chunks based on a query.
|
||||
---
|
||||
@ -1408,7 +1411,7 @@ def retrieval_test(tenant_id):
|
||||
format: float
|
||||
description: Similarity score.
|
||||
"""
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if not req.get("dataset_ids"):
|
||||
return get_error_data_result("`dataset_ids` is required.")
|
||||
kb_ids = req["dataset_ids"]
|
||||
|
||||
@ -17,9 +17,7 @@
|
||||
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from quart import request, make_response
|
||||
from pathlib import Path
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -37,7 +35,7 @@ from common import settings
|
||||
|
||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def upload(tenant_id):
|
||||
async def upload(tenant_id):
|
||||
"""
|
||||
Upload a file to the system.
|
||||
---
|
||||
@ -79,15 +77,17 @@ def upload(tenant_id):
|
||||
type: string
|
||||
description: File type (e.g., document, folder)
|
||||
"""
|
||||
pf_id = request.form.get("parent_id")
|
||||
form = await request.form
|
||||
files = await request.files
|
||||
pf_id = form.get("parent_id")
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
if 'file' not in request.files:
|
||||
if 'file' not in files:
|
||||
return get_json_result(data=False, message='No file part!', code=400)
|
||||
file_objs = request.files.getlist('file')
|
||||
file_objs = files.getlist('file')
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
@ -151,7 +151,7 @@ def upload(tenant_id):
|
||||
|
||||
@manager.route('/file/create', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
async def create(tenant_id):
|
||||
"""
|
||||
Create a new file or folder.
|
||||
---
|
||||
@ -193,9 +193,9 @@ def create(tenant_id):
|
||||
type:
|
||||
type: string
|
||||
"""
|
||||
req = request.json
|
||||
pf_id = request.json.get("parent_id")
|
||||
input_file_type = request.json.get("type")
|
||||
req = await request.json
|
||||
pf_id = await request.json.get("parent_id")
|
||||
input_file_type = await request.json.get("type")
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
@ -450,7 +450,7 @@ def get_all_parent_folders(tenant_id):
|
||||
|
||||
@manager.route('/file/rm', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def rm(tenant_id):
|
||||
async def rm(tenant_id):
|
||||
"""
|
||||
Delete one or multiple files/folders.
|
||||
---
|
||||
@ -481,7 +481,7 @@ def rm(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
req = await request.json
|
||||
file_ids = req["file_ids"]
|
||||
try:
|
||||
for file_id in file_ids:
|
||||
@ -524,7 +524,7 @@ def rm(tenant_id):
|
||||
|
||||
@manager.route('/file/rename', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def rename(tenant_id):
|
||||
async def rename(tenant_id):
|
||||
"""
|
||||
Rename a file.
|
||||
---
|
||||
@ -556,7 +556,7 @@ def rename(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
req = await request.json
|
||||
try:
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
@ -585,7 +585,7 @@ def rename(tenant_id):
|
||||
|
||||
@manager.route('/file/get/<file_id>', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get(tenant_id, file_id):
|
||||
async def get(tenant_id, file_id):
|
||||
"""
|
||||
Download a file.
|
||||
---
|
||||
@ -619,7 +619,7 @@ def get(tenant_id, file_id):
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = settings.STORAGE_IMPL.get(b, n)
|
||||
|
||||
response = flask.make_response(blob)
|
||||
response = await make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name)
|
||||
if ext:
|
||||
if file.type == FileType.VISUAL.value:
|
||||
@ -633,7 +633,7 @@ def get(tenant_id, file_id):
|
||||
|
||||
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def move(tenant_id):
|
||||
async def move(tenant_id):
|
||||
"""
|
||||
Move one or multiple files to another folder.
|
||||
---
|
||||
@ -667,7 +667,7 @@ def move(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
req = await request.json
|
||||
try:
|
||||
file_ids = req["src_file_ids"]
|
||||
parent_id = req["dest_file_id"]
|
||||
@ -693,8 +693,8 @@ def move(tenant_id):
|
||||
|
||||
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def convert(tenant_id):
|
||||
req = request.json
|
||||
async def convert(tenant_id):
|
||||
req = await request.json
|
||||
kb_ids = req["kb_ids"]
|
||||
file_ids = req["file_ids"]
|
||||
file2documents = []
|
||||
|
||||
@ -18,7 +18,7 @@ import re
|
||||
import time
|
||||
|
||||
import tiktoken
|
||||
from flask import Response, jsonify, request
|
||||
from quart import Response, jsonify, request
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.db.db_models import APIToken
|
||||
@ -44,8 +44,8 @@ from common import settings
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id, chat_id):
|
||||
req = request.json
|
||||
async def create(tenant_id, chat_id):
|
||||
req = await request.json
|
||||
req["dialog_id"] = chat_id
|
||||
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
||||
if not dia:
|
||||
@ -97,8 +97,8 @@ def create_agent_session(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, chat_id, session_id):
|
||||
req = request.json
|
||||
async def update(tenant_id, chat_id, session_id):
|
||||
req = await request.json
|
||||
req["dialog_id"] = chat_id
|
||||
conv_id = session_id
|
||||
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
|
||||
@ -119,8 +119,8 @@ def update(tenant_id, chat_id, session_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def chat_completion(tenant_id, chat_id):
|
||||
req = request.json
|
||||
async def chat_completion(tenant_id, chat_id):
|
||||
req = await request.json
|
||||
if not req:
|
||||
req = {"question": ""}
|
||||
if not req.get("session_id"):
|
||||
@ -149,7 +149,7 @@ def chat_completion(tenant_id, chat_id):
|
||||
@manager.route("/chats_openai/<chat_id>/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("model", "messages") # noqa: F821
|
||||
@token_required
|
||||
def chat_completion_openai_like(tenant_id, chat_id):
|
||||
async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
"""
|
||||
OpenAI-like chat completion API that simulates the behavior of OpenAI's completions endpoint.
|
||||
|
||||
@ -206,7 +206,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
if reference:
|
||||
print(completion.choices[0].message.reference)
|
||||
"""
|
||||
req = request.get_json()
|
||||
req = await request.get_json()
|
||||
|
||||
need_reference = bool(req.get("reference", False))
|
||||
|
||||
@ -383,8 +383,8 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
@manager.route("/agents_openai/<agent_id>/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("model", "messages") # noqa: F821
|
||||
@token_required
|
||||
def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
req = request.json
|
||||
async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
req = await request.json
|
||||
tiktokenenc = tiktoken.get_encoding("cl100k_base")
|
||||
messages = req.get("messages", [])
|
||||
if not messages:
|
||||
@ -443,8 +443,8 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def agent_completions(tenant_id, agent_id):
|
||||
req = request.json
|
||||
async def agent_completions(tenant_id, agent_id):
|
||||
req = await request.json
|
||||
|
||||
if req.get("stream", True):
|
||||
|
||||
@ -610,13 +610,13 @@ def list_agent_session(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id, chat_id):
|
||||
async def delete(tenant_id, chat_id):
|
||||
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message="You don't own the chat")
|
||||
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
req = await request.json
|
||||
convs = ConversationService.query(dialog_id=chat_id)
|
||||
if not req:
|
||||
ids = None
|
||||
@ -661,10 +661,10 @@ def delete(tenant_id, chat_id):
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete_agent_session(tenant_id, agent_id):
|
||||
async def delete_agent_session(tenant_id, agent_id):
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
req = await request.json
|
||||
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
|
||||
if not cvs:
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
@ -716,8 +716,8 @@ def delete_agent_session(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def ask_about(tenant_id):
|
||||
req = request.json
|
||||
async def ask_about(tenant_id):
|
||||
req = await request.json
|
||||
if not req.get("question"):
|
||||
return get_error_data_result("`question` is required.")
|
||||
if not req.get("dataset_ids"):
|
||||
@ -755,8 +755,8 @@ def ask_about(tenant_id):
|
||||
|
||||
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def related_questions(tenant_id):
|
||||
req = request.json
|
||||
async def related_questions(tenant_id):
|
||||
req = await request.json
|
||||
if not req.get("question"):
|
||||
return get_error_data_result("`question` is required.")
|
||||
question = req["question"]
|
||||
@ -806,8 +806,8 @@ Related search terms:
|
||||
|
||||
|
||||
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
|
||||
def chatbot_completions(dialog_id):
|
||||
req = request.json
|
||||
async def chatbot_completions(dialog_id):
|
||||
req = await request.json
|
||||
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
@ -856,8 +856,8 @@ def chatbots_inputs(dialog_id):
|
||||
|
||||
|
||||
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
def agent_bot_completions(agent_id):
|
||||
req = request.json
|
||||
async def agent_bot_completions(agent_id):
|
||||
req = await request.json
|
||||
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
@ -901,7 +901,7 @@ def begin_inputs(agent_id):
|
||||
|
||||
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
|
||||
@validate_request("question", "kb_ids")
|
||||
def ask_about_embedded():
|
||||
async def ask_about_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -910,7 +910,7 @@ def ask_about_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = request.json
|
||||
req = await request.json
|
||||
uid = objs[0].tenant_id
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
@ -940,7 +940,7 @@ def ask_about_embedded():
|
||||
|
||||
@manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval_test_embedded():
|
||||
async def retrieval_test_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -949,7 +949,7 @@ def retrieval_test_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = request.json
|
||||
req = await request.json
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
@ -1039,7 +1039,7 @@ def retrieval_test_embedded():
|
||||
|
||||
@manager.route("/searchbots/related_questions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("question")
|
||||
def related_questions_embedded():
|
||||
async def related_questions_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -1048,7 +1048,7 @@ def related_questions_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = request.json
|
||||
req = await request.json
|
||||
tenant_id = objs[0].tenant_id
|
||||
if not tenant_id:
|
||||
return get_error_data_result(message="permission denined.")
|
||||
@ -1115,7 +1115,7 @@ def detail_share_embedded():
|
||||
|
||||
@manager.route("/searchbots/mindmap", methods=["POST"]) # noqa: F821
|
||||
@validate_request("question", "kb_ids")
|
||||
def mindmap():
|
||||
async def mindmap():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -1125,7 +1125,7 @@ def mindmap():
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
req = request.json
|
||||
req = await request.json
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
|
||||
@ -14,8 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import request
|
||||
from api.apps import current_user, login_required
|
||||
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from api.db.db_models import DB
|
||||
@ -30,8 +30,8 @@ from api.utils.api_utils import get_data_error_result, get_json_result, not_allo
|
||||
@manager.route("/create", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.get_json()
|
||||
async def create():
|
||||
req = await request.get_json()
|
||||
search_name = req["name"]
|
||||
description = req.get("description", "")
|
||||
if not isinstance(search_name, str):
|
||||
@ -65,8 +65,8 @@ def create():
|
||||
@login_required
|
||||
@validate_request("search_id", "name", "search_config", "tenant_id")
|
||||
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||
def update():
|
||||
req = request.get_json()
|
||||
async def update():
|
||||
req = await request.get_json()
|
||||
if not isinstance(req["name"], str):
|
||||
return get_data_error_result(message="Search name must be string.")
|
||||
if req["name"].strip() == "":
|
||||
@ -140,7 +140,7 @@ def detail():
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_search_app():
|
||||
async def list_search_app():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
@ -150,7 +150,7 @@ def list_search_app():
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await request.get_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -173,8 +173,8 @@ def list_search_app():
|
||||
@manager.route("/rm", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("search_id")
|
||||
def rm():
|
||||
req = request.get_json()
|
||||
async def rm():
|
||||
req = await request.get_json()
|
||||
search_id = req["search_id"]
|
||||
if not SearchService.accessible4deletion(search_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
@ -17,7 +17,7 @@ import logging
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import APITokenService
|
||||
@ -34,7 +34,7 @@ from common.time_utils import current_timestamp, datetime_format
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from flask import jsonify
|
||||
from quart import jsonify
|
||||
from api.utils.health_utils import run_health_checks
|
||||
from common import settings
|
||||
|
||||
|
||||
@ -14,10 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.apps import smtp_mail_server
|
||||
from quart import request
|
||||
from api.db import UserTenantRole
|
||||
from api.db.db_models import UserTenant
|
||||
from api.db.services.user_service import UserTenantService, UserService
|
||||
@ -28,6 +25,7 @@ from common.time_utils import delta_seconds
|
||||
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
|
||||
from api.utils.web_utils import send_invite_email
|
||||
from common import settings
|
||||
from api.apps import smtp_mail_server, login_required, current_user
|
||||
|
||||
|
||||
@manager.route("/<tenant_id>/user/list", methods=["GET"]) # noqa: F821
|
||||
@ -51,14 +49,14 @@ def user_list(tenant_id):
|
||||
@manager.route('/<tenant_id>/user', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("email")
|
||||
def create(tenant_id):
|
||||
async def create(tenant_id):
|
||||
if current_user.id != tenant_id:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
req = await request.json
|
||||
invite_user_email = req["email"]
|
||||
invite_users = UserService.query(email=invite_user_email)
|
||||
if not invite_users:
|
||||
|
||||
@ -22,8 +22,7 @@ import secrets
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from flask import redirect, request, session, make_response
|
||||
from flask_login import current_user, login_required, login_user, logout_user
|
||||
from quart import redirect, request, session, make_response
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
|
||||
from api.apps.auth import get_auth_client
|
||||
@ -45,7 +44,7 @@ from api.utils.api_utils import (
|
||||
)
|
||||
from api.utils.crypt import decrypt
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api.apps import smtp_mail_server
|
||||
from api.apps import smtp_mail_server, login_required, current_user, login_user, logout_user
|
||||
from api.utils.web_utils import (
|
||||
send_email_html,
|
||||
OTP_LENGTH,
|
||||
@ -61,7 +60,7 @@ from common import settings
|
||||
|
||||
|
||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||
def login():
|
||||
async def login():
|
||||
"""
|
||||
User login endpoint.
|
||||
---
|
||||
@ -91,10 +90,11 @@ def login():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
if not request.json:
|
||||
json_body = await request.json
|
||||
if not json_body:
|
||||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
|
||||
|
||||
email = request.json.get("email", "")
|
||||
email = json_body.get("email", "")
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
return get_json_result(
|
||||
@ -103,7 +103,7 @@ def login():
|
||||
message=f"Email: {email} is not registered!",
|
||||
)
|
||||
|
||||
password = request.json.get("password")
|
||||
password = json_body.get("password")
|
||||
try:
|
||||
password = decrypt(password)
|
||||
except BaseException:
|
||||
@ -125,7 +125,8 @@ def login():
|
||||
user.update_date = (datetime_format(datetime.now()),)
|
||||
user.save()
|
||||
msg = "Welcome back!"
|
||||
return construct_response(data=response_data, auth=user.get_id(), message=msg)
|
||||
|
||||
return await construct_response(data=response_data, auth=user.get_id(), message=msg)
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -501,7 +502,7 @@ def log_out():
|
||||
|
||||
@manager.route("/setting", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def setting_user():
|
||||
async def setting_user():
|
||||
"""
|
||||
Update user settings.
|
||||
---
|
||||
@ -530,7 +531,7 @@ def setting_user():
|
||||
type: object
|
||||
"""
|
||||
update_dict = {}
|
||||
request_data = request.json
|
||||
request_data = await request.json
|
||||
if request_data.get("password"):
|
||||
new_password = request_data.get("new_password")
|
||||
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||||
@ -660,7 +661,7 @@ def user_register(user_id, user):
|
||||
|
||||
@manager.route("/register", methods=["POST"]) # noqa: F821
|
||||
@validate_request("nickname", "email", "password")
|
||||
def user_add():
|
||||
async def user_add():
|
||||
"""
|
||||
Register a new user.
|
||||
---
|
||||
@ -697,7 +698,7 @@ def user_add():
|
||||
code=RetCode.OPERATING_ERROR,
|
||||
)
|
||||
|
||||
req = request.json
|
||||
req = await request.json
|
||||
email_address = req["email"]
|
||||
|
||||
# Validate the email address
|
||||
@ -737,7 +738,7 @@ def user_add():
|
||||
raise Exception(f"Same email: {email_address} exists!")
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return construct_response(
|
||||
return await construct_response(
|
||||
data=user.to_json(),
|
||||
auth=user.get_id(),
|
||||
message=f"{nickname}, welcome aboard!",
|
||||
@ -793,7 +794,7 @@ def tenant_info():
|
||||
@manager.route("/set_tenant_info", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
|
||||
def set_tenant_info():
|
||||
async def set_tenant_info():
|
||||
"""
|
||||
Update tenant information.
|
||||
---
|
||||
@ -830,7 +831,7 @@ def set_tenant_info():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req = request.json
|
||||
req = await request.json
|
||||
try:
|
||||
tid = req.pop("tenant_id")
|
||||
TenantService.update_by_id(tid, req)
|
||||
@ -840,7 +841,7 @@ def set_tenant_info():
|
||||
|
||||
|
||||
@manager.route("/forget/captcha", methods=["GET"]) # noqa: F821
|
||||
def forget_get_captcha():
|
||||
async def forget_get_captcha():
|
||||
"""
|
||||
GET /forget/captcha?email=<email>
|
||||
- Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = OTP_TTL_SECONDS.
|
||||
@ -862,19 +863,19 @@ def forget_get_captcha():
|
||||
from captcha.image import ImageCaptcha
|
||||
image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70])
|
||||
img_bytes = image.generate(captcha_text).read()
|
||||
response = make_response(img_bytes)
|
||||
response = await make_response(img_bytes)
|
||||
response.headers.set("Content-Type", "image/JPEG")
|
||||
return response
|
||||
|
||||
|
||||
@manager.route("/forget/otp", methods=["POST"]) # noqa: F821
|
||||
def forget_send_otp():
|
||||
async def forget_send_otp():
|
||||
"""
|
||||
POST /forget/otp
|
||||
- Verify the image captcha stored at captcha:{email} (case-insensitive).
|
||||
- On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
|
||||
"""
|
||||
req = request.get_json()
|
||||
req = await request.get_json()
|
||||
email = req.get("email") or ""
|
||||
captcha = (req.get("captcha") or "").strip()
|
||||
|
||||
@ -935,12 +936,12 @@ def forget_send_otp():
|
||||
|
||||
|
||||
@manager.route("/forget", methods=["POST"]) # noqa: F821
|
||||
def forget():
|
||||
async def forget():
|
||||
"""
|
||||
POST: Verify email + OTP and reset password, then log the user in.
|
||||
Request JSON: { email, otp, new_password, confirm_new_password }
|
||||
"""
|
||||
req = request.get_json()
|
||||
req = await request.get_json()
|
||||
email = req.get("email") or ""
|
||||
otp = (req.get("otp") or "").strip()
|
||||
new_pwd = req.get("new_password")
|
||||
|
||||
@ -25,7 +25,7 @@ from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import UserMixin
|
||||
from quart_auth import AuthUser
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||
@ -595,7 +595,7 @@ def fill_db_model_object(model_object, human_model_dict):
|
||||
return model_object
|
||||
|
||||
|
||||
class User(DataBaseModel, UserMixin):
|
||||
class User(DataBaseModel, AuthUser):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
access_token = CharField(max_length=255, null=True, index=True)
|
||||
nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
|
||||
|
||||
@ -24,7 +24,6 @@ from api.db import InputType
|
||||
from api.db.db_models import Connector, SyncLogs, Connector2Kb, Knowledgebase
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import TaskStatus
|
||||
from common.time_utils import current_timestamp, timestamp_to_date
|
||||
@ -68,6 +67,7 @@ class ConnectorService(CommonService):
|
||||
|
||||
@classmethod
|
||||
def rebuild(cls, kb_id:str, connector_id: str, tenant_id:str):
|
||||
from api.db.services.file_service import FileService
|
||||
e, conn = cls.get_by_id(connector_id)
|
||||
if not e:
|
||||
return None
|
||||
@ -191,6 +191,7 @@ class SyncLogsService(CommonService):
|
||||
|
||||
@classmethod
|
||||
def duplicate_and_parse(cls, kb, docs, tenant_id, src, auto_parse=True):
|
||||
from api.db.services.file_service import FileService
|
||||
if not docs:
|
||||
return None
|
||||
|
||||
|
||||
@ -41,6 +41,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common import settings
|
||||
|
||||
|
||||
class DocumentService(CommonService):
|
||||
model = Document
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@ import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
from flask_login import current_user
|
||||
from peewee import fn
|
||||
|
||||
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileType
|
||||
@ -34,6 +33,7 @@ from api.db.services.task_service import TaskService
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img, sanitize_path
|
||||
from rag.llm.cv_model import GptV4
|
||||
from common import settings
|
||||
from api.apps import current_user
|
||||
|
||||
|
||||
class FileService(CommonService):
|
||||
|
||||
@ -24,9 +24,10 @@ from common.time_utils import current_timestamp, datetime_format
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import StatusEnum, RetCode
|
||||
from common.constants import StatusEnum
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from api.utils.api_utils import get_parser_config
|
||||
from api.utils.api_utils import get_parser_config, get_data_error_result
|
||||
|
||||
|
||||
class KnowledgebaseService(CommonService):
|
||||
"""Service class for managing knowledge base operations.
|
||||
@ -391,12 +392,12 @@ class KnowledgebaseService(CommonService):
|
||||
"""
|
||||
# Validate name
|
||||
if not isinstance(name, str):
|
||||
return {"code": RetCode.DATA_ERROR, "message": "Dataset name must be string."}
|
||||
return False, get_data_error_result(message="Dataset name must be string.")
|
||||
dataset_name = name.strip()
|
||||
if len(dataset_name) == 0:
|
||||
return {"code": RetCode.DATA_ERROR, "message": "Dataset name can't be empty."}
|
||||
if dataset_name == "":
|
||||
return False, get_data_error_result(message="Dataset name can't be empty.")
|
||||
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||
return {"code": RetCode.DATA_ERROR, "message": f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}"}
|
||||
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
|
||||
|
||||
# Deduplicate name within tenant
|
||||
dataset_name = duplicate_name(
|
||||
@ -409,7 +410,7 @@ class KnowledgebaseService(CommonService):
|
||||
# Verify tenant exists
|
||||
ok, _t = TenantService.get_by_id(tenant_id)
|
||||
if not ok:
|
||||
return {"code": RetCode.DATA_ERROR, "message": "Tenant does not exist."}
|
||||
return False, get_data_error_result(message="Tenant not found.")
|
||||
|
||||
# Build payload
|
||||
kb_id = get_uuid()
|
||||
@ -425,7 +426,7 @@ class KnowledgebaseService(CommonService):
|
||||
# Update parser_config (always override with validated default/merged config)
|
||||
payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config"))
|
||||
|
||||
return payload
|
||||
return True, payload
|
||||
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -31,7 +31,6 @@ import traceback
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from werkzeug.serving import run_simple
|
||||
from api.apps import app, smtp_mail_server
|
||||
from api.db.runtime_config import RuntimeConfig
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -153,14 +152,7 @@ if __name__ == '__main__':
|
||||
# start http server
|
||||
try:
|
||||
logging.info("RAGFlow HTTP server start...")
|
||||
run_simple(
|
||||
hostname=settings.HOST_IP,
|
||||
port=settings.HOST_PORT,
|
||||
application=app,
|
||||
threaded=True,
|
||||
use_reloader=RuntimeConfig.DEBUG,
|
||||
use_debugger=RuntimeConfig.DEBUG,
|
||||
)
|
||||
app.run(host=settings.HOST_IP, port=settings.HOST_PORT)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
stop_event.set()
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -24,14 +25,12 @@ from functools import wraps
|
||||
|
||||
import requests
|
||||
import trio
|
||||
from flask import (
|
||||
from quart import (
|
||||
Response,
|
||||
jsonify,
|
||||
request
|
||||
)
|
||||
from flask_login import current_user
|
||||
from flask import (
|
||||
request as flask_request,
|
||||
)
|
||||
|
||||
from peewee import OperationalError
|
||||
|
||||
from common.constants import ActiveEnum
|
||||
@ -46,6 +45,12 @@ from common import settings
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
|
||||
|
||||
async def request_json():
|
||||
try:
|
||||
return await request.json
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def serialize_for_json(obj):
|
||||
"""
|
||||
Recursively serialize objects to make them JSON serializable.
|
||||
@ -105,31 +110,37 @@ def server_error_response(e):
|
||||
|
||||
|
||||
def validate_request(*args, **kwargs):
|
||||
def process_args(input_arguments):
|
||||
no_arguments = []
|
||||
error_arguments = []
|
||||
for arg in args:
|
||||
if arg not in input_arguments:
|
||||
no_arguments.append(arg)
|
||||
for k, v in kwargs.items():
|
||||
config_value = input_arguments.get(k, None)
|
||||
if config_value is None:
|
||||
no_arguments.append(k)
|
||||
elif isinstance(v, (tuple, list)):
|
||||
if config_value not in v:
|
||||
error_arguments.append((k, set(v)))
|
||||
elif config_value != v:
|
||||
error_arguments.append((k, v))
|
||||
if no_arguments or error_arguments:
|
||||
error_string = ""
|
||||
if no_arguments:
|
||||
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
|
||||
if error_arguments:
|
||||
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||
return error_string
|
||||
|
||||
def wrapper(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*_args, **_kwargs):
|
||||
input_arguments = flask_request.json or flask_request.form.to_dict()
|
||||
no_arguments = []
|
||||
error_arguments = []
|
||||
for arg in args:
|
||||
if arg not in input_arguments:
|
||||
no_arguments.append(arg)
|
||||
for k, v in kwargs.items():
|
||||
config_value = input_arguments.get(k, None)
|
||||
if config_value is None:
|
||||
no_arguments.append(k)
|
||||
elif isinstance(v, (tuple, list)):
|
||||
if config_value not in v:
|
||||
error_arguments.append((k, set(v)))
|
||||
elif config_value != v:
|
||||
error_arguments.append((k, v))
|
||||
if no_arguments or error_arguments:
|
||||
error_string = ""
|
||||
if no_arguments:
|
||||
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
|
||||
if error_arguments:
|
||||
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=error_string)
|
||||
async def decorated_function(*_args, **_kwargs):
|
||||
errs = process_args(await request.json or (await request.form).to_dict())
|
||||
if errs:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*_args, **_kwargs)
|
||||
return func(*_args, **_kwargs)
|
||||
|
||||
return decorated_function
|
||||
@ -138,30 +149,34 @@ def validate_request(*args, **kwargs):
|
||||
|
||||
|
||||
def not_allowed_parameters(*params):
|
||||
def decorator(f):
|
||||
def wrapper(*args, **kwargs):
|
||||
input_arguments = flask_request.json or flask_request.form.to_dict()
|
||||
def decorator(func):
|
||||
async def wrapper(*args, **kwargs):
|
||||
input_arguments = await request.json or (await request.form).to_dict()
|
||||
for param in params:
|
||||
if param in input_arguments:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
|
||||
return f(*args, **kwargs)
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def active_required(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
def active_required(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
from api.db.services import UserService
|
||||
from api.apps import current_user
|
||||
|
||||
user_id = current_user.id
|
||||
usr = UserService.filter_by_id(user_id)
|
||||
# check is_active
|
||||
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
|
||||
return get_json_result(code=RetCode.FORBIDDEN, message="User isn't active, please activate first.")
|
||||
return f(*args, **kwargs)
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -173,12 +188,15 @@ def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=Non
|
||||
|
||||
def apikey_required(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
token = flask_request.headers.get("Authorization").split()[1]
|
||||
async def decorated_function(*args, **kwargs):
|
||||
token = request.headers.get("Authorization").split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return build_error_result(message="API-KEY is invalid!", code=RetCode.FORBIDDEN)
|
||||
kwargs["tenant_id"] = objs[0].tenant_id
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
@ -199,23 +217,38 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da
|
||||
|
||||
|
||||
def token_required(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
def get_tenant_id(**kwargs):
|
||||
if os.environ.get("DISABLE_SDK"):
|
||||
return get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
authorization_str = flask_request.headers.get("Authorization")
|
||||
return False, get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
authorization_str = request.headers.get("Authorization")
|
||||
if not authorization_str:
|
||||
return get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
return False, get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
authorization_list = authorization_str.split()
|
||||
if len(authorization_list) < 2:
|
||||
return get_json_result(data=False, message="Please check your authorization format.")
|
||||
return False, get_json_result(data=False, message="Please check your authorization format.")
|
||||
token = authorization_list[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR)
|
||||
return False, get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR)
|
||||
kwargs["tenant_id"] = objs[0].tenant_id
|
||||
return True, kwargs
|
||||
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
e, kwargs = get_tenant_id(**kwargs)
|
||||
if not e:
|
||||
return kwargs
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
async def adecorated_function(*args, **kwargs):
|
||||
e, kwargs = get_tenant_id(**kwargs)
|
||||
if not e:
|
||||
return kwargs
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return adecorated_function
|
||||
return decorated_function
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ import base64
|
||||
import click
|
||||
import re
|
||||
|
||||
from flask import Flask
|
||||
from quart import Quart
|
||||
from werkzeug.security import generate_password_hash
|
||||
|
||||
from api.db.services import UserService
|
||||
@ -73,6 +73,7 @@ def reset_email(email, new_email, email_confirm):
|
||||
UserService.update_user(user[0].id,user_dict)
|
||||
click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
|
||||
|
||||
def register_commands(app: Flask):
|
||||
|
||||
def register_commands(app: Quart):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
|
||||
@ -17,7 +17,7 @@ from collections import Counter
|
||||
from typing import Annotated, Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Request
|
||||
from quart import Request
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@ -32,7 +32,7 @@ from werkzeug.exceptions import BadRequest, UnsupportedMediaType
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
|
||||
|
||||
def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
|
||||
async def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""
|
||||
Validates and parses JSON requests through a multi-stage validation pipeline.
|
||||
|
||||
@ -81,7 +81,7 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
|
||||
from the final output after validation
|
||||
"""
|
||||
try:
|
||||
payload = request.get_json() or {}
|
||||
payload = await request.get_json() or {}
|
||||
except UnsupportedMediaType:
|
||||
return None, f"Unsupported content type: Expected application/json, got {request.content_type}"
|
||||
except BadRequest:
|
||||
|
||||
@ -23,7 +23,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from api.apps import smtp_mail_server
|
||||
from flask_mail import Message
|
||||
from flask import render_template_string
|
||||
from quart import render_template_string
|
||||
from api.utils.email_templates import EMAIL_TEMPLATES
|
||||
from selenium import webdriver
|
||||
from selenium.common.exceptions import TimeoutException
|
||||
|
||||
@ -21,7 +21,7 @@ from typing import Any, Callable, Coroutine, Optional, Type, Union
|
||||
import asyncio
|
||||
import trio
|
||||
from functools import wraps
|
||||
from flask import make_response, jsonify
|
||||
from quart import make_response, jsonify
|
||||
from common.constants import RetCode
|
||||
|
||||
TimeoutException = Union[Type[BaseException], BaseException]
|
||||
@ -103,7 +103,7 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception:
|
||||
return decorator
|
||||
|
||||
|
||||
def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
|
||||
async def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
|
||||
result_dict = {"code": code, "message": message, "data": data}
|
||||
response_dict = {}
|
||||
for key, value in result_dict.items():
|
||||
@ -111,7 +111,7 @@ def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=
|
||||
continue
|
||||
else:
|
||||
response_dict[key] = value
|
||||
response = make_response(jsonify(response_dict))
|
||||
response = await make_response(jsonify(response_dict))
|
||||
if auth:
|
||||
response.headers["Authorization"] = auth
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
|
||||
@ -61,7 +61,9 @@ class DoclingParser(RAGFlowPdfParser):
|
||||
self.page_images: list[Image.Image] = []
|
||||
self.page_from = 0
|
||||
self.page_to = 10_000
|
||||
|
||||
self.outlines = []
|
||||
|
||||
|
||||
def check_installation(self) -> bool:
|
||||
if DocumentConverter is None:
|
||||
self.logger.warning("[Docling] 'docling' is not importable, please: pip install docling")
|
||||
|
||||
@ -59,6 +59,7 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
self.mineru_api = mineru_api.rstrip("/")
|
||||
self.mineru_server_url = mineru_server_url.rstrip("/")
|
||||
self.using_api = False
|
||||
self.outlines = []
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
def _extract_zip_no_root(self, zip_path, extract_to, root_dir):
|
||||
@ -337,12 +338,54 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
return None, None
|
||||
return
|
||||
|
||||
if not getattr(self, "page_images", None):
|
||||
self.logger.warning("[MinerU] crop called without page images; skipping image generation.")
|
||||
if need_position:
|
||||
return None, None
|
||||
return
|
||||
|
||||
page_count = len(self.page_images)
|
||||
|
||||
filtered_poss = []
|
||||
for pns, left, right, top, bottom in poss:
|
||||
if not pns:
|
||||
self.logger.warning("[MinerU] Empty page index list in crop; skipping this position.")
|
||||
continue
|
||||
valid_pns = [p for p in pns if 0 <= p < page_count]
|
||||
if not valid_pns:
|
||||
self.logger.warning(f"[MinerU] All page indices {pns} out of range for {page_count} pages; skipping.")
|
||||
continue
|
||||
filtered_poss.append((valid_pns, left, right, top, bottom))
|
||||
|
||||
poss = filtered_poss
|
||||
if not poss:
|
||||
self.logger.warning("[MinerU] No valid positions after filtering; skip cropping.")
|
||||
if need_position:
|
||||
return None, None
|
||||
return
|
||||
|
||||
max_width = max(np.max([right - left for (_, left, right, _, _) in poss]), 6)
|
||||
GAP = 6
|
||||
pos = poss[0]
|
||||
poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
|
||||
first_page_idx = pos[0][0]
|
||||
poss.insert(0, ([first_page_idx], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
|
||||
pos = poss[-1]
|
||||
poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1], pos[4] + GAP), min(self.page_images[pos[0][-1]].size[1], pos[4] + 120)))
|
||||
last_page_idx = pos[0][-1]
|
||||
if not (0 <= last_page_idx < page_count):
|
||||
self.logger.warning(f"[MinerU] Last page index {last_page_idx} out of range for {page_count} pages; skipping crop.")
|
||||
if need_position:
|
||||
return None, None
|
||||
return
|
||||
last_page_height = self.page_images[last_page_idx].size[1]
|
||||
poss.append(
|
||||
(
|
||||
[last_page_idx],
|
||||
pos[1],
|
||||
pos[2],
|
||||
min(last_page_height, pos[4] + GAP),
|
||||
min(last_page_height, pos[4] + 120),
|
||||
)
|
||||
)
|
||||
|
||||
positions = []
|
||||
for ii, (pns, left, right, top, bottom) in enumerate(poss):
|
||||
@ -352,7 +395,14 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
bottom = top + 2
|
||||
|
||||
for pn in pns[1:]:
|
||||
bottom += self.page_images[pn - 1].size[1]
|
||||
if 0 <= pn - 1 < page_count:
|
||||
bottom += self.page_images[pn - 1].size[1]
|
||||
else:
|
||||
self.logger.warning(f"[MinerU] Page index {pn}-1 out of range for {page_count} pages during crop; skipping height accumulation.")
|
||||
|
||||
if not (0 <= pns[0] < page_count):
|
||||
self.logger.warning(f"[MinerU] Base page index {pns[0]} out of range for {page_count} pages during crop; skipping this segment.")
|
||||
continue
|
||||
|
||||
img0 = self.page_images[pns[0]]
|
||||
x0, y0, x1, y1 = int(left), int(top), int(right), int(min(bottom, img0.size[1]))
|
||||
@ -363,6 +413,9 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
|
||||
bottom -= img0.size[1]
|
||||
for pn in pns[1:]:
|
||||
if not (0 <= pn < page_count):
|
||||
self.logger.warning(f"[MinerU] Page index {pn} out of range for {page_count} pages during crop; skipping this page.")
|
||||
continue
|
||||
page = self.page_images[pn]
|
||||
x0, y0, x1, y1 = int(left), 0, int(right), int(min(bottom, page.size[1]))
|
||||
cimgp = page.crop((x0, y0, x1, y1))
|
||||
|
||||
@ -1252,24 +1252,77 @@ class RAGFlowPdfParser:
|
||||
return None, None
|
||||
return
|
||||
|
||||
if not getattr(self, "page_images", None):
|
||||
logging.warning("crop called without page images; skipping image generation.")
|
||||
if need_position:
|
||||
return None, None
|
||||
return
|
||||
|
||||
page_count = len(self.page_images)
|
||||
|
||||
filtered_poss = []
|
||||
for pns, left, right, top, bottom in poss:
|
||||
if not pns:
|
||||
logging.warning("Empty page index list in crop; skipping this position.")
|
||||
continue
|
||||
valid_pns = [p for p in pns if 0 <= p < page_count]
|
||||
if not valid_pns:
|
||||
logging.warning(f"All page indices {pns} out of range for {page_count} pages; skipping.")
|
||||
continue
|
||||
filtered_poss.append((valid_pns, left, right, top, bottom))
|
||||
|
||||
poss = filtered_poss
|
||||
if not poss:
|
||||
logging.warning("No valid positions after filtering; skip cropping.")
|
||||
if need_position:
|
||||
return None, None
|
||||
return
|
||||
|
||||
max_width = max(np.max([right - left for (_, left, right, _, _) in poss]), 6)
|
||||
GAP = 6
|
||||
pos = poss[0]
|
||||
poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
|
||||
first_page_idx = pos[0][0]
|
||||
poss.insert(0, ([first_page_idx], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
|
||||
pos = poss[-1]
|
||||
poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + GAP), min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + 120)))
|
||||
last_page_idx = pos[0][-1]
|
||||
if not (0 <= last_page_idx < page_count):
|
||||
logging.warning(f"Last page index {last_page_idx} out of range for {page_count} pages; skipping crop.")
|
||||
if need_position:
|
||||
return None, None
|
||||
return
|
||||
last_page_height = self.page_images[last_page_idx].size[1] / ZM
|
||||
poss.append(
|
||||
(
|
||||
[last_page_idx],
|
||||
pos[1],
|
||||
pos[2],
|
||||
min(last_page_height, pos[4] + GAP),
|
||||
min(last_page_height, pos[4] + 120),
|
||||
)
|
||||
)
|
||||
|
||||
positions = []
|
||||
for ii, (pns, left, right, top, bottom) in enumerate(poss):
|
||||
right = left + max_width
|
||||
bottom *= ZM
|
||||
for pn in pns[1:]:
|
||||
bottom += self.page_images[pn - 1].size[1]
|
||||
if 0 <= pn - 1 < page_count:
|
||||
bottom += self.page_images[pn - 1].size[1]
|
||||
else:
|
||||
logging.warning(f"Page index {pn}-1 out of range for {page_count} pages during crop; skipping height accumulation.")
|
||||
|
||||
if not (0 <= pns[0] < page_count):
|
||||
logging.warning(f"Base page index {pns[0]} out of range for {page_count} pages during crop; skipping this segment.")
|
||||
continue
|
||||
|
||||
imgs.append(self.page_images[pns[0]].crop((left * ZM, top * ZM, right * ZM, min(bottom, self.page_images[pns[0]].size[1]))))
|
||||
if 0 < ii < len(poss) - 1:
|
||||
positions.append((pns[0] + self.page_from, left, right, top, min(bottom, self.page_images[pns[0]].size[1]) / ZM))
|
||||
bottom -= self.page_images[pns[0]].size[1]
|
||||
for pn in pns[1:]:
|
||||
if not (0 <= pn < page_count):
|
||||
logging.warning(f"Page index {pn} out of range for {page_count} pages during crop; skipping this page.")
|
||||
continue
|
||||
imgs.append(self.page_images[pn].crop((left * ZM, 0, right * ZM, min(bottom, self.page_images[pn].size[1]))))
|
||||
if 0 < ii < len(poss) - 1:
|
||||
positions.append((pn + self.page_from, left, right, 0, min(bottom, self.page_images[pn].size[1]) / ZM))
|
||||
|
||||
@ -47,6 +47,7 @@ class TencentCloudAPIClient:
|
||||
self.secret_id = secret_id
|
||||
self.secret_key = secret_key
|
||||
self.region = region
|
||||
self.outlines = []
|
||||
|
||||
# Create credentials
|
||||
self.cred = credential.Credential(secret_id, secret_key)
|
||||
|
||||
@ -86,6 +86,9 @@ dependencies = [
|
||||
"python-pptx>=1.0.2,<2.0.0",
|
||||
"pywencai==0.12.2",
|
||||
"qianfan==0.4.6",
|
||||
"quart-auth==0.11.0",
|
||||
"quart-cors==0.8.0",
|
||||
"Quart==0.20.0",
|
||||
"ranx==0.3.20",
|
||||
"readability-lxml==0.8.1",
|
||||
"valkey==6.0.2",
|
||||
|
||||
@ -216,6 +216,30 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _normalize_section(section):
|
||||
# pad section to length 3: (txt, sec_id, poss)
|
||||
if len(section) == 1:
|
||||
section = (section[0], "", [])
|
||||
elif len(section) == 2:
|
||||
section = (section[0], "", section[1])
|
||||
elif len(section) != 3:
|
||||
raise ValueError(f"Unexpected section length: {len(section)} (value={section!r})")
|
||||
|
||||
txt, sec_id, poss = section
|
||||
if isinstance(poss, str):
|
||||
poss = pdf_parser.extract_positions(poss)
|
||||
first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
|
||||
pn = first[0]
|
||||
|
||||
if isinstance(pn, list):
|
||||
pn = pn[0] # [pn] -> pn
|
||||
poss[0] = (pn, *first[1:])
|
||||
|
||||
return (txt, sec_id, poss)
|
||||
|
||||
|
||||
sections = [_normalize_section(sec) for sec in sections]
|
||||
|
||||
if not sections and not tbls:
|
||||
return []
|
||||
|
||||
|
||||
@ -70,6 +70,7 @@ def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese"
|
||||
callback=callback,
|
||||
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
|
||||
backend=os.environ.get("MINERU_BACKEND", "pipeline"),
|
||||
server_url=os.environ.get("MINERU_SERVER_URL", ""),
|
||||
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
|
||||
)
|
||||
return sections, tables, pdf_parser
|
||||
|
||||
@ -52,7 +52,7 @@ def chunks_format(reference):
|
||||
"similarity": chunk.get("similarity"),
|
||||
"vector_similarity": chunk.get("vector_similarity"),
|
||||
"term_similarity": chunk.get("term_similarity"),
|
||||
"doc_type": chunk.get("doc_type_kwd"),
|
||||
"doc_type": get_value(chunk, "doc_type_kwd", "doc_type"),
|
||||
}
|
||||
for chunk in reference.get("chunks", [])
|
||||
]
|
||||
|
||||
@ -101,10 +101,10 @@ def test_invalid_name_dataset(get_auth):
|
||||
# create dataset
|
||||
# with pytest.raises(Exception) as e:
|
||||
res = create_dataset(get_auth, 0)
|
||||
assert res['code'] == 100
|
||||
assert res['code'] != 0
|
||||
|
||||
res = create_dataset(get_auth, "")
|
||||
assert res['code'] == 102
|
||||
assert res['code'] != 0
|
||||
|
||||
long_string = ""
|
||||
|
||||
@ -112,7 +112,7 @@ def test_invalid_name_dataset(get_auth):
|
||||
long_string += random.choice(string.ascii_letters + string.digits)
|
||||
|
||||
res = create_dataset(get_auth, long_string)
|
||||
assert res['code'] == 102
|
||||
assert res['code'] != 0
|
||||
print(res)
|
||||
|
||||
|
||||
|
||||
@ -496,7 +496,7 @@ const DynamicForm = {
|
||||
<Textarea
|
||||
{...finalFieldProps}
|
||||
placeholder={field.placeholder}
|
||||
className="resize-none"
|
||||
// className="resize-none"
|
||||
/>
|
||||
);
|
||||
}}
|
||||
@ -544,7 +544,7 @@ const DynamicForm = {
|
||||
render={({ field: formField }) => (
|
||||
<FormItem
|
||||
className={cn('flex items-center w-full', {
|
||||
'flex-row items-start space-x-3 space-y-0':
|
||||
'flex-row items-center space-x-3 space-y-0':
|
||||
!field.horizontal,
|
||||
})}
|
||||
>
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
'use client';
|
||||
|
||||
import { EyeIcon, EyeOffIcon } from 'lucide-react';
|
||||
import React, { useId, useState } from 'react';
|
||||
import { Input, InputProps } from '../ui/input';
|
||||
|
||||
@ -33,11 +32,11 @@ export default React.forwardRef<HTMLInputElement, InputProps>(
|
||||
aria-pressed={isVisible}
|
||||
aria-controls="password"
|
||||
>
|
||||
{isVisible ? (
|
||||
{/* {isVisible ? (
|
||||
<EyeOffIcon size={16} aria-hidden="true" />
|
||||
) : (
|
||||
<EyeIcon size={16} aria-hidden="true" />
|
||||
)}
|
||||
)} */}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -140,7 +140,7 @@ export const SelectWithSearch = forwardRef<
|
||||
ref={ref}
|
||||
disabled={disabled}
|
||||
className={cn(
|
||||
'!bg-bg-input hover:bg-background border-input w-full justify-between px-3 font-normal outline-offset-0 outline-none focus-visible:outline-[3px] [&_svg]:pointer-events-auto',
|
||||
'!bg-bg-input hover:bg-background border-border-button w-full justify-between px-3 font-normal outline-offset-0 outline-none focus-visible:outline-[3px] [&_svg]:pointer-events-auto',
|
||||
triggerClassName,
|
||||
)}
|
||||
>
|
||||
@ -166,7 +166,7 @@ export const SelectWithSearch = forwardRef<
|
||||
)}
|
||||
<ChevronDownIcon
|
||||
size={16}
|
||||
className="text-muted-foreground/80 shrink-0 ml-2"
|
||||
className="text-text-disabled shrink-0 ml-2"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@ -41,7 +41,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(
|
||||
type={isPasswordInput && showPassword ? 'text' : type}
|
||||
className={cn(
|
||||
'peer/input',
|
||||
'flex h-8 w-full rounded-md border-0.5 border-input bg-bg-input px-3 py-2 outline-none text-sm text-text-primary',
|
||||
'flex h-8 w-full rounded-md border-0.5 border-border-button bg-bg-input px-3 py-2 outline-none text-sm text-text-primary',
|
||||
'file:border-0 file:bg-transparent file:text-sm file:font-medium file:text-foreground placeholder:text-text-disabled',
|
||||
'focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-accent-primary',
|
||||
'disabled:cursor-not-allowed disabled:opacity-50 transition-colors',
|
||||
|
||||
@ -54,7 +54,7 @@ const Textarea = forwardRef<HTMLTextAreaElement, TextareaProps>(
|
||||
return (
|
||||
<textarea
|
||||
className={cn(
|
||||
'flex min-h-[80px] w-full bg-bg-input rounded-md border border-input px-3 py-2 text-base ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 md:text-sm overflow-hidden',
|
||||
'flex min-h-[80px] w-full bg-bg-input rounded-md border border-input px-3 py-2 text-base ring-offset-background placeholder:text-text-disabled focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-accent-primary focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 md:text-sm overflow-hidden',
|
||||
className,
|
||||
)}
|
||||
rows={autoSize?.minRows ?? props.rows ?? undefined}
|
||||
|
||||
@ -4,7 +4,7 @@ import * as TooltipPrimitive from '@radix-ui/react-tooltip';
|
||||
import * as React from 'react';
|
||||
|
||||
import { cn } from '@/lib/utils';
|
||||
import { Info } from 'lucide-react';
|
||||
import { CircleQuestionMark } from 'lucide-react';
|
||||
|
||||
const TooltipProvider = TooltipPrimitive.Provider;
|
||||
|
||||
@ -39,7 +39,7 @@ export const FormTooltip = ({ tooltip }: { tooltip: React.ReactNode }) => {
|
||||
e.preventDefault(); // Prevent clicking the tooltip from triggering form save
|
||||
}}
|
||||
>
|
||||
<Info className="size-3 ml-2" />
|
||||
<CircleQuestionMark className="size-3 ml-2" />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>{tooltip}</TooltipContent>
|
||||
</Tooltip>
|
||||
|
||||
@ -61,9 +61,10 @@ function buildLlmOptionsWithIcon(x: IThirdOAIModel) {
|
||||
<div className="flex items-center justify-center gap-6">
|
||||
<LlmIcon
|
||||
name={getLLMIconName(x.fid, x.llm_name)}
|
||||
width={26}
|
||||
height={26}
|
||||
width={24}
|
||||
height={24}
|
||||
size={'small'}
|
||||
imgClass="size-6"
|
||||
/>
|
||||
<span>{getRealModelName(x.llm_name)}</span>
|
||||
</div>
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
export default {
|
||||
translation: {
|
||||
common: {
|
||||
confirm: 'Confirm',
|
||||
back: 'Back',
|
||||
noResults: 'No results.',
|
||||
selectPlaceholder: 'select value',
|
||||
@ -694,6 +695,7 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s
|
||||
tocEnhanceTip: ` During the parsing of the document, table of contents information was generated (see the 'Enable Table of Contents Extraction' option in the General method). This allows the large model to return table of contents items relevant to the user's query, thereby using these items to retrieve related chunks and apply weighting to these chunks during the sorting process. This approach is derived from mimicking the behavioral logic of how humans search for knowledge in books.`,
|
||||
},
|
||||
setting: {
|
||||
edit: 'Edit',
|
||||
cropTip:
|
||||
'Drag the selection area to choose the cropping position of the image, and scroll to zoom in/out',
|
||||
cropImage: 'Crop image',
|
||||
@ -1852,6 +1854,19 @@ Important structured information may include: names, dates, locations, events, k
|
||||
asc: 'Ascending',
|
||||
desc: 'Descending',
|
||||
},
|
||||
variableAssignerLogicalOperatorOptions: {
|
||||
overwrite: 'Overwrite',
|
||||
clear: 'Clear',
|
||||
set: 'Set',
|
||||
'+=': 'Add',
|
||||
'-=': 'Subtract',
|
||||
'*=': 'Multiply',
|
||||
'/=': 'Divide',
|
||||
append: 'Append',
|
||||
extend: 'Extend',
|
||||
removeFirst: 'Remove first',
|
||||
removeLast: 'Remove last',
|
||||
},
|
||||
},
|
||||
llmTools: {
|
||||
bad_calculator: {
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
export default {
|
||||
translation: {
|
||||
common: {
|
||||
confirm: '确定',
|
||||
back: '返回',
|
||||
noResults: '无结果。',
|
||||
selectPlaceholder: '请选择',
|
||||
@ -684,6 +685,7 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
||||
tocEnhanceTip: `解析文档时生成了目录信息(见General方法的‘启用目录抽取’),让大模型返回和用户问题相关的目录项,从而利用目录项拿到相关chunk,对这些chunk在排序中进行加权。这种方法来源于模仿人类查询书本中知识的行为逻辑`,
|
||||
},
|
||||
setting: {
|
||||
edit: '编辑',
|
||||
cropTip: '拖动选区可以选择要图片的裁剪位置,滚动可以放大/缩小选区',
|
||||
cropImage: '剪裁图片',
|
||||
selectModelPlaceholder: '请选择模型',
|
||||
@ -1713,6 +1715,19 @@ Tokenizer 会根据所选方式将内容存储为对应的数据结构。`,
|
||||
asc: '升序',
|
||||
desc: '降序',
|
||||
},
|
||||
variableAssignerLogicalOperatorOptions: {
|
||||
overwrite: '覆盖',
|
||||
clear: '清除',
|
||||
set: '设置',
|
||||
add: '加',
|
||||
subtract: '减',
|
||||
multiply: '乘',
|
||||
divide: '除',
|
||||
append: '追加',
|
||||
extend: '扩展',
|
||||
removeFirst: '移除第一个',
|
||||
removeLast: '移除最后一个',
|
||||
},
|
||||
},
|
||||
footer: {
|
||||
profile: 'All rights reserved @ React',
|
||||
|
||||
@ -81,7 +81,6 @@ export function AccordionOperators({
|
||||
Operator.DataOperations,
|
||||
Operator.VariableAssigner,
|
||||
Operator.ListOperations,
|
||||
Operator.VariableAssigner,
|
||||
Operator.VariableAggregator,
|
||||
]}
|
||||
isCustomDropdown={isCustomDropdown}
|
||||
|
||||
@ -1,11 +1,31 @@
|
||||
import { IRagNode } from '@/interfaces/database/agent';
|
||||
import { NodeCollapsible } from '@/components/collapse';
|
||||
import { BaseNode } from '@/interfaces/database/agent';
|
||||
import { NodeProps } from '@xyflow/react';
|
||||
import { RagNode } from '.';
|
||||
import { VariableAssignerFormSchemaType } from '../../form/variable-assigner-form';
|
||||
import { useGetVariableLabelOrTypeByValue } from '../../hooks/use-get-begin-query';
|
||||
import { LabelCard } from './card';
|
||||
|
||||
export function VariableAssignerNode({
|
||||
...props
|
||||
}: NodeProps<BaseNode<VariableAssignerFormSchemaType>>) {
|
||||
const { data } = props;
|
||||
const { getLabel } = useGetVariableLabelOrTypeByValue();
|
||||
|
||||
export function VariableAssignerNode({ ...props }: NodeProps<IRagNode>) {
|
||||
return (
|
||||
<RagNode {...props}>
|
||||
<section>select</section>
|
||||
<NodeCollapsible items={data.form?.variables}>
|
||||
{(x, idx) => (
|
||||
<section key={idx} className="space-y-1">
|
||||
<LabelCard key={idx} className="flex justify-between gap-2">
|
||||
<span className="flex truncate min-w-0">
|
||||
{getLabel(x.variable)}
|
||||
</span>
|
||||
<span className="border px-1 rounded-sm">{x.operator}</span>
|
||||
</LabelCard>
|
||||
</section>
|
||||
)}
|
||||
</NodeCollapsible>
|
||||
</RagNode>
|
||||
);
|
||||
}
|
||||
|
||||
@ -463,21 +463,14 @@ export const initialAgentValues = {
|
||||
tools: [],
|
||||
mcp: [],
|
||||
cite: true,
|
||||
showStructuredOutput: false,
|
||||
[AgentStructuredOutputField]: {},
|
||||
outputs: {
|
||||
// structured_output: {
|
||||
// topic: {
|
||||
// type: 'string',
|
||||
// description:
|
||||
// 'default:general. The category of the search.news is useful for retrieving real-time updates, particularly about politics, sports, and major current events covered by mainstream media sources. general is for broader, more general-purpose searches that may include a wide range of sources.',
|
||||
// enum: ['general', 'news'],
|
||||
// default: 'general',
|
||||
// },
|
||||
// },
|
||||
content: {
|
||||
type: 'string',
|
||||
value: '',
|
||||
},
|
||||
[AgentStructuredOutputField]: {},
|
||||
// [AgentStructuredOutputField]: {},
|
||||
},
|
||||
};
|
||||
|
||||
@ -858,6 +851,13 @@ export enum VariableAssignerLogicalNumberOperator {
|
||||
Divide = '/=',
|
||||
}
|
||||
|
||||
export const VariableAssignerLogicalNumberOperatorLabelMap = {
|
||||
[VariableAssignerLogicalNumberOperator.Add]: 'add',
|
||||
[VariableAssignerLogicalNumberOperator.Subtract]: 'subtract',
|
||||
[VariableAssignerLogicalNumberOperator.Multiply]: 'multiply',
|
||||
[VariableAssignerLogicalNumberOperator.Divide]: 'divide',
|
||||
};
|
||||
|
||||
export enum VariableAssignerLogicalArrayOperator {
|
||||
Overwrite = VariableAssignerLogicalOperator.Overwrite,
|
||||
Clear = VariableAssignerLogicalOperator.Clear,
|
||||
|
||||
@ -6,6 +6,7 @@ import {
|
||||
import { LlmSettingSchema } from '@/components/llm-setting-items/next';
|
||||
import { MessageHistoryWindowSizeFormField } from '@/components/message-history-window-size-item';
|
||||
import { SelectWithSearch } from '@/components/originui/select-with-search';
|
||||
import { RAGFlowFormItem } from '@/components/ragflow-form';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import {
|
||||
Form,
|
||||
@ -15,6 +16,7 @@ import {
|
||||
FormLabel,
|
||||
} from '@/components/ui/form';
|
||||
import { Input, NumberInput } from '@/components/ui/input';
|
||||
import { Label } from '@/components/ui/label';
|
||||
import { Separator } from '@/components/ui/separator';
|
||||
import { Switch } from '@/components/ui/switch';
|
||||
import { LlmModelType } from '@/constants/knowledge';
|
||||
@ -26,9 +28,9 @@ import { useTranslation } from 'react-i18next';
|
||||
import { z } from 'zod';
|
||||
import {
|
||||
AgentExceptionMethod,
|
||||
AgentStructuredOutputField,
|
||||
NodeHandleId,
|
||||
VariableType,
|
||||
initialAgentValues,
|
||||
} from '../../constant';
|
||||
import { INextOperatorForm } from '../../interface';
|
||||
import useGraphStore from '../../store';
|
||||
@ -71,18 +73,20 @@ const FormSchema = z.object({
|
||||
exception_default_value: z.string().optional(),
|
||||
...LargeModelFilterFormSchema,
|
||||
cite: z.boolean().optional(),
|
||||
showStructuredOutput: z.boolean().optional(),
|
||||
[AgentStructuredOutputField]: z.record(z.any()),
|
||||
});
|
||||
|
||||
export type AgentFormSchemaType = z.infer<typeof FormSchema>;
|
||||
|
||||
const outputList = buildOutputList(initialAgentValues.outputs);
|
||||
|
||||
function AgentForm({ node }: INextOperatorForm) {
|
||||
const { t } = useTranslation();
|
||||
const { edges, deleteEdgesBySourceAndSourceHandle } = useGraphStore(
|
||||
(state) => state,
|
||||
);
|
||||
|
||||
const outputList = buildOutputList(node?.data.form.outputs);
|
||||
|
||||
const defaultValues = useValues(node);
|
||||
|
||||
const { extraOptions } = useBuildPromptExtraPromptOptions(edges, node?.id);
|
||||
@ -112,13 +116,18 @@ function AgentForm({ node }: INextOperatorForm) {
|
||||
name: 'exception_method',
|
||||
});
|
||||
|
||||
const showStructuredOutput = useWatch({
|
||||
control: form.control,
|
||||
name: 'showStructuredOutput',
|
||||
});
|
||||
|
||||
const {
|
||||
initialStructuredOutput,
|
||||
showStructuredOutputDialog,
|
||||
structuredOutputDialogVisible,
|
||||
hideStructuredOutputDialog,
|
||||
handleStructuredOutputDialogOk,
|
||||
} = useShowStructuredOutputDialog(node?.id);
|
||||
} = useShowStructuredOutputDialog(form);
|
||||
|
||||
useEffect(() => {
|
||||
if (exceptionMethod !== AgentExceptionMethod.Goto) {
|
||||
@ -275,18 +284,42 @@ function AgentForm({ node }: INextOperatorForm) {
|
||||
)}
|
||||
</section>
|
||||
</Collapse>
|
||||
<Output list={outputList}></Output>
|
||||
<section className="space-y-2">
|
||||
<div className="flex justify-between items-center">
|
||||
{t('flow.structuredOutput.structuredOutput')}
|
||||
<Button variant={'outline'} onClick={showStructuredOutputDialog}>
|
||||
{t('flow.structuredOutput.configuration')}
|
||||
</Button>
|
||||
</div>
|
||||
<StructuredOutputPanel
|
||||
value={initialStructuredOutput}
|
||||
></StructuredOutputPanel>
|
||||
</section>
|
||||
<RAGFlowFormItem name={AgentStructuredOutputField} className="hidden">
|
||||
<Input></Input>
|
||||
</RAGFlowFormItem>
|
||||
<Output list={outputList}>
|
||||
<RAGFlowFormItem name="showStructuredOutput">
|
||||
{(field) => (
|
||||
<div className="flex items-center space-x-2">
|
||||
<Label htmlFor="airplane-mode">
|
||||
{t('flow.structuredOutput.structuredOutput')}
|
||||
</Label>
|
||||
<Switch
|
||||
id="airplane-mode"
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</RAGFlowFormItem>
|
||||
</Output>
|
||||
{showStructuredOutput && (
|
||||
<section className="space-y-2">
|
||||
<div className="flex justify-between items-center">
|
||||
{t('flow.structuredOutput.structuredOutput')}
|
||||
<Button
|
||||
variant={'outline'}
|
||||
onClick={showStructuredOutputDialog}
|
||||
>
|
||||
{t('flow.structuredOutput.configuration')}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<StructuredOutputPanel
|
||||
value={initialStructuredOutput}
|
||||
></StructuredOutputPanel>
|
||||
</section>
|
||||
)}
|
||||
</FormWrapper>
|
||||
</Form>
|
||||
{structuredOutputDialogVisible && (
|
||||
|
||||
@ -1,27 +1,25 @@
|
||||
import { JSONSchema } from '@/components/jsonjoy-builder';
|
||||
import { AgentStructuredOutputField } from '@/constants/agent';
|
||||
import { useSetModalState } from '@/hooks/common-hooks';
|
||||
import { useCallback } from 'react';
|
||||
import useGraphStore from '../../store';
|
||||
import { UseFormReturn } from 'react-hook-form';
|
||||
|
||||
export function useShowStructuredOutputDialog(nodeId?: string) {
|
||||
export function useShowStructuredOutputDialog(form: UseFormReturn<any>) {
|
||||
const {
|
||||
visible: structuredOutputDialogVisible,
|
||||
showModal: showStructuredOutputDialog,
|
||||
hideModal: hideStructuredOutputDialog,
|
||||
} = useSetModalState();
|
||||
const { updateNodeForm, getNode } = useGraphStore((state) => state);
|
||||
|
||||
const initialStructuredOutput = getNode(nodeId)?.data.form.outputs.structured;
|
||||
const initialStructuredOutput = form.getValues(AgentStructuredOutputField);
|
||||
|
||||
const handleStructuredOutputDialogOk = useCallback(
|
||||
(values: JSONSchema) => {
|
||||
// Sync data to canvas
|
||||
if (nodeId) {
|
||||
updateNodeForm(nodeId, values, ['outputs', 'structured']);
|
||||
}
|
||||
form.setValue(AgentStructuredOutputField, values);
|
||||
hideStructuredOutputDialog();
|
||||
},
|
||||
[hideStructuredOutputDialog, nodeId, updateNodeForm],
|
||||
[form, hideStructuredOutputDialog],
|
||||
);
|
||||
|
||||
return {
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { omit } from 'lodash';
|
||||
import { useEffect } from 'react';
|
||||
import { UseFormReturn, useWatch } from 'react-hook-form';
|
||||
import { PromptRole } from '../../constant';
|
||||
import { AgentStructuredOutputField, PromptRole } from '../../constant';
|
||||
import useGraphStore from '../../store';
|
||||
|
||||
export function useWatchFormChange(id?: string, form?: UseFormReturn<any>) {
|
||||
@ -16,6 +17,20 @@ export function useWatchFormChange(id?: string, form?: UseFormReturn<any>) {
|
||||
prompts: [{ role: PromptRole.User, content: values.prompts }],
|
||||
};
|
||||
|
||||
if (values.showStructuredOutput) {
|
||||
nextValues = {
|
||||
...nextValues,
|
||||
outputs: {
|
||||
...values.outputs,
|
||||
[AgentStructuredOutputField]: values[AgentStructuredOutputField],
|
||||
},
|
||||
};
|
||||
} else {
|
||||
nextValues = {
|
||||
...nextValues,
|
||||
outputs: omit(values.outputs, [AgentStructuredOutputField]),
|
||||
};
|
||||
}
|
||||
updateNodeForm(id, nextValues);
|
||||
}
|
||||
}, [form?.formState.isDirty, id, updateNodeForm, values]);
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { RAGFlowFormItem } from '@/components/ragflow-form';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { t } from 'i18next';
|
||||
import { PropsWithChildren } from 'react';
|
||||
import { z } from 'zod';
|
||||
|
||||
export type OutputType = {
|
||||
@ -11,7 +12,7 @@ export type OutputType = {
|
||||
type OutputProps = {
|
||||
list: Array<OutputType>;
|
||||
isFormRequired?: boolean;
|
||||
};
|
||||
} & PropsWithChildren;
|
||||
|
||||
export function transferOutputs(outputs: Record<string, any>) {
|
||||
return Object.entries(outputs).map(([key, value]) => ({
|
||||
@ -24,10 +25,16 @@ export const OutputSchema = {
|
||||
outputs: z.record(z.any()),
|
||||
};
|
||||
|
||||
export function Output({ list, isFormRequired = false }: OutputProps) {
|
||||
export function Output({
|
||||
list,
|
||||
isFormRequired = false,
|
||||
children,
|
||||
}: OutputProps) {
|
||||
return (
|
||||
<section className="space-y-2">
|
||||
<div className="text-sm">{t('flow.output')}</div>
|
||||
<div className="text-sm flex items-center justify-between">
|
||||
{t('flow.output')} <span>{children}</span>
|
||||
</div>
|
||||
<ul>
|
||||
{list.map((x, idx) => (
|
||||
<li
|
||||
|
||||
@ -4,6 +4,7 @@ import {
|
||||
HoverCardTrigger,
|
||||
} from '@/components/ui/hover-card';
|
||||
import { cn } from '@/lib/utils';
|
||||
import { getStructuredDatatype } from '@/utils/canvas-util';
|
||||
import { get, isEmpty, isPlainObject } from 'lodash';
|
||||
import { ChevronRight } from 'lucide-react';
|
||||
import { PropsWithChildren, ReactNode, useCallback } from 'react';
|
||||
@ -62,22 +63,28 @@ export function StructuredOutputSecondaryMenu({
|
||||
value: option.value + `.${key}`,
|
||||
};
|
||||
|
||||
const dataType = get(value, 'type');
|
||||
const { dataType, compositeDataType } =
|
||||
getStructuredDatatype(value);
|
||||
|
||||
if (
|
||||
isEmpty(types) ||
|
||||
(!isEmpty(types) &&
|
||||
(types?.some((x) => x === dataType) ||
|
||||
(types?.some((x) => x === compositeDataType) ||
|
||||
hasSpecificTypeChild(value ?? {}, types)))
|
||||
) {
|
||||
return (
|
||||
<li key={key} className="pl-1">
|
||||
<div
|
||||
onClick={handleSubMenuClick(nextOption, dataType)}
|
||||
onClick={handleSubMenuClick(
|
||||
nextOption,
|
||||
compositeDataType,
|
||||
)}
|
||||
className="hover:bg-bg-card p-1 text-text-primary rounded-sm flex justify-between"
|
||||
>
|
||||
{key}
|
||||
<span className="text-text-secondary">{dataType}</span>
|
||||
<span className="text-text-secondary">
|
||||
{compositeDataType}
|
||||
</span>
|
||||
</div>
|
||||
{[JsonSchemaDataType.Object, JsonSchemaDataType.Array].some(
|
||||
(x) => x === dataType,
|
||||
@ -122,7 +129,7 @@ export function StructuredOutputSecondaryMenu({
|
||||
side="left"
|
||||
align="start"
|
||||
className={cn(
|
||||
'min-w-[140px] border border-border rounded-md shadow-lg p-0',
|
||||
'min-w-72 border border-border rounded-md shadow-lg p-0',
|
||||
)}
|
||||
>
|
||||
<section className="p-2">
|
||||
|
||||
@ -81,7 +81,8 @@ function MessageForm({ node }: INextOperatorForm) {
|
||||
)}
|
||||
{...field}
|
||||
onValueChange={field.onChange}
|
||||
placeholder={t('flow.messagePlaceholder')}
|
||||
placeholder={t('common.selectPlaceholder')}
|
||||
allowClear
|
||||
></RAGFlowSelect>
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import { RAGFlowNodeType } from '@/interfaces/database/flow';
|
||||
import { isEmpty } from 'lodash';
|
||||
import { useMemo } from 'react';
|
||||
import { ExportFileType, initialMessageValues } from '../../constant';
|
||||
import { initialMessageValues } from '../../constant';
|
||||
import { convertToObjectArray } from '../../utils';
|
||||
|
||||
export function useValues(node?: RAGFlowNodeType) {
|
||||
@ -15,7 +15,6 @@ export function useValues(node?: RAGFlowNodeType) {
|
||||
return {
|
||||
...formData,
|
||||
content: convertToObjectArray(formData.content),
|
||||
output_format: formData.output_format || ExportFileType.PDF,
|
||||
};
|
||||
}, [node]);
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import { FormContainer } from '@/components/form-container';
|
||||
import { SelectWithSearch } from '@/components/originui/select-with-search';
|
||||
import { BlockButton, Button } from '@/components/ui/button';
|
||||
import { Card, CardContent } from '@/components/ui/card';
|
||||
import {
|
||||
@ -16,16 +15,15 @@ import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-ope
|
||||
import { cn } from '@/lib/utils';
|
||||
import { zodResolver } from '@hookform/resolvers/zod';
|
||||
import { t } from 'i18next';
|
||||
import { toLower } from 'lodash';
|
||||
import { X } from 'lucide-react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useFieldArray, useForm, useFormContext } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { z } from 'zod';
|
||||
import { SwitchLogicOperatorOptions, VariableType } from '../../constant';
|
||||
import { useBuildQueryVariableOptions } from '../../hooks/use-get-begin-query';
|
||||
import { SwitchLogicOperatorOptions } from '../../constant';
|
||||
import { IOperatorForm } from '../../interface';
|
||||
import { FormWrapper } from '../components/form-wrapper';
|
||||
import { QueryVariable } from '../components/query-variable';
|
||||
import { useValues } from './use-values';
|
||||
import { useWatchFormChange } from './use-watch-change';
|
||||
|
||||
@ -47,19 +45,6 @@ function ConditionCards({
|
||||
}: ConditionCardsProps) {
|
||||
const form = useFormContext();
|
||||
|
||||
const nextOptions = useBuildQueryVariableOptions();
|
||||
|
||||
const finalOptions = useMemo(() => {
|
||||
return nextOptions.map((x) => {
|
||||
return {
|
||||
...x,
|
||||
options: x.options.filter(
|
||||
(y) => !toLower(y.type).includes(VariableType.Array),
|
||||
),
|
||||
};
|
||||
});
|
||||
}, [nextOptions]);
|
||||
|
||||
const switchOperatorOptions = useBuildSwitchOperatorOptions();
|
||||
|
||||
const name = `${parentName}.${ItemKey}`;
|
||||
@ -101,11 +86,11 @@ function ConditionCards({
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex-1 min-w-0">
|
||||
<FormControl>
|
||||
<SelectWithSearch
|
||||
<QueryVariable
|
||||
pureQuery
|
||||
{...field}
|
||||
options={finalOptions}
|
||||
triggerClassName="text-accent-primary bg-transparent border-none truncate"
|
||||
></SelectWithSearch>
|
||||
hideLabel
|
||||
></QueryVariable>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
|
||||
@ -1,26 +1,57 @@
|
||||
import { buildOptions } from '@/utils/form';
|
||||
import { camelCase } from 'lodash';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
JsonSchemaDataType,
|
||||
VariableAssignerLogicalArrayOperator,
|
||||
VariableAssignerLogicalNumberOperator,
|
||||
VariableAssignerLogicalNumberOperatorLabelMap,
|
||||
VariableAssignerLogicalOperator,
|
||||
} from '../../constant';
|
||||
|
||||
export function useBuildLogicalOptions() {
|
||||
const buildLogicalOptions = useCallback((type: string) => {
|
||||
if (
|
||||
type?.toLowerCase().startsWith(JsonSchemaDataType.Array.toLowerCase())
|
||||
) {
|
||||
return buildOptions(VariableAssignerLogicalArrayOperator);
|
||||
}
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (type === JsonSchemaDataType.Number) {
|
||||
return buildOptions(VariableAssignerLogicalNumberOperator);
|
||||
}
|
||||
const buildVariableAssignerLogicalOptions = useCallback(
|
||||
(record: Record<string, any>) => {
|
||||
return buildOptions(
|
||||
record,
|
||||
t,
|
||||
'flow.variableAssignerLogicalOperatorOptions',
|
||||
true,
|
||||
);
|
||||
},
|
||||
[t],
|
||||
);
|
||||
|
||||
return buildOptions(VariableAssignerLogicalOperator);
|
||||
}, []);
|
||||
const buildLogicalOptions = useCallback(
|
||||
(type: string) => {
|
||||
if (
|
||||
type?.toLowerCase().startsWith(JsonSchemaDataType.Array.toLowerCase())
|
||||
) {
|
||||
return buildVariableAssignerLogicalOptions(
|
||||
VariableAssignerLogicalArrayOperator,
|
||||
);
|
||||
}
|
||||
|
||||
if (type === JsonSchemaDataType.Number) {
|
||||
return Object.values(VariableAssignerLogicalNumberOperator).map(
|
||||
(val) => ({
|
||||
label: t(
|
||||
`flow.variableAssignerLogicalOperatorOptions.${camelCase(VariableAssignerLogicalNumberOperatorLabelMap[val as keyof typeof VariableAssignerLogicalNumberOperatorLabelMap] || val)}`,
|
||||
),
|
||||
value: val,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
return buildVariableAssignerLogicalOptions(
|
||||
VariableAssignerLogicalOperator,
|
||||
);
|
||||
},
|
||||
[buildVariableAssignerLogicalOptions, t],
|
||||
);
|
||||
|
||||
return {
|
||||
buildLogicalOptions,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { getStructuredDatatype } from '@/utils/canvas-util';
|
||||
import { get, isPlainObject } from 'lodash';
|
||||
import { ReactNode, useCallback } from 'react';
|
||||
import {
|
||||
@ -106,10 +107,10 @@ export function useFindAgentStructuredOutputTypeByValue() {
|
||||
if (isPlainObject(values) && properties) {
|
||||
for (const [key, value] of Object.entries(properties)) {
|
||||
const nextPath = path ? `${path}.${key}` : key;
|
||||
const dataType = get(value, 'type');
|
||||
const { dataType, compositeDataType } = getStructuredDatatype(value);
|
||||
|
||||
if (nextPath === target) {
|
||||
return dataType;
|
||||
return compositeDataType;
|
||||
}
|
||||
|
||||
if (
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
import { JSONSchema } from '@/components/jsonjoy-builder';
|
||||
import { getStructuredDatatype } from '@/utils/canvas-util';
|
||||
import { get, isPlainObject, toLower } from 'lodash';
|
||||
import { JsonSchemaDataType } from '../constant';
|
||||
|
||||
function predicate(types: string[], type: string) {
|
||||
return types.some((x) => toLower(x) === toLower(type));
|
||||
function predicate(types: string[], value: unknown) {
|
||||
return types.some(
|
||||
(x) =>
|
||||
toLower(x) === toLower(getStructuredDatatype(value).compositeDataType),
|
||||
);
|
||||
}
|
||||
|
||||
export function hasSpecificTypeChild(
|
||||
@ -12,7 +16,7 @@ export function hasSpecificTypeChild(
|
||||
) {
|
||||
if (Array.isArray(data)) {
|
||||
for (const value of data) {
|
||||
if (isPlainObject(value) && predicate(types, value.type)) {
|
||||
if (isPlainObject(value) && predicate(types, value)) {
|
||||
return true;
|
||||
}
|
||||
if (hasSpecificTypeChild(value, types)) {
|
||||
@ -23,7 +27,11 @@ export function hasSpecificTypeChild(
|
||||
|
||||
if (isPlainObject(data)) {
|
||||
for (const value of Object.values(data)) {
|
||||
if (isPlainObject(value) && predicate(types, value.type)) {
|
||||
if (
|
||||
isPlainObject(value) &&
|
||||
predicate(types, value) &&
|
||||
get(data, 'type') !== JsonSchemaDataType.Array
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -40,17 +40,17 @@ const AddDataSourceModal = ({
|
||||
return (
|
||||
<Modal
|
||||
title={
|
||||
<div className="flex flex-col">
|
||||
{sourceData?.icon}
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="size-6">{sourceData?.icon}</div>
|
||||
{t('setting.addDataSourceModalTital', { name: sourceData?.name })}
|
||||
</div>
|
||||
}
|
||||
open={visible || false}
|
||||
onOpenChange={(open) => !open && hideModal?.()}
|
||||
// onOk={() => handleOk()}
|
||||
okText={t('common.ok')}
|
||||
okText={t('common.confirm')}
|
||||
cancelText={t('common.cancel')}
|
||||
showfooter={false}
|
||||
footer={<div className="p-4"></div>}
|
||||
>
|
||||
<DynamicForm.Root
|
||||
fields={fields}
|
||||
@ -63,7 +63,7 @@ const AddDataSourceModal = ({
|
||||
] as FieldValues
|
||||
}
|
||||
>
|
||||
<div className="flex items-center justify-end w-full gap-2 py-4">
|
||||
<div className=" absolute bottom-0 right-0 left-0 flex items-center justify-end w-full gap-2 py-6 px-6">
|
||||
<DynamicForm.CancelButton
|
||||
handleCancel={() => {
|
||||
hideModal?.();
|
||||
@ -71,7 +71,7 @@ const AddDataSourceModal = ({
|
||||
/>
|
||||
<DynamicForm.SavingButton
|
||||
submitLoading={loading || false}
|
||||
buttonText={t('common.ok')}
|
||||
buttonText={t('common.confirm')}
|
||||
submitFunc={(values: FieldValues) => {
|
||||
handleOk(values);
|
||||
}}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks';
|
||||
import { Settings, Trash2 } from 'lucide-react';
|
||||
@ -32,17 +33,19 @@ export const AddedSourceCard = (props: IAddedSourceCardProps) => {
|
||||
>
|
||||
<div className="text-sm text-text-secondary ">{item.name}</div>
|
||||
<div className="text-sm text-text-secondary flex gap-2">
|
||||
<Settings
|
||||
className="cursor-pointer"
|
||||
size={14}
|
||||
<Button
|
||||
variant={'ghost'}
|
||||
className="rounded-lg px-2 py-1 bg-transparent hover:bg-bg-card"
|
||||
onClick={() => {
|
||||
toDetail(item.id);
|
||||
}}
|
||||
/>
|
||||
>
|
||||
<Settings size={14} />
|
||||
</Button>
|
||||
{/* <ConfirmDeleteDialog onOk={() => handleDelete(item)}> */}
|
||||
<Trash2
|
||||
className="cursor-pointer"
|
||||
size={14}
|
||||
<Button
|
||||
variant={'ghost'}
|
||||
className="rounded-lg px-2 py-1 bg-transparent hover:bg-state-error-5 hover:text-state-error"
|
||||
onClick={() =>
|
||||
delSourceModal({
|
||||
data: item,
|
||||
@ -51,7 +54,9 @@ export const AddedSourceCard = (props: IAddedSourceCardProps) => {
|
||||
},
|
||||
})
|
||||
}
|
||||
/>
|
||||
>
|
||||
<Trash2 className="cursor-pointer" size={14} />
|
||||
</Button>
|
||||
{/* </ConfirmDeleteDialog> */}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -71,6 +71,7 @@ export default function McpServer() {
|
||||
className="w-40"
|
||||
value={searchString}
|
||||
onChange={handleInputChange}
|
||||
placeholder={t('common.search')}
|
||||
></SearchInput>
|
||||
<Button variant={'secondary'} onClick={switchSelectionMode}>
|
||||
{isSelectionMode ? (
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import { ConfirmDeleteDialog } from '@/components/confirm-delete-dialog';
|
||||
import { RAGFlowTooltip } from '@/components/ui/tooltip';
|
||||
import { useDeleteMcpServer } from '@/hooks/use-mcp-request';
|
||||
import { PenLine, Trash2, Upload } from 'lucide-react';
|
||||
import { MouseEventHandler, useCallback } from 'react';
|
||||
@ -20,24 +19,24 @@ export function McpOperation({
|
||||
}, [deleteMcpServer, mcpId]);
|
||||
|
||||
return (
|
||||
<div className="hidden gap-3 group-hover:flex text-text-secondary">
|
||||
<RAGFlowTooltip tooltip={t('mcp.export')}>
|
||||
<Upload
|
||||
className="size-3 cursor-pointer"
|
||||
onClick={handleExportMcpJson([mcpId])}
|
||||
/>
|
||||
</RAGFlowTooltip>
|
||||
<RAGFlowTooltip tooltip={t('common.edit')}>
|
||||
<PenLine
|
||||
className="size-3 cursor-pointer"
|
||||
onClick={showEditModal(mcpId)}
|
||||
/>
|
||||
</RAGFlowTooltip>
|
||||
<RAGFlowTooltip tooltip={t('common.delete')}>
|
||||
<ConfirmDeleteDialog onOk={handleDelete}>
|
||||
<Trash2 className="size-3 cursor-pointer" />
|
||||
</ConfirmDeleteDialog>
|
||||
</RAGFlowTooltip>
|
||||
<div className="hidden gap-1 group-hover:flex text-text-secondary">
|
||||
{/* <RAGFlowTooltip tooltip={t('mcp.export')}> */}
|
||||
<Upload
|
||||
className="size-5 cursor-pointer p-1 rounded-sm hover:text-text-primary hover:bg-bg-card"
|
||||
onClick={handleExportMcpJson([mcpId])}
|
||||
/>
|
||||
{/* </RAGFlowTooltip>
|
||||
<RAGFlowTooltip tooltip={t('common.edit')}> */}
|
||||
<PenLine
|
||||
className="size-5 cursor-pointer p-1 rounded-sm hover:text-text-primary hover:bg-bg-card"
|
||||
onClick={showEditModal(mcpId)}
|
||||
/>
|
||||
{/* </RAGFlowTooltip>
|
||||
<RAGFlowTooltip tooltip={t('common.delete')}> */}
|
||||
<ConfirmDeleteDialog onOk={handleDelete}>
|
||||
<Trash2 className="size-5 cursor-pointer p-1 rounded-sm hover:text-state-error hover:bg-state-error-5" />
|
||||
</ConfirmDeleteDialog>
|
||||
{/* </RAGFlowTooltip> */}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@ -143,12 +143,12 @@ const ProfilePage: FC = () => {
|
||||
{profile.userName}
|
||||
</div>
|
||||
<Button
|
||||
variant={'secondary'}
|
||||
variant={'ghost'}
|
||||
type="button"
|
||||
onClick={() => handleEditClick(EditType.editName)}
|
||||
className="text-sm text-text-secondary flex gap-1 px-1"
|
||||
className="text-sm text-text-secondary flex gap-1 px-1 border border-border-button"
|
||||
>
|
||||
<PenLine size={12} /> Edit
|
||||
<PenLine size={12} /> {t('edit')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
@ -175,12 +175,12 @@ const ProfilePage: FC = () => {
|
||||
{profile.timeZone}
|
||||
</div>
|
||||
<Button
|
||||
variant={'secondary'}
|
||||
variant={'ghost'}
|
||||
type="button"
|
||||
onClick={() => handleEditClick(EditType.editTimeZone)}
|
||||
className="text-sm text-text-secondary flex gap-1 px-1"
|
||||
className="text-sm text-text-secondary flex gap-1 px-1 border border-border-button"
|
||||
>
|
||||
<PenLine size={12} /> Edit
|
||||
<PenLine size={12} /> {t('edit')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
@ -208,12 +208,12 @@ const ProfilePage: FC = () => {
|
||||
{profile.currPasswd ? '********' : ''}
|
||||
</div>
|
||||
<Button
|
||||
variant={'secondary'}
|
||||
variant={'ghost'}
|
||||
type="button"
|
||||
onClick={() => handleEditClick(EditType.editPassword)}
|
||||
className="text-sm text-text-secondary flex gap-1 px-1"
|
||||
className="text-sm text-text-secondary flex gap-1 px-1 border border-border-button"
|
||||
>
|
||||
<PenLine size={12} /> Edit
|
||||
<PenLine size={12} /> {t('edit')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
@ -304,7 +304,7 @@ const ProfilePage: FC = () => {
|
||||
<div className="flex flex-col w-full gap-2">
|
||||
<FormLabel
|
||||
required
|
||||
className="text-sm flex justify-between text-text-secondary whitespace-nowrap"
|
||||
className="text-sm flex text-text-secondary whitespace-nowrap"
|
||||
>
|
||||
{t('currentPassword')}
|
||||
</FormLabel>
|
||||
|
||||
@ -67,21 +67,22 @@ export const ModelProviderCard: FC<IModelCardProps> = ({
|
||||
return (
|
||||
<div className={`w-full rounded-lg border border-border-button`}>
|
||||
{/* Header */}
|
||||
<div className="flex h-16 items-center justify-between p-4 cursor-pointer transition-colors">
|
||||
<div className="flex h-16 items-center justify-between p-4 cursor-pointer transition-colors text-text-secondary">
|
||||
<div className="flex items-center space-x-3">
|
||||
<LlmIcon name={item.name} />
|
||||
<div>
|
||||
<h3 className="font-medium">{item.name}</h3>
|
||||
<div className="font-medium text-xl">{item.name}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center space-x-2">
|
||||
<Button
|
||||
variant={'ghost'}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleApiKeyClick();
|
||||
}}
|
||||
className="px-3 py-1 text-sm bg-bg-input hover:bg-bg-input text-text-primary rounded-md transition-colors flex items-center space-x-1"
|
||||
className="px-3 py-1 text-sm rounded-md transition-colors flex items-center space-x-1 border border-border-default"
|
||||
>
|
||||
<SettingOutlined />
|
||||
<span>
|
||||
@ -90,23 +91,24 @@ export const ModelProviderCard: FC<IModelCardProps> = ({
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant={'ghost'}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleShowMoreClick();
|
||||
}}
|
||||
className="px-3 py-1 text-sm bg-bg-input hover:bg-bg-input text-text-primary rounded-md transition-colors flex items-center space-x-1"
|
||||
className="px-3 py-1 text-sm rounded-md transition-colors flex items-center space-x-1 border border-border-default"
|
||||
>
|
||||
<span>{visible ? t('hideModels') : t('showMoreModels')}</span>
|
||||
{!visible ? <ChevronsDown /> : <ChevronsUp />}
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant={'secondary'}
|
||||
variant={'ghost'}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleDeleteFactory();
|
||||
}}
|
||||
className="p-1 text-text-primary hover:text-state-error transition-colors"
|
||||
className=" hover:text-state-error hover:bg-state-error-5 transition-colors border border-border-default"
|
||||
>
|
||||
<Trash2 />
|
||||
</Button>
|
||||
|
||||
@ -27,7 +27,6 @@ interface IProps {
|
||||
const SystemSetting = ({ onOk, loading }: IProps) => {
|
||||
const { systemSetting: initialValues, allOptions } =
|
||||
useFetchSystemModelSettingOnMount();
|
||||
const { t: tcommon } = useTranslate('common');
|
||||
const { t } = useTranslate('setting');
|
||||
|
||||
const [formData, setFormData] = useState({
|
||||
@ -149,7 +148,7 @@ const SystemSetting = ({ onOk, loading }: IProps) => {
|
||||
<TooltipTrigger>
|
||||
<CircleQuestionMark
|
||||
size={12}
|
||||
className="ml-1 text-text-disabled text-xs"
|
||||
className="ml-1 text-text-secondary text-xs"
|
||||
/>
|
||||
</TooltipTrigger>
|
||||
</Tooltip>
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
// src/components/AvailableModels.tsx
|
||||
import { LlmIcon } from '@/components/svg-icon';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { useTranslate } from '@/hooks/common-hooks';
|
||||
import { useSelectLlmList } from '@/hooks/llm-hooks';
|
||||
import { Plus, Search } from 'lucide-react';
|
||||
@ -77,12 +78,12 @@ export const AvailableModels: FC<{
|
||||
{/* Search Bar */}
|
||||
<div className="mb-6">
|
||||
<div className="relative">
|
||||
<input
|
||||
<Input
|
||||
type="text"
|
||||
placeholder={t('search')}
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
className="w-full px-4 py-3 pl-10 bg-bg-input border border-border-default rounded-lg focus:outline-none focus:ring-1 focus:ring-border-button transition-colors"
|
||||
className="w-full px-4 py-2 pl-10 bg-bg-input border border-border-default rounded-lg focus:outline-none focus:ring-1 focus:ring-border-button transition-colors"
|
||||
/>
|
||||
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 w-5 h-5 text-text-secondary" />
|
||||
</div>
|
||||
|
||||
@ -72,7 +72,7 @@ const TenantTable = ({ searchTerm }: { searchTerm: string }) => {
|
||||
<div className="rounded-lg bg-bg-input scrollbar-auto overflow-hidden border border-border-default">
|
||||
<Table rootClassName="rounded-lg">
|
||||
<TableHeader className="bg-bg-title">
|
||||
<TableRow>
|
||||
<TableRow className="hover:bg-bg-title">
|
||||
<TableHead className="h-12 px-4">{t('common.name')}</TableHead>
|
||||
<TableHead
|
||||
className="h-12 px-4 cursor-pointer"
|
||||
|
||||
@ -78,7 +78,7 @@ const UserTable = ({ searchUser }: { searchUser: string }) => {
|
||||
<div className="rounded-lg bg-bg-input scrollbar-auto overflow-hidden border border-border-default">
|
||||
<Table rootClassName="rounded-lg">
|
||||
<TableHeader className="bg-bg-title">
|
||||
<TableRow>
|
||||
<TableRow className="hover:bg-bg-title">
|
||||
<TableHead className="h-12 px-4">{t('common.name')}</TableHead>
|
||||
<TableHead
|
||||
className="h-12 px-4 cursor-pointer"
|
||||
|
||||
@ -6,7 +6,7 @@ import {
|
||||
import { BaseNode } from '@/interfaces/database/agent';
|
||||
|
||||
import { Edge } from '@xyflow/react';
|
||||
import { isEmpty } from 'lodash';
|
||||
import { get, isEmpty } from 'lodash';
|
||||
import { ComponentType, ReactNode } from 'react';
|
||||
|
||||
export function filterAllUpstreamNodeIds(edges: Edge[], nodeIds: string[]) {
|
||||
@ -87,3 +87,15 @@ export function buildNodeOutputOptions({
|
||||
),
|
||||
}));
|
||||
}
|
||||
|
||||
export function getStructuredDatatype(value: Record<string, any> | unknown) {
|
||||
const dataType = get(value, 'type');
|
||||
const arrayItemsType = get(value, 'items.type', JsonSchemaDataType.String);
|
||||
|
||||
const compositeDataType =
|
||||
dataType === JsonSchemaDataType.Array
|
||||
? `${dataType}<${arrayItemsType}>`
|
||||
: dataType;
|
||||
|
||||
return { dataType, compositeDataType };
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user