mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-04 03:25:30 +08:00
Compare commits
66 Commits
4aa1abd8e5
...
pipeline
| Author | SHA1 | Date | |
|---|---|---|---|
| 32dbed36e3 | |||
| 7f62ab8eb3 | |||
| e87987785c | |||
| b3b0be832a | |||
| 20b577a72c | |||
| 4d6ff672eb | |||
| fb19e24f8a | |||
| 9989e06abb | |||
| c49e81882c | |||
| 63cdce660e | |||
| 8bc8126848 | |||
| 71f69cdb75 | |||
| 664bc0b961 | |||
| f4cc4dbd30 | |||
| cce361d774 | |||
| 7a63b6386e | |||
| 4996dcb0eb | |||
| 3521eb61fe | |||
| 6b9b785b5c | |||
| 4c0a89f262 | |||
| 76b1ee2a00 | |||
| 771a38434f | |||
| 886d38620e | |||
| c7efaab30e | |||
| ff49454501 | |||
| 14273b4595 | |||
| abe7132630 | |||
| c1151519a0 | |||
| a1147ce609 | |||
| d907e79893 | |||
| 1b19d302c5 | |||
| 840b2b5809 | |||
| a6039cf563 | |||
| 8be7380b79 | |||
| afb8a84f7b | |||
| 6bf0cda16f | |||
| 5715ca6b74 | |||
| 8f465525f7 | |||
| f20dca2895 | |||
| 0c557e37ad | |||
| d0bfe8b10c | |||
| 28afc7e67d | |||
| 73c33bc8d2 | |||
| 476852e8f1 | |||
| e6cf00cb33 | |||
| d039d1e73d | |||
| d050ef568d | |||
| 028c2d83e9 | |||
| b5d6a6e8f2 | |||
| 5dfdbcce3a | |||
| 4fae40f66a | |||
| a1b947ffd6 | |||
| f9c7404bee | |||
| 5c1791d7f0 | |||
| e82617f6de | |||
| a7abc57f68 | |||
| cf1f523d03 | |||
| ccb255919a | |||
| b68c84b52e | |||
| 93cf0258c3 | |||
| b79fef1ca8 | |||
| 2b50de3186 | |||
| d8ef22db68 | |||
| 592f3b1555 | |||
| 3404469e2a | |||
| 63d7382dc9 |
@ -1,9 +1,13 @@
|
||||
import argparse
|
||||
import base64
|
||||
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from typing import Dict, List, Any
|
||||
from lark import Lark, Transformer, Tree
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
from api.common.base64 import encode_to_base64
|
||||
|
||||
GRAMMAR = r"""
|
||||
start: command
|
||||
@ -19,6 +23,8 @@ sql_command: list_services
|
||||
| show_user
|
||||
| drop_user
|
||||
| alter_user
|
||||
| create_user
|
||||
| activate_user
|
||||
| list_datasets
|
||||
| list_agents
|
||||
|
||||
@ -35,6 +41,7 @@ meta_arg: /[^\\s"']+/ | quoted_string
|
||||
LIST: "LIST"i
|
||||
SERVICES: "SERVICES"i
|
||||
SHOW: "SHOW"i
|
||||
CREATE: "CREATE"i
|
||||
SERVICE: "SERVICE"i
|
||||
SHUTDOWN: "SHUTDOWN"i
|
||||
STARTUP: "STARTUP"i
|
||||
@ -43,6 +50,7 @@ USERS: "USERS"i
|
||||
DROP: "DROP"i
|
||||
USER: "USER"i
|
||||
ALTER: "ALTER"i
|
||||
ACTIVE: "ACTIVE"i
|
||||
PASSWORD: "PASSWORD"i
|
||||
DATASETS: "DATASETS"i
|
||||
OF: "OF"i
|
||||
@ -58,12 +66,15 @@ list_users: LIST USERS ";"
|
||||
drop_user: DROP USER quoted_string ";"
|
||||
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
|
||||
show_user: SHOW USER quoted_string ";"
|
||||
create_user: CREATE USER quoted_string quoted_string ";"
|
||||
activate_user: ALTER USER ACTIVE quoted_string status ";"
|
||||
|
||||
list_datasets: LIST DATASETS OF quoted_string ";"
|
||||
list_agents: LIST AGENTS OF quoted_string ";"
|
||||
|
||||
identifier: WORD
|
||||
quoted_string: QUOTED_STRING
|
||||
status: WORD
|
||||
|
||||
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
|
||||
WORD: /[a-zA-Z0-9_\-\.]+/
|
||||
@ -118,6 +129,16 @@ class AdminTransformer(Transformer):
|
||||
new_password = items[4]
|
||||
return {"type": "alter_user", "username": user_name, "password": new_password}
|
||||
|
||||
def create_user(self, items):
|
||||
user_name = items[2]
|
||||
password = items[3]
|
||||
return {"type": "create_user", "username": user_name, "password": password, "role": "user"}
|
||||
|
||||
def activate_user(self, items):
|
||||
user_name = items[3]
|
||||
activate_status = items[4]
|
||||
return {"type": "activate_user", "activate_status": activate_status, "username": user_name}
|
||||
|
||||
def list_datasets(self, items):
|
||||
user_name = items[3]
|
||||
return {"type": "list_datasets", "username": user_name}
|
||||
@ -147,9 +168,12 @@ class AdminTransformer(Transformer):
|
||||
return items
|
||||
|
||||
|
||||
def encode_to_base64(input_string):
|
||||
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
||||
return base64_encoded.decode('utf-8')
|
||||
def encrypt(input_string):
|
||||
pub = '-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----'
|
||||
pub_key = RSA.importKey(pub)
|
||||
cipher = Cipher_pkcs1_v1_5.new(pub_key)
|
||||
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode('utf-8')))
|
||||
return base64.b64encode(cipher_text).decode("utf-8")
|
||||
|
||||
|
||||
class AdminCommandParser:
|
||||
@ -220,6 +244,9 @@ class AdminCLI:
|
||||
if not data:
|
||||
print("No data to print")
|
||||
return
|
||||
if isinstance(data, dict):
|
||||
# handle single row data
|
||||
data = [data]
|
||||
|
||||
columns = list(data[0].keys())
|
||||
col_widths = {}
|
||||
@ -335,6 +362,10 @@ class AdminCLI:
|
||||
self._handle_drop_user(command_dict)
|
||||
case 'alter_user':
|
||||
self._handle_alter_user(command_dict)
|
||||
case 'create_user':
|
||||
self._handle_create_user(command_dict)
|
||||
case 'activate_user':
|
||||
self._handle_activate_user(command_dict)
|
||||
case 'list_datasets':
|
||||
self._handle_list_datasets(command_dict)
|
||||
case 'list_agents':
|
||||
@ -349,9 +380,8 @@ class AdminCLI:
|
||||
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/services'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = dict
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
res_json = response.json()
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||
@ -377,9 +407,8 @@ class AdminCLI:
|
||||
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = dict
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
res_json = response.json()
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||
@ -388,11 +417,25 @@ class AdminCLI:
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Showing user: {username}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get user {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_drop_user(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Drop user: {username}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
|
||||
response = requests.delete(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
print(res_json["message"])
|
||||
else:
|
||||
print(f"Fail to drop user, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_alter_user(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
@ -400,16 +443,75 @@ class AdminCLI:
|
||||
password_tree: Tree = command['password']
|
||||
password: str = password_tree.children[0].strip("'\"")
|
||||
print(f"Alter user: {username}, password: {password}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/password'
|
||||
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||
json={'new_password': encrypt(password)})
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
print(res_json["message"])
|
||||
else:
|
||||
print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_create_user(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
password_tree: Tree = command['password']
|
||||
password: str = password_tree.children[0].strip("'\"")
|
||||
role: str = command['role']
|
||||
print(f"Create user: {username}, password: {password}, role: {role}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||
response = requests.post(
|
||||
url,
|
||||
auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||
json={'username': username, 'password': encrypt(password), 'role': role}
|
||||
)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to create user {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_activate_user(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
activate_tree: Tree = command['activate_status']
|
||||
activate_status: str = activate_tree.children[0].strip("'\"")
|
||||
if activate_status.lower() in ['on', 'off']:
|
||||
print(f"Alter user {username} activate status, turn {activate_status.lower()}.")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/activate'
|
||||
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||
json={'activate_status': activate_status})
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
print(res_json["message"])
|
||||
else:
|
||||
print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
|
||||
else:
|
||||
print(f"Unknown activate status: {activate_status}.")
|
||||
|
||||
def _handle_list_datasets(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Listing all datasets of user: {username}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/datasets'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get all datasets of {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_list_agents(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Listing all agents of user: {username}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/agents'
|
||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
else:
|
||||
print(f"Fail to get all agents of {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_meta_command(self, command):
|
||||
meta_command = command['command']
|
||||
@ -436,6 +538,7 @@ Commands:
|
||||
DROP USER <user>
|
||||
CREATE USER <user> <password>
|
||||
ALTER USER PASSWORD <user> <new_password>
|
||||
ALTER USER ACTIVE <user> <on/off>
|
||||
LIST DATASETS OF <user>
|
||||
LIST AGENTS OF <user>
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from flask import Flask
|
||||
from routes import admin_bp
|
||||
from api.utils.log_utils import init_root_logger
|
||||
from api.constants import SERVICE_CONF
|
||||
from api import settings
|
||||
from config import load_configurations, SERVICE_CONFIGS
|
||||
|
||||
stop_event = threading.Event()
|
||||
@ -26,7 +27,7 @@ if __name__ == '__main__':
|
||||
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(admin_bp)
|
||||
|
||||
settings.init_settings()
|
||||
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
|
||||
|
||||
try:
|
||||
|
||||
@ -4,7 +4,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
from api.utils import read_config
|
||||
from api.utils.configs import read_config
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from flask import Blueprint, request
|
||||
|
||||
from auth import login_verify
|
||||
from responses import success_response, error_response
|
||||
from services import UserMgr, ServiceMgr
|
||||
from services import UserMgr, ServiceMgr, UserServiceMgr
|
||||
from exceptions import AdminException
|
||||
|
||||
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
|
||||
@ -38,21 +39,29 @@ def create_user():
|
||||
password = data['password']
|
||||
role = data.get('role', 'user')
|
||||
|
||||
user = UserMgr.create_user(username, password, role)
|
||||
return success_response(user, "User created successfully", 201)
|
||||
res = UserMgr.create_user(username, password, role)
|
||||
if res["success"]:
|
||||
user_info = res["user_info"]
|
||||
user_info.pop("password") # do not return password
|
||||
return success_response(user_info, "User created successfully")
|
||||
else:
|
||||
return error_response("create user failed")
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
return error_response(str(e))
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>', methods=['DELETE'])
|
||||
@login_verify
|
||||
def delete_user(username):
|
||||
try:
|
||||
UserMgr.delete_user(username)
|
||||
return success_response(None, "User and all data deleted successfully")
|
||||
res = UserMgr.delete_user(username)
|
||||
if res["success"]:
|
||||
return success_response(None, res["message"])
|
||||
else:
|
||||
return error_response(res["message"])
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
@ -69,8 +78,8 @@ def change_password(username):
|
||||
return error_response("New password is required", 400)
|
||||
|
||||
new_password = data['new_password']
|
||||
UserMgr.update_user_password(username, new_password)
|
||||
return success_response(None, "Password updated successfully")
|
||||
msg = UserMgr.update_user_password(username, new_password)
|
||||
return success_response(None, msg)
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
@ -78,6 +87,21 @@ def change_password(username):
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>/activate', methods=['PUT'])
|
||||
@login_verify
|
||||
def alter_user_activate_status(username):
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data or 'activate_status' not in data:
|
||||
return error_response("Activation status is required", 400)
|
||||
activate_status = data['activate_status']
|
||||
msg = UserMgr.update_user_activate_status(username, activate_status)
|
||||
return success_response(None, msg)
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
@admin_bp.route('/users/<username>', methods=['GET'])
|
||||
@login_verify
|
||||
def get_user_details(username):
|
||||
@ -90,6 +114,31 @@ def get_user_details(username):
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
@admin_bp.route('/users/<username>/datasets', methods=['GET'])
|
||||
@login_verify
|
||||
def get_user_datasets(username):
|
||||
try:
|
||||
datasets_list = UserServiceMgr.get_user_datasets(username)
|
||||
return success_response(datasets_list)
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>/agents', methods=['GET'])
|
||||
@login_verify
|
||||
def get_user_agents(username):
|
||||
try:
|
||||
agents_list = UserServiceMgr.get_user_agents(username)
|
||||
return success_response(agents_list)
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/services', methods=['GET'])
|
||||
@login_verify
|
||||
|
||||
@ -1,5 +1,13 @@
|
||||
import re
|
||||
from werkzeug.security import check_password_hash
|
||||
from api.db import ActiveEnum
|
||||
from api.db.services import UserService
|
||||
from exceptions import AdminException
|
||||
from api.db.joint_services.user_account_service import create_new_user, delete_user_data
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.crypt import decrypt
|
||||
from exceptions import AdminException, UserAlreadyExistsError, UserNotFoundError
|
||||
from config import SERVICE_CONFIGS
|
||||
|
||||
class UserMgr:
|
||||
@ -13,19 +21,132 @@ class UserMgr:
|
||||
|
||||
@staticmethod
|
||||
def get_user_details(username):
|
||||
raise AdminException("get_user_details: not implemented")
|
||||
# use email to query
|
||||
users = UserService.query_user_by_email(username)
|
||||
result = []
|
||||
for user in users:
|
||||
result.append({
|
||||
'email': user.email,
|
||||
'language': user.language,
|
||||
'last_login_time': user.last_login_time,
|
||||
'is_authenticated': user.is_authenticated,
|
||||
'is_active': user.is_active,
|
||||
'is_anonymous': user.is_anonymous,
|
||||
'login_channel': user.login_channel,
|
||||
'status': user.status,
|
||||
'is_superuser': user.is_superuser,
|
||||
'create_date': user.create_date,
|
||||
'update_date': user.update_date
|
||||
})
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def create_user(username, password, role="user"):
|
||||
raise AdminException("create_user: not implemented")
|
||||
def create_user(username, password, role="user") -> dict:
|
||||
# Validate the email address
|
||||
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", username):
|
||||
raise AdminException(f"Invalid email address: {username}!")
|
||||
# Check if the email address is already used
|
||||
if UserService.query(email=username):
|
||||
raise UserAlreadyExistsError(username)
|
||||
# Construct user info data
|
||||
user_info_dict = {
|
||||
"email": username,
|
||||
"nickname": "", # ask user to edit it manually in settings.
|
||||
"password": decrypt(password),
|
||||
"login_channel": "password",
|
||||
"is_superuser": role == "admin",
|
||||
}
|
||||
return create_new_user(user_info_dict)
|
||||
|
||||
@staticmethod
|
||||
def delete_user(username):
|
||||
raise AdminException("delete_user: not implemented")
|
||||
# use email to delete
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
if len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
usr = user_list[0]
|
||||
return delete_user_data(usr.id)
|
||||
|
||||
@staticmethod
|
||||
def update_user_password(username, new_password):
|
||||
raise AdminException("update_user_password: not implemented")
|
||||
def update_user_password(username, new_password) -> str:
|
||||
# use email to find user. check exist and unique.
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
# check new_password different from old.
|
||||
usr = user_list[0]
|
||||
psw = decrypt(new_password)
|
||||
if check_password_hash(usr.password, psw):
|
||||
return "Same password, no need to update!"
|
||||
# update password
|
||||
UserService.update_user_password(usr.id, psw)
|
||||
return "Password updated successfully!"
|
||||
|
||||
@staticmethod
|
||||
def update_user_activate_status(username, activate_status: str):
|
||||
# use email to find user. check exist and unique.
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
# check activate status different from new
|
||||
usr = user_list[0]
|
||||
# format activate_status before handle
|
||||
_activate_status = activate_status.lower()
|
||||
target_status = {
|
||||
'on': ActiveEnum.ACTIVE.value,
|
||||
'off': ActiveEnum.INACTIVE.value,
|
||||
}.get(_activate_status)
|
||||
if not target_status:
|
||||
raise AdminException(f"Invalid activate_status: {activate_status}")
|
||||
if target_status == usr.is_active:
|
||||
return f"User activate status is already {_activate_status}!"
|
||||
# update is_active
|
||||
UserService.update_user(usr.id, {"is_active": target_status})
|
||||
return f"Turn {_activate_status} user activate status successfully!"
|
||||
|
||||
class UserServiceMgr:
|
||||
|
||||
@staticmethod
|
||||
def get_user_datasets(username):
|
||||
# use email to find user.
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
# find tenants
|
||||
usr = user_list[0]
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||
# filter permitted kb and owned kb
|
||||
return KnowledgebaseService.get_all_kb_by_tenant_ids(tenant_ids, usr.id)
|
||||
|
||||
@staticmethod
|
||||
def get_user_agents(username):
|
||||
# use email to find user.
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
# find tenants
|
||||
usr = user_list[0]
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||
# filter permitted agents and owned agents
|
||||
res = UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
|
||||
return [{
|
||||
'title': r['title'],
|
||||
'permission': r['permission'],
|
||||
'canvas_type': r['canvas_type'],
|
||||
'canvas_category': r['canvas_category']
|
||||
} for r in res]
|
||||
|
||||
class ServiceMgr:
|
||||
|
||||
|
||||
@ -153,6 +153,16 @@ class Graph:
|
||||
def get_tenant_id(self):
|
||||
return self._tenant_id
|
||||
|
||||
def get_variable_value(self, exp: str) -> Any:
|
||||
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
|
||||
if exp.find("@") < 0:
|
||||
return self.globals[exp]
|
||||
cpn_id, var_nm = exp.split("@")
|
||||
cpn = self.get_component(cpn_id)
|
||||
if not cpn:
|
||||
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
||||
return cpn["obj"].output(var_nm)
|
||||
|
||||
|
||||
class Canvas(Graph):
|
||||
|
||||
@ -406,16 +416,6 @@ class Canvas(Graph):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_variable_value(self, exp: str) -> Any:
|
||||
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
|
||||
if exp.find("@") < 0:
|
||||
return self.globals[exp]
|
||||
cpn_id, var_nm = exp.split("@")
|
||||
cpn = self.get_component(cpn_id)
|
||||
if not cpn:
|
||||
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
||||
return cpn["obj"].output(var_nm)
|
||||
|
||||
def get_history(self, window_size):
|
||||
convs = []
|
||||
if window_size <= 0:
|
||||
|
||||
@ -137,7 +137,7 @@ class Agent(LLM, ToolBase):
|
||||
res.update(cpn.get_input_form())
|
||||
return res
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if kwargs.get("user_prompt"):
|
||||
usr_pmt = ""
|
||||
|
||||
@ -431,7 +431,7 @@ class ComponentBase(ABC):
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@ -80,7 +80,7 @@ Here's description of each category:
|
||||
- Prioritize the most specific applicable category
|
||||
- Return only the category name without explanations
|
||||
- Use "Other" only when no other category fits
|
||||
|
||||
|
||||
""".format(
|
||||
"\n - ".join(list(self.category_description.keys())),
|
||||
"\n".join(descriptions)
|
||||
@ -96,7 +96,7 @@ Here's description of each category:
|
||||
class Categorize(LLM, ABC):
|
||||
component_name = "Categorize"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||
if not msg:
|
||||
@ -112,7 +112,7 @@ class Categorize(LLM, ABC):
|
||||
|
||||
user_prompt = """
|
||||
---- Real Data ----
|
||||
{} →
|
||||
{} →
|
||||
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
|
||||
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
||||
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
||||
@ -134,4 +134,4 @@ class Categorize(LLM, ABC):
|
||||
self.set_output("_next", cpn_ids)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
||||
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
||||
|
||||
@ -53,7 +53,7 @@ class InvokeParam(ComponentParamBase):
|
||||
class Invoke(ComponentBase, ABC):
|
||||
component_name = "Invoke"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
|
||||
def _invoke(self, **kwargs):
|
||||
args = {}
|
||||
for para in self._param.variables:
|
||||
|
||||
@ -101,6 +101,8 @@ class LLM(ComponentBase):
|
||||
|
||||
def get_input_elements(self) -> dict[str, Any]:
|
||||
res = self.get_input_elements_from_text(self._param.sys_prompt)
|
||||
if isinstance(self._param.prompts, str):
|
||||
self._param.prompts = [{"role": "user", "content": self._param.prompts}]
|
||||
for prompt in self._param.prompts:
|
||||
d = self.get_input_elements_from_text(prompt["content"])
|
||||
res.update(d)
|
||||
@ -112,6 +114,17 @@ class LLM(ComponentBase):
|
||||
def add2system_prompt(self, txt):
|
||||
self._param.sys_prompt += txt
|
||||
|
||||
def _sys_prompt_and_msg(self, msg, args):
|
||||
if isinstance(self._param.prompts, str):
|
||||
self._param.prompts = [{"role": "user", "content": self._param.prompts}]
|
||||
for p in self._param.prompts:
|
||||
if msg and msg[-1]["role"] == p["role"]:
|
||||
continue
|
||||
p = deepcopy(p)
|
||||
p["content"] = self.string_format(p["content"], args)
|
||||
msg.append(p)
|
||||
return msg, self.string_format(self._param.sys_prompt, args)
|
||||
|
||||
def _prepare_prompt_variables(self):
|
||||
if self._param.visual_files_var:
|
||||
self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
|
||||
@ -127,7 +140,6 @@ class LLM(ComponentBase):
|
||||
|
||||
args = {}
|
||||
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
|
||||
sys_prompt = self._param.sys_prompt
|
||||
for k, o in vars.items():
|
||||
args[k] = o["value"]
|
||||
if not isinstance(args[k], str):
|
||||
@ -137,16 +149,8 @@ class LLM(ComponentBase):
|
||||
args[k] = str(args[k])
|
||||
self.set_input_value(k, args[k])
|
||||
|
||||
msg = self._canvas.get_history(self._param.message_history_window_size)[:-1]
|
||||
for p in self._param.prompts:
|
||||
if msg and msg[-1]["role"] == p["role"]:
|
||||
continue
|
||||
msg.append(deepcopy(p))
|
||||
|
||||
sys_prompt = self.string_format(sys_prompt, args)
|
||||
msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args)
|
||||
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
|
||||
for m in msg:
|
||||
m["content"] = self.string_format(m["content"], args)
|
||||
if self._param.cite and self._canvas.get_reference()["chunks"]:
|
||||
sys_prompt += citation_prompt(user_defined_prompt)
|
||||
|
||||
@ -201,7 +205,7 @@ class LLM(ComponentBase):
|
||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
||||
yield delta(txt)
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
def clean_formated_answer(ans: str) -> str:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
@ -127,7 +127,7 @@ class Message(ComponentBase):
|
||||
]
|
||||
return any([re.search(p, content) for p in patt])
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
rand_cnt = random.choice(self._param.content)
|
||||
if self._param.stream and not self._is_jinjia2(rand_cnt):
|
||||
|
||||
@ -56,7 +56,7 @@ class StringTransform(Message, ABC):
|
||||
"type": "line"
|
||||
} for k, o in self.get_input_elements_from_text(self._param.script).items()}
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self._param.method == "split":
|
||||
self._split(kwargs.get("line"))
|
||||
|
||||
@ -61,7 +61,7 @@ class SwitchParam(ComponentParamBase):
|
||||
class Switch(ComponentBase, ABC):
|
||||
component_name = "Switch"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
|
||||
def _invoke(self, **kwargs):
|
||||
for cond in self._param.conditions:
|
||||
res = []
|
||||
|
||||
@ -61,7 +61,7 @@ class ArXivParam(ToolParamBase):
|
||||
class ArXiv(ToolBase, ABC):
|
||||
component_name = "ArXiv"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -97,6 +97,6 @@ class ArXiv(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -129,7 +129,7 @@ module.exports = { main };
|
||||
class CodeExec(ToolBase, ABC):
|
||||
component_name = "CodeExec"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
lang = kwargs.get("lang", self._param.lang)
|
||||
script = kwargs.get("script", self._param.script)
|
||||
@ -157,7 +157,7 @@ class CodeExec(ToolBase, ABC):
|
||||
|
||||
try:
|
||||
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:")
|
||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run", code_req, resp.status_code)
|
||||
if resp.status_code != 200:
|
||||
resp.raise_for_status()
|
||||
body = resp.json()
|
||||
|
||||
@ -73,7 +73,7 @@ class DuckDuckGoParam(ToolParamBase):
|
||||
class DuckDuckGo(ToolBase, ABC):
|
||||
component_name = "DuckDuckGo"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -115,6 +115,6 @@ class DuckDuckGo(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -98,8 +98,8 @@ class EmailParam(ToolParamBase):
|
||||
|
||||
class Email(ToolBase, ABC):
|
||||
component_name = "Email"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("to_email"):
|
||||
self.set_output("success", False)
|
||||
@ -212,4 +212,4 @@ class Email(ToolBase, ABC):
|
||||
To: {}
|
||||
Subject: {}
|
||||
Your email is on its way—sit tight!
|
||||
""".format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!"))
|
||||
""".format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!"))
|
||||
|
||||
@ -53,7 +53,7 @@ class ExeSQLParam(ToolParamBase):
|
||||
self.max_records = 1024
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql'])
|
||||
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2'])
|
||||
self.check_empty(self.database, "Database name")
|
||||
self.check_empty(self.username, "database username")
|
||||
self.check_empty(self.host, "IP Address")
|
||||
@ -78,7 +78,7 @@ class ExeSQLParam(ToolParamBase):
|
||||
class ExeSQL(ToolBase, ABC):
|
||||
component_name = "ExeSQL"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
|
||||
def convert_decimals(obj):
|
||||
@ -123,6 +123,55 @@ class ExeSQL(ToolBase, ABC):
|
||||
r'PWD=' + self._param.password
|
||||
)
|
||||
db = pyodbc.connect(conn_str)
|
||||
elif self._param.db_type == 'IBM DB2':
|
||||
import ibm_db
|
||||
conn_str = (
|
||||
f"DATABASE={self._param.database};"
|
||||
f"HOSTNAME={self._param.host};"
|
||||
f"PORT={self._param.port};"
|
||||
f"PROTOCOL=TCPIP;"
|
||||
f"UID={self._param.username};"
|
||||
f"PWD={self._param.password};"
|
||||
)
|
||||
try:
|
||||
conn = ibm_db.connect(conn_str, "", "")
|
||||
except Exception as e:
|
||||
raise Exception("Database Connection Failed! \n" + str(e))
|
||||
|
||||
sql_res = []
|
||||
formalized_content = []
|
||||
for single_sql in sqls:
|
||||
single_sql = single_sql.replace("```", "").strip()
|
||||
if not single_sql:
|
||||
continue
|
||||
single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql)
|
||||
|
||||
stmt = ibm_db.exec_immediate(conn, single_sql)
|
||||
rows = []
|
||||
row = ibm_db.fetch_assoc(stmt)
|
||||
while row and len(rows) < self._param.max_records:
|
||||
rows.append(row)
|
||||
row = ibm_db.fetch_assoc(stmt)
|
||||
|
||||
if not rows:
|
||||
sql_res.append({"content": "No record in the database!"})
|
||||
continue
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
for col in df.columns:
|
||||
if pd.api.types.is_datetime64_any_dtype(df[col]):
|
||||
df[col] = df[col].dt.strftime("%Y-%m-%d")
|
||||
|
||||
df = df.where(pd.notnull(df), None)
|
||||
|
||||
sql_res.append(convert_decimals(df.to_dict(orient="records")))
|
||||
formalized_content.append(df.to_markdown(index=False, floatfmt=".6f"))
|
||||
|
||||
ibm_db.close(conn)
|
||||
|
||||
self.set_output("json", sql_res)
|
||||
self.set_output("formalized_content", "\n\n".join(formalized_content))
|
||||
return self.output("formalized_content")
|
||||
try:
|
||||
cursor = db.cursor()
|
||||
except Exception as e:
|
||||
@ -150,6 +199,8 @@ class ExeSQL(ToolBase, ABC):
|
||||
if pd.api.types.is_datetime64_any_dtype(single_res[col]):
|
||||
single_res[col] = single_res[col].dt.strftime('%Y-%m-%d')
|
||||
|
||||
single_res = single_res.where(pd.notnull(single_res), None)
|
||||
|
||||
sql_res.append(convert_decimals(single_res.to_dict(orient='records')))
|
||||
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ class GitHubParam(ToolParamBase):
|
||||
class GitHub(ToolBase, ABC):
|
||||
component_name = "GitHub"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -88,4 +88,4 @@ class GitHub(ToolBase, ABC):
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||
return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -116,7 +116,7 @@ class GoogleParam(ToolParamBase):
|
||||
class Google(ToolBase, ABC):
|
||||
component_name = "Google"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("q"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -154,6 +154,6 @@ class Google(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -63,7 +63,7 @@ class GoogleScholarParam(ToolParamBase):
|
||||
class GoogleScholar(ToolBase, ABC):
|
||||
component_name = "GoogleScholar"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -93,4 +93,4 @@ class GoogleScholar(ToolBase, ABC):
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -33,7 +33,7 @@ class PubMedParam(ToolParamBase):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "pubmed_search",
|
||||
"description": """
|
||||
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
|
||||
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
|
||||
In addition to MEDLINE, PubMed provides access to:
|
||||
- older references from the print version of Index Medicus, back to 1951 and earlier
|
||||
- references to some journals before they were indexed in Index Medicus and MEDLINE, for instance Science, BMJ, and Annals of Surgery
|
||||
@ -69,7 +69,7 @@ In addition to MEDLINE, PubMed provides access to:
|
||||
class PubMed(ToolBase, ABC):
|
||||
component_name = "PubMed"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -105,4 +105,4 @@ class PubMed(ToolBase, ABC):
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -74,7 +74,7 @@ class RetrievalParam(ToolParamBase):
|
||||
class Retrieval(ToolBase, ABC):
|
||||
component_name = "Retrieval"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
@ -164,18 +164,18 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
# Format the chunks for JSON output (similar to how other tools do it)
|
||||
json_output = kbinfos["chunks"].copy()
|
||||
|
||||
|
||||
self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"])
|
||||
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
|
||||
|
||||
|
||||
# Set both formalized content and JSON output
|
||||
self.set_output("formalized_content", form_cnt)
|
||||
self.set_output("json", json_output)
|
||||
|
||||
|
||||
return form_cnt
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -77,7 +77,7 @@ class SearXNGParam(ToolParamBase):
|
||||
class SearXNG(ToolBase, ABC):
|
||||
component_name = "SearXNG"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
# Gracefully handle try-run without inputs
|
||||
query = kwargs.get("query")
|
||||
@ -94,7 +94,6 @@ class SearXNG(ToolBase, ABC):
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
try:
|
||||
# 构建搜索参数
|
||||
search_params = {
|
||||
'q': query,
|
||||
'format': 'json',
|
||||
@ -104,33 +103,29 @@ class SearXNG(ToolBase, ABC):
|
||||
'pageno': 1
|
||||
}
|
||||
|
||||
# 发送搜索请求
|
||||
response = requests.get(
|
||||
f"{searxng_url}/search",
|
||||
params=search_params,
|
||||
timeout=10
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
data = response.json()
|
||||
|
||||
# 验证响应数据
|
||||
|
||||
if not data or not isinstance(data, dict):
|
||||
raise ValueError("Invalid response from SearXNG")
|
||||
|
||||
|
||||
results = data.get("results", [])
|
||||
if not isinstance(results, list):
|
||||
raise ValueError("Invalid results format from SearXNG")
|
||||
|
||||
# 限制结果数量
|
||||
|
||||
results = results[:self._param.top_n]
|
||||
|
||||
# 处理搜索结果
|
||||
|
||||
self._retrieve_chunks(results,
|
||||
get_title=lambda r: r.get("title", ""),
|
||||
get_url=lambda r: r.get("url", ""),
|
||||
get_content=lambda r: r.get("content", ""))
|
||||
|
||||
|
||||
self.set_output("json", results)
|
||||
return self.output("formalized_content")
|
||||
|
||||
@ -151,6 +146,6 @@ class SearXNG(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Keywords: {}
|
||||
Searching with SearXNG for relevant results...
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -31,7 +31,7 @@ class TavilySearchParam(ToolParamBase):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "tavily_search",
|
||||
"description": """
|
||||
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
|
||||
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
|
||||
When searching:
|
||||
- Start with specific query which should focus on just a single aspect.
|
||||
- Number of keywords in query should be less than 5.
|
||||
@ -101,7 +101,7 @@ When searching:
|
||||
class TavilySearch(ToolBase, ABC):
|
||||
component_name = "TavilySearch"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -136,7 +136,7 @@ class TavilySearch(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -199,7 +199,7 @@ class TavilyExtractParam(ToolParamBase):
|
||||
class TavilyExtract(ToolBase, ABC):
|
||||
component_name = "TavilyExtract"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
||||
last_e = None
|
||||
@ -224,4 +224,4 @@ class TavilyExtract(ToolBase, ABC):
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))
|
||||
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))
|
||||
|
||||
@ -68,7 +68,7 @@ fund selection platform: through AI technology, is committed to providing excell
|
||||
class WenCai(ToolBase, ABC):
|
||||
component_name = "WenCai"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("report", "")
|
||||
@ -111,4 +111,4 @@ class WenCai(ToolBase, ABC):
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -64,7 +64,7 @@ class WikipediaParam(ToolParamBase):
|
||||
class Wikipedia(ToolBase, ABC):
|
||||
component_name = "Wikipedia"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
@ -99,6 +99,6 @@ class Wikipedia(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@ -72,7 +72,7 @@ class YahooFinanceParam(ToolParamBase):
|
||||
class YahooFinance(ToolBase, ABC):
|
||||
component_name = "YahooFinance"
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("stock_code"):
|
||||
self.set_output("report", "")
|
||||
@ -111,4 +111,4 @@ class YahooFinance(ToolBase, ABC):
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!"))
|
||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!"))
|
||||
|
||||
@ -27,7 +27,8 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from api.db import StatusEnum
|
||||
from api.db.db_models import close_connection
|
||||
from api.db.services import UserService
|
||||
from api.utils import CustomJSONEncoder, commands
|
||||
from api.utils.json import CustomJSONEncoder
|
||||
from api.utils import commands
|
||||
|
||||
from flask_mail import Mail
|
||||
from flask_session import Session
|
||||
|
||||
@ -19,15 +19,19 @@ import re
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
import flask
|
||||
import trio
|
||||
from flask import request, Response
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from agent.component.llm import LLM
|
||||
from agent.component import LLM
|
||||
from api import settings
|
||||
from api.db import CanvasCategory, FileType
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, TaskService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.settings import RetCode
|
||||
@ -35,10 +39,12 @@ from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
||||
from agent.canvas import Canvas
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.db_models import APIToken, Task
|
||||
import time
|
||||
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
||||
from rag.flow.pipeline import Pipeline
|
||||
from rag.nlp import search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
@ -48,14 +54,6 @@ def templates():
|
||||
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.Agent)])
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def canvas_list():
|
||||
return get_json_result(data=sorted([c.to_dict() for c in \
|
||||
UserCanvasService.query(user_id=current_user.id, canvas_category=CanvasCategory.Agent)], key=lambda x: x["update_time"]*-1)
|
||||
)
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@validate_request("canvas_ids")
|
||||
@login_required
|
||||
@ -77,9 +75,10 @@ def save():
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
cate = req.get("canvas_category", CanvasCategory.Agent)
|
||||
if "id" not in req:
|
||||
req["user_id"] = current_user.id
|
||||
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=CanvasCategory.Agent):
|
||||
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=cate):
|
||||
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
|
||||
req["id"] = get_uuid()
|
||||
if not UserCanvasService.save(**req):
|
||||
@ -101,7 +100,7 @@ def save():
|
||||
def get(canvas_id):
|
||||
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
return get_json_result(data=c)
|
||||
|
||||
|
||||
@ -148,6 +147,14 @@ def run():
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
task_id = get_uuid()
|
||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
|
||||
if not ok:
|
||||
return get_data_error_result(message=error_message)
|
||||
return get_json_result(data={"message_id": task_id})
|
||||
|
||||
try:
|
||||
canvas = Canvas(cvs.dsl, current_user.id, req["id"])
|
||||
except Exception as e:
|
||||
@ -173,6 +180,44 @@ def run():
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "dsl", "component_id")
|
||||
@login_required
|
||||
def rerun():
|
||||
req = request.json
|
||||
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||
if not doc:
|
||||
return get_data_error_result(message="Document not found.")
|
||||
doc = doc[0]
|
||||
if 0 < doc["progress"] < 1:
|
||||
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
||||
|
||||
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
|
||||
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
||||
doc["progress_msg"] = ""
|
||||
doc["chunk_num"] = 0
|
||||
doc["token_num"] = 0
|
||||
DocumentService.clear_chunk_num_when_rerun(doc["id"])
|
||||
DocumentService.update_by_id(id, doc)
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
|
||||
dsl = req["dsl"]
|
||||
dsl["path"] = [req["component_id"]]
|
||||
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
|
||||
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
|
||||
@login_required
|
||||
def cancel(task_id):
|
||||
try:
|
||||
REDIS_CONN.set(f"{task_id}-cancel", "x")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/reset', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
@ -198,7 +243,7 @@ def reset():
|
||||
|
||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||
def upload(canvas_id):
|
||||
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
@ -348,6 +393,22 @@ def test_db_connect():
|
||||
cursor = db.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.close()
|
||||
elif req["db_type"] == 'IBM DB2':
|
||||
import ibm_db
|
||||
conn_str = (
|
||||
f"DATABASE={req['database']};"
|
||||
f"HOSTNAME={req['host']};"
|
||||
f"PORT={req['port']};"
|
||||
f"PROTOCOL=TCPIP;"
|
||||
f"UID={req['username']};"
|
||||
f"PWD={req['password']};"
|
||||
)
|
||||
logging.info(conn_str)
|
||||
conn = ibm_db.connect(conn_str, "", "")
|
||||
stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1")
|
||||
ibm_db.fetch_assoc(stmt)
|
||||
ibm_db.close(conn)
|
||||
return get_json_result(data="Database Connection Successful!")
|
||||
else:
|
||||
return server_error_response("Unsupported database type.")
|
||||
if req["db_type"] != 'mssql':
|
||||
@ -383,22 +444,32 @@ def getversion( version_id):
|
||||
return get_json_result(data=f"Error getting history file: {e}")
|
||||
|
||||
|
||||
@manager.route('/listteam', methods=['GET']) # noqa: F821
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def list_canvas():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 150))
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
desc = request.args.get("desc", True)
|
||||
try:
|
||||
canvas_category = request.args.get("canvas_category")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
owner_ids = [id for id in request.args.get("owner_ids", "").strip().split(",") if id]
|
||||
if not owner_ids:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
tenants = [m["tenant_id"] for m in tenants]
|
||||
tenants.append(current_user.id)
|
||||
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||
[m["tenant_id"] for m in tenants], current_user.id, page_number,
|
||||
items_per_page, orderby, desc, keywords, canvas_category=CanvasCategory.Agent)
|
||||
return get_json_result(data={"canvas": canvas, "total": total})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
tenants, current_user.id, page_number,
|
||||
items_per_page, orderby, desc, keywords, canvas_category)
|
||||
else:
|
||||
tenants = owner_ids
|
||||
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||
tenants, current_user.id, 0,
|
||||
0, orderby, desc, keywords, canvas_category)
|
||||
return get_json_result(data={"canvas": canvas, "total": total})
|
||||
|
||||
|
||||
@manager.route('/setting', methods=['POST']) # noqa: F821
|
||||
@ -483,3 +554,11 @@ def prompts():
|
||||
#"context_ranking": RANK_MEMORY,
|
||||
"citation_guidelines": CITATION_PROMPT_TEMPLATE
|
||||
})
|
||||
|
||||
|
||||
@manager.route('/download', methods=['GET']) # noqa: F821
|
||||
def download():
|
||||
id = request.args.get("id")
|
||||
created_by = request.args.get("created_by")
|
||||
blob = FileService.get_blob(created_by, id)
|
||||
return flask.make_response(blob)
|
||||
@ -1,353 +0,0 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from agent.component.llm import LLM
|
||||
from api.db import CanvasCategory, FileType
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.task_service import queue_dataflow
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.settings import RetCode
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
||||
from rag.flow.pipeline import Pipeline
|
||||
|
||||
|
||||
@manager.route("/templates", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def templates():
|
||||
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.DataFlow)])
|
||||
|
||||
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def canvas_list():
|
||||
return get_json_result(data=sorted([c.to_dict() for c in UserCanvasService.query(user_id=current_user.id, canvas_category=CanvasCategory.DataFlow)], key=lambda x: x["update_time"] * -1))
|
||||
|
||||
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@validate_request("canvas_ids")
|
||||
@login_required
|
||||
def rm():
|
||||
for i in request.json["canvas_ids"]:
|
||||
if not UserCanvasService.accessible(i, current_user.id):
|
||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
UserCanvasService.delete_by_id(i)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@validate_request("dsl", "title")
|
||||
@login_required
|
||||
def save():
|
||||
req = request.json
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
req["canvas_category"] = CanvasCategory.DataFlow
|
||||
if "id" not in req:
|
||||
req["user_id"] = current_user.id
|
||||
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=CanvasCategory.DataFlow):
|
||||
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
|
||||
req["id"] = get_uuid()
|
||||
|
||||
if not UserCanvasService.save(**req):
|
||||
return get_data_error_result(message="Fail to save canvas.")
|
||||
else:
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
UserCanvasService.update_by_id(req["id"], req)
|
||||
# save version
|
||||
UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
|
||||
UserCanvasVersionService.delete_all_versions(req["id"])
|
||||
return get_json_result(data=req)
|
||||
|
||||
|
||||
@manager.route("/get/<canvas_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def get(canvas_id):
|
||||
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
return get_json_result(data=c)
|
||||
|
||||
|
||||
@manager.route("/run", methods=["POST"]) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def run():
|
||||
req = request.json
|
||||
flow_id = req.get("id", "")
|
||||
doc_id = req.get("doc_id", "")
|
||||
if not all([flow_id, doc_id]):
|
||||
return get_data_error_result(message="id and doc_id are required.")
|
||||
|
||||
if not DocumentService.get_by_id(doc_id):
|
||||
return get_data_error_result(message=f"Document for {doc_id} not found.")
|
||||
|
||||
user_id = req.get("user_id", current_user.id)
|
||||
if not UserCanvasService.accessible(flow_id, current_user.id):
|
||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(flow_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
task_id = get_uuid()
|
||||
|
||||
ok, error_message = queue_dataflow(dsl=cvs.dsl, tenant_id=user_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id, priority=0)
|
||||
if not ok:
|
||||
return server_error_response(error_message)
|
||||
|
||||
return get_json_result(data={"task_id": task_id, "flow_id": flow_id})
|
||||
|
||||
|
||||
@manager.route("/reset", methods=["POST"]) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def reset():
|
||||
req = request.json
|
||||
flow_id = req.get("id", "")
|
||||
if not flow_id:
|
||||
return get_data_error_result(message="id is required.")
|
||||
|
||||
if not UserCanvasService.accessible(flow_id, current_user.id):
|
||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
|
||||
task_id = req.get("task_id", "")
|
||||
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id=task_id)
|
||||
dataflow.reset()
|
||||
req["dsl"] = json.loads(str(dataflow))
|
||||
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
|
||||
return get_json_result(data=req["dsl"])
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||
def upload(canvas_id):
|
||||
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
user_id = cvs["user_id"]
|
||||
|
||||
def structured(filename, filetype, blob, content_type):
|
||||
nonlocal user_id
|
||||
if filetype == FileType.PDF.value:
|
||||
blob = read_potential_broken_pdf(blob)
|
||||
|
||||
location = get_uuid()
|
||||
FileService.put_blob(user_id, location, blob)
|
||||
|
||||
return {
|
||||
"id": location,
|
||||
"name": filename,
|
||||
"size": sys.getsizeof(blob),
|
||||
"extension": filename.split(".")[-1].lower(),
|
||||
"mime_type": content_type,
|
||||
"created_by": user_id,
|
||||
"created_at": time.time(),
|
||||
"preview_url": None,
|
||||
}
|
||||
|
||||
if request.args.get("url"):
|
||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CrawlResult, DefaultMarkdownGenerator, PruningContentFilter
|
||||
|
||||
try:
|
||||
url = request.args.get("url")
|
||||
filename = re.sub(r"\?.*", "", url.split("/")[-1])
|
||||
|
||||
async def adownload():
|
||||
browser_config = BrowserConfig(
|
||||
headless=True,
|
||||
verbose=False,
|
||||
)
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
crawler_config = CrawlerRunConfig(markdown_generator=DefaultMarkdownGenerator(content_filter=PruningContentFilter()), pdf=True, screenshot=False)
|
||||
result: CrawlResult = await crawler.arun(url=url, config=crawler_config)
|
||||
return result
|
||||
|
||||
page = trio.run(adownload())
|
||||
if page.pdf:
|
||||
if filename.split(".")[-1].lower() != "pdf":
|
||||
filename += ".pdf"
|
||||
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
|
||||
|
||||
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
file = request.files["file"]
|
||||
try:
|
||||
DocumentService.check_doc_health(user_id, file.filename)
|
||||
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/input_form", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def input_form():
|
||||
flow_id = request.args.get("id")
|
||||
cpn_id = request.args.get("component_id")
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(flow_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
if not UserCanvasService.query(user_id=current_user.id, id=flow_id):
|
||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
|
||||
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id="")
|
||||
|
||||
return get_json_result(data=dataflow.get_component_input_form(cpn_id))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/debug", methods=["POST"]) # noqa: F821
|
||||
@validate_request("id", "component_id", "params")
|
||||
@login_required
|
||||
def debug():
|
||||
req = request.json
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
||||
canvas.reset()
|
||||
canvas.message_id = get_uuid()
|
||||
component = canvas.get_component(req["component_id"])["obj"]
|
||||
component.reset()
|
||||
|
||||
if isinstance(component, LLM):
|
||||
component.set_debug_inputs(req["params"])
|
||||
component.invoke(**{k: o["value"] for k, o in req["params"].items()})
|
||||
outputs = component.output()
|
||||
for k in outputs.keys():
|
||||
if isinstance(outputs[k], partial):
|
||||
txt = ""
|
||||
for c in outputs[k]():
|
||||
txt += c
|
||||
outputs[k] = txt
|
||||
return get_json_result(data=outputs)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
# api get list version dsl of canvas
|
||||
@manager.route("/getlistversion/<canvas_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def getlistversion(canvas_id):
|
||||
try:
|
||||
list = sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"] * -1)
|
||||
return get_json_result(data=list)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"Error getting history files: {e}")
|
||||
|
||||
|
||||
# api get version dsl of canvas
|
||||
@manager.route("/getversion/<version_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def getversion(version_id):
|
||||
try:
|
||||
e, version = UserCanvasVersionService.get_by_id(version_id)
|
||||
if version:
|
||||
return get_json_result(data=version.to_dict())
|
||||
except Exception as e:
|
||||
return get_json_result(data=f"Error getting history file: {e}")
|
||||
|
||||
|
||||
@manager.route("/listteam", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_canvas():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 150))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
desc = request.args.get("desc", True)
|
||||
try:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords, canvas_category=CanvasCategory.DataFlow
|
||||
)
|
||||
return get_json_result(data={"canvas": canvas, "total": total})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/setting", methods=["POST"]) # noqa: F821
|
||||
@validate_request("id", "title", "permission")
|
||||
@login_required
|
||||
def setting():
|
||||
req = request.json
|
||||
req["user_id"] = current_user.id
|
||||
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, flow = UserCanvasService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
flow = flow.to_dict()
|
||||
flow["title"] = req["title"]
|
||||
for key in ("description", "permission", "avatar"):
|
||||
if value := req.get(key):
|
||||
flow[key] = value
|
||||
|
||||
num = UserCanvasService.update_by_id(req["id"], flow)
|
||||
return get_json_result(data=num)
|
||||
|
||||
|
||||
@manager.route("/trace", methods=["GET"]) # noqa: F821
|
||||
def trace():
|
||||
dataflow_id = request.args.get("dataflow_id")
|
||||
task_id = request.args.get("task_id")
|
||||
if not all([dataflow_id, task_id]):
|
||||
return get_data_error_result(message="dataflow_id and task_id are required.")
|
||||
|
||||
e, dataflow_canvas = UserCanvasService.get_by_id(dataflow_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="dataflow not found.")
|
||||
|
||||
dsl_str = json.dumps(dataflow_canvas.dsl, ensure_ascii=False)
|
||||
dataflow = Pipeline(dsl=dsl_str, tenant_id=dataflow_canvas.user_id, flow_id=dataflow_id, task_id=task_id)
|
||||
log = dataflow.fetch_logs()
|
||||
|
||||
return get_json_result(data=log)
|
||||
@ -32,7 +32,7 @@ from api.db.services.document_service import DocumentService, doc_upload_and_par
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks, queue_dataflow
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import (
|
||||
@ -182,6 +182,7 @@ def create():
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"pipeline_id": kb.pipeline_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": current_user.id,
|
||||
"type": FileType.VIRTUAL,
|
||||
@ -479,8 +480,11 @@ def run():
|
||||
kb_table_num_map[kb_id] = count
|
||||
if kb_table_num_map[kb_id] <= 0:
|
||||
KnowledgebaseService.delete_field_map(kb_id)
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
if doc.get("pipeline_id", ""):
|
||||
queue_dataflow(tenant_id, flow_id=doc["pipeline_id"], task_id=get_uuid(), doc_id=id)
|
||||
else:
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
@ -546,31 +550,22 @@ def get(doc_id):
|
||||
|
||||
@manager.route("/change_parser", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "parser_id")
|
||||
@validate_request("doc_id")
|
||||
def change_parser():
|
||||
req = request.json
|
||||
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if doc.parser_id.lower() == req["parser_id"].lower():
|
||||
if "parser_config" in req:
|
||||
if req["parser_config"] == doc.parser_config:
|
||||
return get_json_result(data=True)
|
||||
else:
|
||||
return get_json_result(data=True)
|
||||
|
||||
if (doc.type == FileType.VISUAL and req["parser_id"] != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation"):
|
||||
return get_data_error_result(message="Not supported yet!")
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
def reset_doc():
|
||||
nonlocal doc
|
||||
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": TaskStatus.UNSTART.value})
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if "parser_config" in req:
|
||||
DocumentService.update_parser_config(doc.id, req["parser_config"])
|
||||
if doc.token_num > 0:
|
||||
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duration * -1)
|
||||
if not e:
|
||||
@ -581,6 +576,26 @@ def change_parser():
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
try:
|
||||
if "pipeline_id" in req:
|
||||
if doc.pipeline_id == req["pipeline_id"]:
|
||||
return get_json_result(data=True)
|
||||
DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"]})
|
||||
reset_doc()
|
||||
return get_json_result(data=True)
|
||||
|
||||
if doc.parser_id.lower() == req["parser_id"].lower():
|
||||
if "parser_config" in req:
|
||||
if req["parser_config"] == doc.parser_config:
|
||||
return get_json_result(data=True)
|
||||
else:
|
||||
return get_json_result(data=True)
|
||||
|
||||
if (doc.type == FileType.VISUAL and req["parser_id"] != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation"):
|
||||
return get_data_error_result(message="Not supported yet!")
|
||||
if "parser_config" in req:
|
||||
DocumentService.update_parser_config(doc.id, req["parser_config"])
|
||||
reset_doc()
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -246,6 +246,8 @@ def rm():
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if file.tenant_id != current_user.id:
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||
continue
|
||||
|
||||
@ -292,6 +294,8 @@ def rename():
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="File not found!")
|
||||
if file.tenant_id != current_user.id:
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
if file.type != FileType.FOLDER.value \
|
||||
and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||
file.name.lower()).suffix:
|
||||
@ -328,6 +332,8 @@ def get(file_id):
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if file.tenant_id != current_user.id:
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
blob = STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
if not blob:
|
||||
@ -367,6 +373,8 @@ def move():
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if file.tenant_id != current_user.id:
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
fe, _ = FileService.get_by_id(parent_id)
|
||||
if not fe:
|
||||
return get_data_error_result(message="Parent Folder not found!")
|
||||
|
||||
@ -14,18 +14,21 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||
from api.utils import get_uuid
|
||||
from api.db import StatusEnum, FileSource
|
||||
from api.db import PipelineTaskType, StatusEnum, FileSource, VALID_FILE_TYPES, VALID_TASK_STATUS
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
from api.utils.api_utils import get_json_result
|
||||
@ -35,7 +38,6 @@ from api.constants import DATASET_NAME_LIMIT
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
@manager.route('/create', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
@ -61,10 +63,39 @@ def create():
|
||||
req["name"] = dataset_name
|
||||
req["tenant_id"] = current_user.id
|
||||
req["created_by"] = current_user.id
|
||||
if not req.get("parser_id"):
|
||||
req["parser_id"] = "naive"
|
||||
e, t = TenantService.get_by_id(current_user.id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Tenant not found.")
|
||||
req["embd_id"] = t.embd_id
|
||||
req["parser_config"] = {
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 512,
|
||||
"delimiter": "\n",
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"html4excel": False,
|
||||
"topn_tags": 3,
|
||||
"raptor": {
|
||||
"use_raptor": True,
|
||||
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
|
||||
"max_token": 256,
|
||||
"threshold": 0.1,
|
||||
"max_cluster": 64,
|
||||
"random_seed": 0
|
||||
},
|
||||
"graphrag": {
|
||||
"use_graphrag": True,
|
||||
"entity_types": [
|
||||
"organization",
|
||||
"person",
|
||||
"geo",
|
||||
"event",
|
||||
"category"
|
||||
],
|
||||
"method": "light"
|
||||
}
|
||||
}
|
||||
if not KnowledgebaseService.save(**req):
|
||||
return get_data_error_result()
|
||||
return get_json_result(data={"kb_id": req["id"]})
|
||||
@ -395,3 +426,352 @@ def get_basic_info():
|
||||
basic_info = DocumentService.knowledgebase_basic_info(kb_id)
|
||||
|
||||
return get_json_result(data=basic_info)
|
||||
|
||||
|
||||
@manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_pipeline_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
keywords = request.args.get("keywords", "")
|
||||
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
create_date_from = request.args.get("create_date_from", "")
|
||||
create_date_to = request.args.get("create_date_to", "")
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = request.get_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
types = req.get("types", [])
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
|
||||
|
||||
suffix = req.get("suffix", [])
|
||||
|
||||
try:
|
||||
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from, create_date_to)
|
||||
return get_json_result(data={"total": tol, "logs": logs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_pipeline_dataset_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
create_date_from = request.args.get("create_date_from", "")
|
||||
create_date_to = request.args.get("create_date_to", "")
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = request.get_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
try:
|
||||
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from, create_date_to)
|
||||
return get_json_result(data={"total": tol, "logs": logs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def delete_pipeline_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
req = request.get_json()
|
||||
log_ids = req.get("log_ids", [])
|
||||
|
||||
PipelineOperationLogService.delete_by_ids(log_ids)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/pipeline_log_detail", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def pipeline_log_detail():
|
||||
log_id = request.args.get("log_id")
|
||||
if not log_id:
|
||||
return get_json_result(data=False, message='Lack of "Pipeline log ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
ok, log = PipelineOperationLogService.get_by_id(log_id)
|
||||
if not ok:
|
||||
return get_data_error_result(message="Invalid pipeline log ID")
|
||||
|
||||
return get_json_result(data=log.to_dict())
|
||||
|
||||
|
||||
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_graphrag():
|
||||
req = request.json
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
||||
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
|
||||
|
||||
return get_json_result(data={"graphrag_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/trace_graphrag", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def trace_graphrag():
|
||||
kb_id = request.args.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if not task_id:
|
||||
return get_json_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="GraphRAG Task Not Found or Error Occurred")
|
||||
|
||||
return get_json_result(data=task.to_dict())
|
||||
|
||||
|
||||
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_raptor():
|
||||
req = request.json
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid RAPTOR task id is expected for kb {kb_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
|
||||
logging.warning(f"Cannot save raptor_task_id for kb {kb_id}")
|
||||
|
||||
return get_json_result(data={"raptor_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/trace_raptor", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def trace_raptor():
|
||||
kb_id = request.args.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if not task_id:
|
||||
return get_json_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
|
||||
|
||||
return get_json_result(data=task.to_dict())
|
||||
|
||||
|
||||
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_mindmap():
|
||||
req = request.json
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.mindmap_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid Mindmap task id is expected for kb {kb_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Mindmap Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}):
|
||||
logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}")
|
||||
|
||||
return get_json_result(data={"mindmap_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/trace_mindmap", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def trace_mindmap():
|
||||
kb_id = request.args.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.mindmap_task_id
|
||||
if not task_id:
|
||||
return get_json_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Mindmap Task Not Found or Error Occurred")
|
||||
|
||||
return get_json_result(data=task.to_dict())
|
||||
|
||||
|
||||
@manager.route("/unbind_task", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
def delete_kb_task():
|
||||
kb_id = request.args.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_json_result(data=True)
|
||||
|
||||
pipeline_task_type = request.args.get("pipeline_task_type", "")
|
||||
if not pipeline_task_type or pipeline_task_type not in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
|
||||
return get_error_data_result(message="Invalid task type")
|
||||
|
||||
match pipeline_task_type:
|
||||
case PipelineTaskType.GRAPH_RAG:
|
||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
kb_task_id = "graphrag_task_id"
|
||||
kb_task_finish_at = "graphrag_task_finish_at"
|
||||
case PipelineTaskType.RAPTOR:
|
||||
kb_task_id = "raptor_task_id"
|
||||
kb_task_finish_at = "raptor_task_finish_at"
|
||||
case PipelineTaskType.MINDMAP:
|
||||
kb_task_id = "mindmap_task_id"
|
||||
kb_task_finish_at = "mindmap_task_finish_at"
|
||||
case _:
|
||||
return get_error_data_result(message="Internal Error: Invalid task type")
|
||||
|
||||
ok = KnowledgebaseService.update_by_id(kb_id, {kb_task_id: "", kb_task_finish_at: None})
|
||||
if not ok:
|
||||
return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}")
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -39,6 +39,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from flask import jsonify
|
||||
from api.utils.health_utils import run_health_checks
|
||||
|
||||
|
||||
@manager.route("/version", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def version():
|
||||
|
||||
@ -98,7 +98,14 @@ def login():
|
||||
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password")
|
||||
|
||||
user = UserService.query_user(email, password)
|
||||
if user:
|
||||
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return get_json_result(
|
||||
data=False,
|
||||
code=settings.RetCode.FORBIDDEN,
|
||||
message="This account has been disabled, please contact the administrator!",
|
||||
)
|
||||
elif user:
|
||||
response_data = user.to_json()
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
@ -227,6 +234,9 @@ def oauth_callback(channel):
|
||||
# User exists, try to log in
|
||||
user = users[0]
|
||||
user.access_token = get_uuid()
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
|
||||
login_user(user)
|
||||
user.save()
|
||||
return redirect(f"/?auth={user.get_id()}")
|
||||
@ -317,6 +327,8 @@ def github_callback():
|
||||
# User has already registered, try to log in
|
||||
user = users[0]
|
||||
user.access_token = get_uuid()
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
login_user(user)
|
||||
user.save()
|
||||
return redirect("/?auth=%s" % user.get_id())
|
||||
@ -418,6 +430,8 @@ def feishu_callback():
|
||||
|
||||
# User has already registered, try to log in
|
||||
user = users[0]
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.save()
|
||||
|
||||
2
api/common/README.md
Normal file
2
api/common/README.md
Normal file
@ -0,0 +1,2 @@
|
||||
The python files in this directory are shared between service. They contain common utilities, models, and functions that can be used across various
|
||||
services to ensure consistency and reduce code duplication.
|
||||
21
api/common/base64.py
Normal file
21
api/common/base64.py
Normal file
@ -0,0 +1,21 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import base64
|
||||
|
||||
def encode_to_base64(input_string):
|
||||
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
||||
return base64_encoded.decode('utf-8')
|
||||
@ -23,6 +23,11 @@ class StatusEnum(Enum):
|
||||
INVALID = "0"
|
||||
|
||||
|
||||
class ActiveEnum(Enum):
|
||||
ACTIVE = "1"
|
||||
INACTIVE = "0"
|
||||
|
||||
|
||||
class UserTenantRole(StrEnum):
|
||||
OWNER = 'owner'
|
||||
ADMIN = 'admin'
|
||||
@ -122,4 +127,15 @@ class MCPServerType(StrEnum):
|
||||
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
||||
|
||||
|
||||
class PipelineTaskType(StrEnum):
|
||||
PARSE = "Parse"
|
||||
DOWNLOAD = "Download"
|
||||
RAPTOR = "RAPTOR"
|
||||
GRAPH_RAG = "GraphRAG"
|
||||
MINDMAP = "Mindmap"
|
||||
|
||||
|
||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
|
||||
|
||||
|
||||
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
||||
|
||||
@ -26,12 +26,14 @@ from functools import wraps
|
||||
|
||||
from flask_login import UserMixin
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||
|
||||
from api import settings, utils
|
||||
from api.db import ParserType, SerializedType
|
||||
from api.utils.json import json_dumps, json_loads
|
||||
from api.utils.configs import deserialize_b64, serialize_b64
|
||||
|
||||
|
||||
def singleton(cls, *args, **kw):
|
||||
@ -70,12 +72,12 @@ class JSONField(LongTextField):
|
||||
def db_value(self, value):
|
||||
if value is None:
|
||||
value = self.default_value
|
||||
return utils.json_dumps(value)
|
||||
return json_dumps(value)
|
||||
|
||||
def python_value(self, value):
|
||||
if not value:
|
||||
return self.default_value
|
||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
|
||||
|
||||
class ListField(JSONField):
|
||||
@ -91,21 +93,21 @@ class SerializedField(LongTextField):
|
||||
|
||||
def db_value(self, value):
|
||||
if self._serialized_type == SerializedType.PICKLE:
|
||||
return utils.serialize_b64(value, to_str=True)
|
||||
return serialize_b64(value, to_str=True)
|
||||
elif self._serialized_type == SerializedType.JSON:
|
||||
if value is None:
|
||||
return None
|
||||
return utils.json_dumps(value, with_type=True)
|
||||
return json_dumps(value, with_type=True)
|
||||
else:
|
||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
def python_value(self, value):
|
||||
if self._serialized_type == SerializedType.PICKLE:
|
||||
return utils.deserialize_b64(value)
|
||||
return deserialize_b64(value)
|
||||
elif self._serialized_type == SerializedType.JSON:
|
||||
if value is None:
|
||||
return {}
|
||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
else:
|
||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
@ -250,36 +252,63 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def execute_sql(self, sql, params=None, commit=True):
|
||||
from peewee import OperationalError
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return super().execute_sql(sql, params, commit)
|
||||
except OperationalError as e:
|
||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
||||
logging.warning(f"Lost connection (attempt {attempt + 1}/{self.max_retries}): {e}")
|
||||
except (OperationalError, InterfaceError) as e:
|
||||
error_codes = [2013, 2006]
|
||||
error_messages = ['', 'Lost connection']
|
||||
should_retry = (
|
||||
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||
(str(e) in error_messages) or
|
||||
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||
)
|
||||
|
||||
if should_retry and attempt < self.max_retries:
|
||||
logging.warning(
|
||||
f"Database connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
|
||||
)
|
||||
self._handle_connection_loss()
|
||||
time.sleep(self.retry_delay * (2**attempt))
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
logging.error(f"DB execution failure: {e}")
|
||||
raise
|
||||
return None
|
||||
|
||||
def _handle_connection_loss(self):
|
||||
self.close_all()
|
||||
self.connect()
|
||||
# self.close_all()
|
||||
# self.connect()
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self.connect()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to reconnect: {e}")
|
||||
time.sleep(0.1)
|
||||
self.connect()
|
||||
|
||||
def begin(self):
|
||||
from peewee import OperationalError
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return super().begin()
|
||||
except OperationalError as e:
|
||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
||||
logging.warning(f"Lost connection during transaction (attempt {attempt + 1}/{self.max_retries})")
|
||||
except (OperationalError, InterfaceError) as e:
|
||||
error_codes = [2013, 2006]
|
||||
error_messages = ['', 'Lost connection']
|
||||
|
||||
should_retry = (
|
||||
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||
(str(e) in error_messages) or
|
||||
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||
)
|
||||
|
||||
if should_retry and attempt < self.max_retries:
|
||||
logging.warning(
|
||||
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
|
||||
)
|
||||
self._handle_connection_loss()
|
||||
time.sleep(self.retry_delay * (2**attempt))
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
raise
|
||||
|
||||
@ -299,7 +328,16 @@ class BaseDataBase:
|
||||
def __init__(self):
|
||||
database_config = settings.DATABASE.copy()
|
||||
db_name = database_config.pop("name")
|
||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||
|
||||
pool_config = {
|
||||
'max_retries': 5,
|
||||
'retry_delay': 1,
|
||||
}
|
||||
database_config.update(pool_config)
|
||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(
|
||||
db_name, **database_config
|
||||
)
|
||||
# self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||
logging.info("init database on cluster mode successfully")
|
||||
|
||||
|
||||
@ -646,8 +684,17 @@ class Knowledgebase(DataBaseModel):
|
||||
vector_similarity_weight = FloatField(default=0.3, index=True)
|
||||
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
pagerank = IntegerField(default=0, index=False)
|
||||
|
||||
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
|
||||
graphrag_task_finish_at = DateTimeField(null=True)
|
||||
raptor_task_id = CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)
|
||||
raptor_task_finish_at = DateTimeField(null=True)
|
||||
mindmap_task_id = CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)
|
||||
mindmap_task_finish_at = DateTimeField(null=True)
|
||||
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||
|
||||
def __str__(self):
|
||||
@ -662,6 +709,7 @@ class Document(DataBaseModel):
|
||||
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
||||
kb_id = CharField(max_length=256, null=False, index=True)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="pipleline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
|
||||
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
|
||||
@ -904,6 +952,32 @@ class Search(DataBaseModel):
|
||||
db_table = "search"
|
||||
|
||||
|
||||
class PipelineOperationLog(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
document_id = CharField(max_length=32, index=True)
|
||||
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||
kb_id = CharField(max_length=32, null=False, index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
||||
pipeline_title = CharField(max_length=32, null=True, help_text="Pipeline title", index=True)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="Parser ID", index=True)
|
||||
document_name = CharField(max_length=255, null=False, help_text="File name")
|
||||
document_suffix = CharField(max_length=255, null=False, help_text="File suffix")
|
||||
document_type = CharField(max_length=255, null=False, help_text="Document type")
|
||||
source_from = CharField(max_length=255, null=False, help_text="Source")
|
||||
progress = FloatField(default=0, index=True)
|
||||
progress_msg = TextField(null=True, help_text="process message", default="")
|
||||
process_begin_at = DateTimeField(null=True, index=True)
|
||||
process_duration = FloatField(default=0)
|
||||
dsl = JSONField(null=True, default=dict)
|
||||
task_type = CharField(max_length=32, null=False, default="")
|
||||
operation_status = CharField(max_length=32, null=False, help_text="Operation status")
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "pipeline_operation_log"
|
||||
|
||||
|
||||
def migrate_db():
|
||||
logging.disable(logging.ERROR)
|
||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||
@ -1020,7 +1094,6 @@ def migrate_db():
|
||||
migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
migrate(migrator.alter_column_type("canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title")))
|
||||
except Exception:
|
||||
@ -1037,4 +1110,36 @@ def migrate_db():
|
||||
migrate(migrator.add_column("canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "raptor_task_finish_at", CharField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "mindmap_task_finish_at", CharField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@ -32,11 +31,7 @@ from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_l
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
def encode_to_base64(input_string):
|
||||
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
||||
return base64_encoded.decode('utf-8')
|
||||
from api.common.base64 import encode_to_base64
|
||||
|
||||
|
||||
def init_superuser():
|
||||
|
||||
327
api/db/joint_services/user_account_service.py
Normal file
327
api/db/joint_services/user_account_service.py
Normal file
@ -0,0 +1,327 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from api import settings
|
||||
from api.utils.api_utils import group_by
|
||||
from api.db import FileType, UserTenantRole, ActiveEnum
|
||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.conversation_service import ConversationService
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.llm_service import get_init_tenant_llm
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.nlp import search
|
||||
|
||||
|
||||
def create_new_user(user_info: dict) -> dict:
|
||||
"""
|
||||
Add a new user, and create tenant, tenant llm, file folder for new user.
|
||||
:param user_info: {
|
||||
"email": <example@example.com>,
|
||||
"nickname": <str, "name">,
|
||||
"password": <decrypted password>,
|
||||
"login_channel": <enum, "password">,
|
||||
"is_superuser": <bool, role == "admin">,
|
||||
}
|
||||
:return: {
|
||||
"success": <bool>,
|
||||
"user_info": <dict>, # if true, return user_info
|
||||
}
|
||||
"""
|
||||
# generate user_id and access_token for user
|
||||
user_id = uuid.uuid1().hex
|
||||
user_info['id'] = user_id
|
||||
user_info['access_token'] = uuid.uuid1().hex
|
||||
# construct tenant info
|
||||
tenant = {
|
||||
"id": user_id,
|
||||
"name": user_info["nickname"] + "‘s Kingdom",
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||
"rerank_id": settings.RERANK_MDL,
|
||||
}
|
||||
usr_tenant = {
|
||||
"tenant_id": user_id,
|
||||
"user_id": user_id,
|
||||
"invited_by": user_id,
|
||||
"role": UserTenantRole.OWNER,
|
||||
}
|
||||
# construct file folder info
|
||||
file_id = uuid.uuid1().hex
|
||||
file = {
|
||||
"id": file_id,
|
||||
"parent_id": file_id,
|
||||
"tenant_id": user_id,
|
||||
"created_by": user_id,
|
||||
"name": "/",
|
||||
"type": FileType.FOLDER.value,
|
||||
"size": 0,
|
||||
"location": "",
|
||||
}
|
||||
try:
|
||||
tenant_llm = get_init_tenant_llm(user_id)
|
||||
|
||||
if not UserService.save(**user_info):
|
||||
return {"success": False}
|
||||
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
FileService.insert(file)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"user_info": user_info,
|
||||
}
|
||||
|
||||
except Exception as create_error:
|
||||
logging.exception(create_error)
|
||||
# rollback
|
||||
try:
|
||||
TenantService.delete_by_id(user_id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
u = UserTenantService.query(tenant_id=user_id)
|
||||
if u:
|
||||
UserTenantService.delete_by_id(u[0].id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
TenantLLMService.delete_by_tenant_id(user_id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
FileService.delete_by_id(file["id"])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
# delete user row finally
|
||||
try:
|
||||
UserService.delete_by_id(user_id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
# reraise
|
||||
raise create_error
|
||||
|
||||
|
||||
def delete_user_data(user_id: str) -> dict:
|
||||
# use user_id to delete
|
||||
usr = UserService.filter_by_id(user_id)
|
||||
if not usr:
|
||||
return {"success": False, "message": f"{user_id} can't be found."}
|
||||
# check is inactive and not admin
|
||||
if usr.is_active == ActiveEnum.ACTIVE.value:
|
||||
return {"success": False, "message": f"{user_id} is active and can't be deleted."}
|
||||
if usr.is_superuser:
|
||||
return {"success": False, "message": "Can't delete the super user."}
|
||||
# tenant info
|
||||
tenants = UserTenantService.get_user_tenant_relation_by_user_id(usr.id)
|
||||
owned_tenant = [t for t in tenants if t["role"] == UserTenantRole.OWNER.value]
|
||||
|
||||
done_msg = ''
|
||||
try:
|
||||
# step1. delete owned tenant info
|
||||
if owned_tenant:
|
||||
done_msg += "Start to delete owned tenant.\n"
|
||||
tenant_id = owned_tenant[0]["tenant_id"]
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(usr.id)
|
||||
# step1.1 delete knowledgebase related file and info
|
||||
if kb_ids:
|
||||
# step1.1.1 delete files in storage, remove bucket
|
||||
for kb_id in kb_ids:
|
||||
if STORAGE_IMPL.bucket_exists(kb_id):
|
||||
STORAGE_IMPL.remove_bucket(kb_id)
|
||||
done_msg += f"- Removed {len(kb_ids)} dataset's buckets.\n"
|
||||
# step1.1.2 delete file and document info in db
|
||||
doc_ids = DocumentService.get_all_doc_ids_by_kb_ids(kb_ids)
|
||||
if doc_ids:
|
||||
doc_delete_res = DocumentService.delete_by_ids([i["id"] for i in doc_ids])
|
||||
done_msg += f"- Deleted {doc_delete_res} document records.\n"
|
||||
task_delete_res = TaskService.delete_by_doc_ids([i["id"] for i in doc_ids])
|
||||
done_msg += f"- Deleted {task_delete_res} task records.\n"
|
||||
file_ids = FileService.get_all_file_ids_by_tenant_id(usr.id)
|
||||
if file_ids:
|
||||
file_delete_res = FileService.delete_by_ids([f["id"] for f in file_ids])
|
||||
done_msg += f"- Deleted {file_delete_res} file records.\n"
|
||||
if doc_ids or file_ids:
|
||||
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
|
||||
[i["id"] for i in doc_ids],
|
||||
[f["id"] for f in file_ids]
|
||||
)
|
||||
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||
# step1.1.3 delete chunk in es
|
||||
r = settings.docStoreConn.delete({"kb_id": kb_ids},
|
||||
search.index_name(tenant_id), kb_ids)
|
||||
done_msg += f"- Deleted {r} chunk records.\n"
|
||||
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
|
||||
done_msg += f"- Deleted {kb_delete_res} knowledgebase records.\n"
|
||||
# step1.1.4 delete agents
|
||||
agent_delete_res = delete_user_agents(usr.id)
|
||||
done_msg += f"- Deleted {agent_delete_res['agents_deleted_count']} agent, {agent_delete_res['version_deleted_count']} versions records.\n"
|
||||
# step1.1.5 delete dialogs
|
||||
dialog_delete_res = delete_user_dialogs(usr.id)
|
||||
done_msg += f"- Deleted {dialog_delete_res['dialogs_deleted_count']} dialogs, {dialog_delete_res['conversations_deleted_count']} conversations, {dialog_delete_res['api_token_deleted_count']} api tokens, {dialog_delete_res['api4conversation_deleted_count']} api4conversations.\n"
|
||||
# step1.1.6 delete mcp server
|
||||
mcp_delete_res = MCPServerService.delete_by_tenant_id(usr.id)
|
||||
done_msg += f"- Deleted {mcp_delete_res} MCP server.\n"
|
||||
# step1.1.7 delete search
|
||||
search_delete_res = SearchService.delete_by_tenant_id(usr.id)
|
||||
done_msg += f"- Deleted {search_delete_res} search records.\n"
|
||||
# step1.2 delete tenant_llm and tenant_langfuse
|
||||
llm_delete_res = TenantLLMService.delete_by_tenant_id(tenant_id)
|
||||
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
|
||||
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
|
||||
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
|
||||
# step1.3 delete own tenant
|
||||
tenant_delete_res = TenantService.delete_by_id(tenant_id)
|
||||
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
|
||||
# step2 delete user-tenant relation
|
||||
if tenants:
|
||||
# step2.1 delete docs and files in joined team
|
||||
joined_tenants = [t for t in tenants if t["role"] == UserTenantRole.NORMAL.value]
|
||||
if joined_tenants:
|
||||
done_msg += "Start to delete data in joined tenants.\n"
|
||||
created_documents = DocumentService.get_all_docs_by_creator_id(usr.id)
|
||||
if created_documents:
|
||||
# step2.1.1 delete files
|
||||
doc_file_info = File2DocumentService.get_by_document_ids([d['id'] for d in created_documents])
|
||||
created_files = FileService.get_by_ids([f['file_id'] for f in doc_file_info])
|
||||
if created_files:
|
||||
# step2.1.1.1 delete file in storage
|
||||
for f in created_files:
|
||||
STORAGE_IMPL.rm(f.parent_id, f.location)
|
||||
done_msg += f"- Deleted {len(created_files)} uploaded file.\n"
|
||||
# step2.1.1.2 delete file record
|
||||
file_delete_res = FileService.delete_by_ids([f.id for f in created_files])
|
||||
done_msg += f"- Deleted {file_delete_res} file records.\n"
|
||||
# step2.1.2 delete document-file relation record
|
||||
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
|
||||
[d['id'] for d in created_documents],
|
||||
[f.id for f in created_files]
|
||||
)
|
||||
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||
# step2.1.3 delete chunks
|
||||
doc_groups = group_by(created_documents, "tenant_id")
|
||||
kb_grouped_doc = {k: group_by(v, "kb_id") for k, v in doc_groups.items()}
|
||||
# chunks in {'tenant_id': {'kb_id': [{'id': doc_id}]}} structure
|
||||
chunk_delete_res = 0
|
||||
kb_doc_info = {}
|
||||
for _tenant_id, kb_doc in kb_grouped_doc.items():
|
||||
for _kb_id, docs in kb_doc.items():
|
||||
chunk_delete_res += settings.docStoreConn.delete(
|
||||
{"doc_id": [d["id"] for d in docs]},
|
||||
search.index_name(_tenant_id), _kb_id
|
||||
)
|
||||
# record doc info
|
||||
if _kb_id in kb_doc_info.keys():
|
||||
kb_doc_info[_kb_id]['doc_num'] += 1
|
||||
kb_doc_info[_kb_id]['token_num'] += sum([d["token_num"] for d in docs])
|
||||
kb_doc_info[_kb_id]['chunk_num'] += sum([d["chunk_num"] for d in docs])
|
||||
else:
|
||||
kb_doc_info[_kb_id] = {
|
||||
'doc_num': 1,
|
||||
'token_num': sum([d["token_num"] for d in docs]),
|
||||
'chunk_num': sum([d["chunk_num"] for d in docs])
|
||||
}
|
||||
done_msg += f"- Deleted {chunk_delete_res} chunks.\n"
|
||||
# step2.1.4 delete tasks
|
||||
task_delete_res = TaskService.delete_by_doc_ids([d['id'] for d in created_documents])
|
||||
done_msg += f"- Deleted {task_delete_res} tasks.\n"
|
||||
# step2.1.5 delete document record
|
||||
doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents])
|
||||
done_msg += f"- Deleted {doc_delete_res} documents.\n"
|
||||
# step2.1.6 update knowledge base doc&chunk&token cnt
|
||||
for kb_id, doc_num in kb_doc_info.items():
|
||||
KnowledgebaseService.decrease_document_num_in_delete(kb_id, doc_num)
|
||||
|
||||
# step2.2 delete relation
|
||||
user_tenant_delete_res = UserTenantService.delete_by_ids([t["id"] for t in tenants])
|
||||
done_msg += f"- Deleted {user_tenant_delete_res} user-tenant records.\n"
|
||||
# step3 finally delete user
|
||||
user_delete_res = UserService.delete_by_id(usr.id)
|
||||
done_msg += f"- Deleted {user_delete_res} user.\nDelete done!"
|
||||
|
||||
return {"success": True, "message": f"Successfully deleted user. Details:\n{done_msg}"}
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return {"success": False, "message": f"Error: {str(e)}. Already done:\n{done_msg}"}
|
||||
|
||||
|
||||
def delete_user_agents(user_id: str) -> dict:
|
||||
"""
|
||||
use user_id to delete
|
||||
:return: {
|
||||
"agents_deleted_count": 1,
|
||||
"version_deleted_count": 2
|
||||
}
|
||||
"""
|
||||
agents_deleted_count, agents_version_deleted_count = 0, 0
|
||||
user_agents = UserCanvasService.get_all_agents_by_tenant_ids([user_id], user_id)
|
||||
if user_agents:
|
||||
agents_version = UserCanvasVersionService.get_all_canvas_version_by_canvas_ids([a['id'] for a in user_agents])
|
||||
agents_version_deleted_count = UserCanvasVersionService.delete_by_ids([v['id'] for v in agents_version])
|
||||
agents_deleted_count = UserCanvasService.delete_by_ids([a['id'] for a in user_agents])
|
||||
return {
|
||||
"agents_deleted_count": agents_deleted_count,
|
||||
"version_deleted_count": agents_version_deleted_count
|
||||
}
|
||||
|
||||
|
||||
def delete_user_dialogs(user_id: str) -> dict:
|
||||
"""
|
||||
use user_id to delete
|
||||
:return: {
|
||||
"dialogs_deleted_count": 1,
|
||||
"conversations_deleted_count": 1,
|
||||
"api_token_deleted_count": 2,
|
||||
"api4conversation_deleted_count": 2
|
||||
}
|
||||
"""
|
||||
dialog_deleted_count, conversations_deleted_count, api_token_deleted_count, api4conversation_deleted_count = 0, 0, 0, 0
|
||||
user_dialogs = DialogService.get_all_dialogs_by_tenant_id(user_id)
|
||||
if user_dialogs:
|
||||
# delete conversation
|
||||
conversations = ConversationService.get_all_conversation_by_dialog_ids([ud['id'] for ud in user_dialogs])
|
||||
conversations_deleted_count = ConversationService.delete_by_ids([c['id'] for c in conversations])
|
||||
# delete api token
|
||||
api_token_deleted_count = APITokenService.delete_by_tenant_id(user_id)
|
||||
# delete api for conversation
|
||||
api4conversation_deleted_count = API4ConversationService.delete_by_dialog_ids([ud['id'] for ud in user_dialogs])
|
||||
# delete dialog at last
|
||||
dialog_deleted_count = DialogService.delete_by_ids([ud['id'] for ud in user_dialogs])
|
||||
return {
|
||||
"dialogs_deleted_count": dialog_deleted_count,
|
||||
"conversations_deleted_count": conversations_deleted_count,
|
||||
"api_token_deleted_count": api_token_deleted_count,
|
||||
"api4conversation_deleted_count": api4conversation_deleted_count
|
||||
}
|
||||
@ -35,6 +35,11 @@ class APITokenService(CommonService):
|
||||
cls.model.token == token
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
|
||||
class API4ConversationService(CommonService):
|
||||
model = API4Conversation
|
||||
@ -100,3 +105,8 @@ class API4ConversationService(CommonService):
|
||||
cls.model.create_date <= to_date,
|
||||
cls.model.source == source
|
||||
).group_by(cls.model.create_date.truncate("day")).dicts()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_dialog_ids(cls, dialog_ids):
|
||||
return cls.model.delete().where(cls.model.dialog_id.in_(dialog_ids)).execute()
|
||||
|
||||
@ -63,7 +63,38 @@ class UserCanvasService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_id(cls, pid):
|
||||
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
|
||||
# will get all permitted agents, be cautious
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.title,
|
||||
cls.model.permission,
|
||||
cls.model.canvas_type,
|
||||
cls.model.canvas_category
|
||||
]
|
||||
# find team agents and owned agents
|
||||
agents = cls.model.select(*fields).where(
|
||||
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
|
||||
cls.model.user_id == user_id
|
||||
)
|
||||
)
|
||||
# sort by create_time, asc
|
||||
agents.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 50
|
||||
res = []
|
||||
while True:
|
||||
ag_batch = agents.offset(offset).limit(limit)
|
||||
_temp = list(ag_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_canvas_id(cls, pid):
|
||||
try:
|
||||
|
||||
fields = [
|
||||
@ -95,7 +126,7 @@ class UserCanvasService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||
page_number, items_per_page,
|
||||
orderby, desc, keywords, canvas_category=CanvasCategory.Agent,
|
||||
orderby, desc, keywords, canvas_category=None
|
||||
):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
@ -104,6 +135,7 @@ class UserCanvasService(CommonService):
|
||||
cls.model.dsl,
|
||||
cls.model.description,
|
||||
cls.model.permission,
|
||||
cls.model.user_id.alias("tenant_id"),
|
||||
User.nickname,
|
||||
User.avatar.alias('tenant_avatar'),
|
||||
cls.model.update_time,
|
||||
@ -111,31 +143,33 @@ class UserCanvasService(CommonService):
|
||||
]
|
||||
if keywords:
|
||||
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
||||
((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
||||
TenantPermission.TEAM.value)) | (
|
||||
cls.model.user_id == user_id)),
|
||||
(fn.LOWER(cls.model.title).contains(keywords.lower()))
|
||||
cls.model.user_id.in_(joined_tenant_ids),
|
||||
fn.LOWER(cls.model.title).contains(keywords.lower())
|
||||
#(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)),
|
||||
#(fn.LOWER(cls.model.title).contains(keywords.lower()))
|
||||
)
|
||||
else:
|
||||
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
||||
((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
||||
TenantPermission.TEAM.value)) | (
|
||||
cls.model.user_id == user_id))
|
||||
cls.model.user_id.in_(joined_tenant_ids)
|
||||
#(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))
|
||||
)
|
||||
agents = agents.where(cls.model.canvas_category == canvas_category)
|
||||
if canvas_category:
|
||||
agents = agents.where(cls.model.canvas_category == canvas_category)
|
||||
if desc:
|
||||
agents = agents.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
agents = agents.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
count = agents.count()
|
||||
agents = agents.paginate(page_number, items_per_page)
|
||||
if page_number and items_per_page:
|
||||
agents = agents.paginate(page_number, items_per_page)
|
||||
return list(agents.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def accessible(cls, canvas_id, tenant_id):
|
||||
from api.db.services.user_service import UserTenantService
|
||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return False
|
||||
|
||||
|
||||
@ -14,12 +14,24 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
import peewee
|
||||
from peewee import InterfaceError, OperationalError
|
||||
|
||||
from api.db.db_models import DB
|
||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||
|
||||
def retry_db_operation(func):
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=5),
|
||||
retry=retry_if_exception_type((InterfaceError, OperationalError)),
|
||||
before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"),
|
||||
reraise=True,
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
class CommonService:
|
||||
"""Base service class that provides common database operations.
|
||||
@ -202,6 +214,7 @@ class CommonService:
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@retry_db_operation
|
||||
def update_by_id(cls, pid, data):
|
||||
# Update a single record by ID
|
||||
# Args:
|
||||
|
||||
@ -48,6 +48,21 @@ class ConversationService(CommonService):
|
||||
|
||||
return list(sessions.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_conversation_by_dialog_ids(cls, dialog_ids):
|
||||
sessions = cls.model.select().where(cls.model.dialog_id.in_(dialog_ids))
|
||||
sessions.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
s_batch = sessions.offset(offset).limit(limit)
|
||||
_temp = list(s_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
def structure_answer(conv, ans, message_id, session_id):
|
||||
reference = ans["reference"]
|
||||
|
||||
@ -159,6 +159,22 @@ class DialogService(CommonService):
|
||||
|
||||
return list(dialogs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_dialogs_by_tenant_id(cls, tenant_id):
|
||||
fields = [cls.model.id]
|
||||
dialogs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||
dialogs.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
d_batch = dialogs.offset(offset).limit(limit)
|
||||
_temp = list(d_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
def chat_solo(dialog, messages, stream=True):
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
|
||||
@ -24,12 +24,13 @@ from io import BytesIO
|
||||
|
||||
import trio
|
||||
import xxhash
|
||||
from peewee import fn, Case
|
||||
from peewee import fn, Case, JOIN
|
||||
|
||||
from api import settings
|
||||
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
|
||||
from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole
|
||||
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File
|
||||
from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole, CanvasCategory
|
||||
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
|
||||
User
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -51,6 +52,7 @@ class DocumentService(CommonService):
|
||||
cls.model.thumbnail,
|
||||
cls.model.kb_id,
|
||||
cls.model.parser_id,
|
||||
cls.model.pipeline_id,
|
||||
cls.model.parser_config,
|
||||
cls.model.source_type,
|
||||
cls.model.type,
|
||||
@ -79,7 +81,10 @@ class DocumentService(CommonService):
|
||||
def get_list(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords, id, name):
|
||||
fields = cls.get_cls_model_fields()
|
||||
docs = cls.model.select(*fields).join(File2Document, on = (File2Document.document_id == cls.model.id)).join(File, on = (File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\
|
||||
.join(File, on = (File.id == File2Document.file_id))\
|
||||
.join(UserCanvas, on = ((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(cls.model.kb_id == kb_id)
|
||||
if id:
|
||||
docs = docs.where(
|
||||
cls.model.id == id)
|
||||
@ -117,12 +122,22 @@ class DocumentService(CommonService):
|
||||
orderby, desc, keywords, run_status, types, suffix):
|
||||
fields = cls.get_cls_model_fields()
|
||||
if keywords:
|
||||
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
|
||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||
.join(File, on=(File.id == File2Document.file_id))\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
else:
|
||||
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
|
||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.join(File, on=(File.id == File2Document.file_id))\
|
||||
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(cls.model.kb_id == kb_id)
|
||||
|
||||
if run_status:
|
||||
docs = docs.where(cls.model.run.in_(run_status))
|
||||
@ -228,6 +243,46 @@ class DocumentService(CommonService):
|
||||
|
||||
return int(query.scalar()) or 0
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_doc_ids_by_kb_ids(cls, kb_ids):
|
||||
fields = [cls.model.id]
|
||||
docs = cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids))
|
||||
docs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
doc_batch = docs.offset(offset).limit(limit)
|
||||
_temp = list(doc_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_docs_by_creator_id(cls, creator_id):
|
||||
fields = [
|
||||
cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id
|
||||
]
|
||||
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.created_by == creator_id
|
||||
)
|
||||
docs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
doc_batch = docs.offset(offset).limit(limit)
|
||||
_temp = list(doc_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert(cls, doc):
|
||||
@ -330,8 +385,7 @@ class DocumentService(CommonService):
|
||||
process_duration=cls.model.process_duration + duration).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
if num == 0:
|
||||
raise LookupError(
|
||||
"Document not found which is supposed to be there")
|
||||
logging.warning("Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num +
|
||||
token_num,
|
||||
@ -597,6 +651,22 @@ class DocumentService(CommonService):
|
||||
@DB.connection_context()
|
||||
def update_progress(cls):
|
||||
docs = cls.get_unfinished_docs()
|
||||
|
||||
cls._sync_progress(docs)
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_progress_immediately(cls, docs:list[dict]):
|
||||
if not docs:
|
||||
return
|
||||
|
||||
cls._sync_progress(docs)
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def _sync_progress(cls, docs:list[dict]):
|
||||
for d in docs:
|
||||
try:
|
||||
tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
|
||||
@ -606,8 +676,6 @@ class DocumentService(CommonService):
|
||||
prg = 0
|
||||
finished = True
|
||||
bad = 0
|
||||
has_raptor = False
|
||||
has_graphrag = False
|
||||
e, doc = DocumentService.get_by_id(d["id"])
|
||||
status = doc.run # TaskStatus.RUNNING.value
|
||||
priority = 0
|
||||
@ -619,24 +687,14 @@ class DocumentService(CommonService):
|
||||
prg += t.progress if t.progress >= 0 else 0
|
||||
if t.progress_msg.strip():
|
||||
msg.append(t.progress_msg)
|
||||
if t.task_type == "raptor":
|
||||
has_raptor = True
|
||||
elif t.task_type == "graphrag":
|
||||
has_graphrag = True
|
||||
priority = max(priority, t.priority)
|
||||
prg /= len(tsks)
|
||||
if finished and bad:
|
||||
prg = -1
|
||||
status = TaskStatus.FAIL.value
|
||||
elif finished:
|
||||
if (d["parser_config"].get("raptor") or {}).get("use_raptor") and not has_raptor:
|
||||
queue_raptor_o_graphrag_tasks(d, "raptor", priority)
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
elif (d["parser_config"].get("graphrag") or {}).get("use_graphrag") and not has_graphrag:
|
||||
queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
else:
|
||||
status = TaskStatus.DONE.value
|
||||
prg = 1
|
||||
status = TaskStatus.DONE.value
|
||||
|
||||
msg = "\n".join(sorted(msg))
|
||||
info = {
|
||||
@ -648,7 +706,7 @@ class DocumentService(CommonService):
|
||||
info["progress"] = prg
|
||||
if msg:
|
||||
info["progress_msg"] = msg
|
||||
if msg.endswith("created task graphrag") or msg.endswith("created task raptor"):
|
||||
if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"):
|
||||
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
else:
|
||||
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
@ -729,7 +787,11 @@ class DocumentService(CommonService):
|
||||
"cancelled": int(cancelled),
|
||||
}
|
||||
|
||||
def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
||||
def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]):
|
||||
"""
|
||||
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
|
||||
Optionally, specify a list of doc_ids to determine which documents participate in the task.
|
||||
"""
|
||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||
hasher = xxhash.xxh64()
|
||||
for field in sorted(chunking_config.keys()):
|
||||
@ -739,11 +801,12 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
||||
nonlocal doc
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc["id"],
|
||||
"doc_id": fake_doc_id if fake_doc_id else doc["id"],
|
||||
"from_page": 100000000,
|
||||
"to_page": 100000000,
|
||||
"task_type": ty,
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
|
||||
"begin_at": datetime.now(),
|
||||
}
|
||||
|
||||
task = new_task()
|
||||
@ -752,7 +815,12 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
||||
hasher.update(ty.encode("utf-8"))
|
||||
task["digest"] = hasher.hexdigest()
|
||||
bulk_insert_into_db(Task, [task], True)
|
||||
|
||||
if ty in ["graphrag", "raptor", "mindmap"]:
|
||||
task["doc_ids"] = doc_ids
|
||||
DocumentService.begin2parse(doc["id"])
|
||||
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
||||
return task["id"]
|
||||
|
||||
|
||||
def get_queue_length(priority):
|
||||
|
||||
@ -38,6 +38,12 @@ class File2DocumentService(CommonService):
|
||||
objs = cls.model.select().where(cls.model.document_id == document_id)
|
||||
return objs
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_document_ids(cls, document_ids):
|
||||
objs = cls.model.select().where(cls.model.document_id.in_(document_ids))
|
||||
return list(objs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert(cls, obj):
|
||||
@ -50,6 +56,15 @@ class File2DocumentService(CommonService):
|
||||
def delete_by_file_id(cls, file_id):
|
||||
return cls.model.delete().where(cls.model.file_id == file_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_document_ids_or_file_ids(cls, document_ids, file_ids):
|
||||
if not document_ids:
|
||||
return cls.model.delete().where(cls.model.file_id.in_(file_ids)).execute()
|
||||
elif not file_ids:
|
||||
return cls.model.delete().where(cls.model.document_id.in_(document_ids)).execute()
|
||||
return cls.model.delete().where(cls.model.document_id.in_(document_ids) | cls.model.file_id.in_(file_ids)).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_document_id(cls, doc_id):
|
||||
|
||||
@ -161,6 +161,23 @@ class FileService(CommonService):
|
||||
result_ids.append(folder_id)
|
||||
return result_ids
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_file_ids_by_tenant_id(cls, tenant_id):
|
||||
fields = [cls.model.id]
|
||||
files = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||
files.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
file_batch = files.offset(offset).limit(limit)
|
||||
_temp = list(file_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create_folder(cls, file, parent_id, name, count):
|
||||
@ -440,6 +457,7 @@ class FileService(CommonService):
|
||||
"id": doc_id,
|
||||
"kb_id": kb.id,
|
||||
"parser_id": self.get_parser(filetype, filename, kb.parser_id),
|
||||
"pipeline_id": kb.pipeline_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": user_id,
|
||||
"type": filetype,
|
||||
@ -495,7 +513,7 @@ class FileService(CommonService):
|
||||
return ParserType.AUDIO.value
|
||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||
return ParserType.PRESENTATION.value
|
||||
if re.search(r"\.(eml)$", filename):
|
||||
if re.search(r"\.(msg|eml)$", filename):
|
||||
return ParserType.EMAIL.value
|
||||
return default
|
||||
|
||||
|
||||
@ -15,10 +15,10 @@
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
from peewee import fn
|
||||
from peewee import fn, JOIN
|
||||
|
||||
from api.db import StatusEnum, TenantPermission
|
||||
from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant
|
||||
from api.db.db_models import DB, Document, Knowledgebase, User, UserTenant, UserCanvas
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
|
||||
@ -190,6 +190,41 @@ class KnowledgebaseService(CommonService):
|
||||
|
||||
return list(kbs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id):
|
||||
# will get all permitted kb, be cautious.
|
||||
fields = [
|
||||
cls.model.name,
|
||||
cls.model.language,
|
||||
cls.model.permission,
|
||||
cls.model.doc_num,
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.status,
|
||||
cls.model.create_date,
|
||||
cls.model.update_date
|
||||
]
|
||||
# find team kb and owned kb
|
||||
kbs = cls.model.select(*fields).where(
|
||||
(cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | (
|
||||
cls.model.tenant_id == user_id
|
||||
)
|
||||
)
|
||||
# sort by create_time asc
|
||||
kbs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later.
|
||||
offset, limit = 0, 50
|
||||
res = []
|
||||
while True:
|
||||
kb_batch = kbs.offset(offset).limit(limit)
|
||||
_temp = list(kb_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_ids(cls, tenant_id):
|
||||
@ -225,20 +260,29 @@ class KnowledgebaseService(CommonService):
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.parser_id,
|
||||
cls.model.pipeline_id,
|
||||
UserCanvas.title.alias("pipeline_name"),
|
||||
UserCanvas.avatar.alias("pipeline_avatar"),
|
||||
cls.model.parser_config,
|
||||
cls.model.pagerank,
|
||||
cls.model.graphrag_task_id,
|
||||
cls.model.graphrag_task_finish_at,
|
||||
cls.model.raptor_task_id,
|
||||
cls.model.raptor_task_finish_at,
|
||||
cls.model.mindmap_task_id,
|
||||
cls.model.mindmap_task_finish_at,
|
||||
cls.model.create_time,
|
||||
cls.model.update_time
|
||||
]
|
||||
kbs = cls.model.select(*fields).join(Tenant, on=(
|
||||
(Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
|
||||
kbs = cls.model.select(*fields)\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(
|
||||
(cls.model.id == kb_id),
|
||||
(cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
).dicts()
|
||||
if not kbs:
|
||||
return
|
||||
d = kbs[0].to_dict()
|
||||
return d
|
||||
return kbs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@ -436,3 +480,17 @@ class KnowledgebaseService(CommonService):
|
||||
else:
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def decrease_document_num_in_delete(cls, kb_id, doc_num_info: dict):
|
||||
kb_row = cls.model.get_by_id(kb_id)
|
||||
if not kb_row:
|
||||
raise RuntimeError(f"kb_id {kb_id} does not exist")
|
||||
update_dict = {
|
||||
'doc_num': kb_row.doc_num - doc_num_info['doc_num'],
|
||||
'chunk_num': kb_row.chunk_num - doc_num_info['chunk_num'],
|
||||
'token_num': kb_row.token_num - doc_num_info['token_num'],
|
||||
'update_time': current_timestamp(),
|
||||
'update_date': datetime_format(datetime.now())
|
||||
}
|
||||
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()
|
||||
|
||||
@ -51,6 +51,11 @@ class TenantLangfuseService(CommonService):
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_ty_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@classmethod
|
||||
def update_by_tenant(cls, tenant_id, langfuse_keys):
|
||||
langfuse_keys["update_time"] = current_timestamp()
|
||||
|
||||
@ -84,3 +84,8 @@ class MCPServerService(CommonService):
|
||||
return bool(mcp_server), mcp_server
|
||||
except Exception:
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id: str):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
263
api/db/services/pipeline_operation_log_service.py
Normal file
263
api/db/services/pipeline_operation_log_service.py
Normal file
@ -0,0 +1,263 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from peewee import fn
|
||||
|
||||
from api.db import VALID_PIPELINE_TASK_TYPES, PipelineTaskType
|
||||
from api.db.db_models import DB, Document, PipelineOperationLog
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||
|
||||
|
||||
class PipelineOperationLogService(CommonService):
|
||||
model = PipelineOperationLog
|
||||
|
||||
@classmethod
|
||||
def get_file_logs_fields(cls):
|
||||
return [
|
||||
cls.model.id,
|
||||
cls.model.document_id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.kb_id,
|
||||
cls.model.pipeline_id,
|
||||
cls.model.pipeline_title,
|
||||
cls.model.parser_id,
|
||||
cls.model.document_name,
|
||||
cls.model.document_suffix,
|
||||
cls.model.document_type,
|
||||
cls.model.source_from,
|
||||
cls.model.progress,
|
||||
cls.model.progress_msg,
|
||||
cls.model.process_begin_at,
|
||||
cls.model.process_duration,
|
||||
cls.model.dsl,
|
||||
cls.model.task_type,
|
||||
cls.model.operation_status,
|
||||
cls.model.avatar,
|
||||
cls.model.status,
|
||||
cls.model.create_time,
|
||||
cls.model.create_date,
|
||||
cls.model.update_time,
|
||||
cls.model.update_date,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_dataset_logs_fields(cls):
|
||||
return [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.kb_id,
|
||||
cls.model.progress,
|
||||
cls.model.progress_msg,
|
||||
cls.model.process_begin_at,
|
||||
cls.model.process_duration,
|
||||
cls.model.task_type,
|
||||
cls.model.operation_status,
|
||||
cls.model.avatar,
|
||||
cls.model.status,
|
||||
cls.model.create_time,
|
||||
cls.model.create_date,
|
||||
cls.model.update_time,
|
||||
cls.model.update_date,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def save(cls, **kwargs):
|
||||
"""
|
||||
wrap this function in a transaction
|
||||
"""
|
||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return sample_obj
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[], dsl: str = "{}"):
|
||||
referred_document_id = document_id
|
||||
|
||||
if referred_document_id == GRAPH_RAPTOR_FAKE_DOC_ID and fake_document_ids:
|
||||
referred_document_id = fake_document_ids[0]
|
||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||
if not ok:
|
||||
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||
return
|
||||
DocumentService.update_progress_immediately([document.to_dict()])
|
||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||
if not ok:
|
||||
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||
return
|
||||
if document.progress not in [1, -1]:
|
||||
return
|
||||
operation_status = document.run
|
||||
|
||||
if pipeline_id:
|
||||
ok, user_pipeline = UserCanvasService.get_by_id(pipeline_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Pipeline {pipeline_id} not found")
|
||||
tenant_id = user_pipeline.user_id
|
||||
title = user_pipeline.title
|
||||
avatar = user_pipeline.avatar
|
||||
else:
|
||||
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}")
|
||||
|
||||
tenant_id = kb_info.tenant_id
|
||||
title = document.parser_id
|
||||
avatar = document.thumbnail
|
||||
|
||||
if task_type not in VALID_PIPELINE_TASK_TYPES:
|
||||
raise ValueError(f"Invalid task type: {task_type}")
|
||||
|
||||
if task_type in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
|
||||
finish_at = document.process_begin_at + timedelta(seconds=document.process_duration)
|
||||
if task_type == PipelineTaskType.GRAPH_RAG:
|
||||
KnowledgebaseService.update_by_id(
|
||||
document.kb_id,
|
||||
{"graphrag_task_finish_at": finish_at},
|
||||
)
|
||||
elif task_type == PipelineTaskType.RAPTOR:
|
||||
KnowledgebaseService.update_by_id(
|
||||
document.kb_id,
|
||||
{"raptor_task_finish_at": finish_at},
|
||||
)
|
||||
elif task_type == PipelineTaskType.MINDMAP:
|
||||
KnowledgebaseService.update_by_id(
|
||||
document.kb_id,
|
||||
{"mindmap_task_finish_at": finish_at},
|
||||
)
|
||||
|
||||
log = dict(
|
||||
id=get_uuid(),
|
||||
document_id=document_id, # GRAPH_RAPTOR_FAKE_DOC_ID or real document_id
|
||||
tenant_id=tenant_id,
|
||||
kb_id=document.kb_id,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_title=title,
|
||||
parser_id=document.parser_id,
|
||||
document_name=document.name,
|
||||
document_suffix=document.suffix,
|
||||
document_type=document.type,
|
||||
source_from="", # TODO: add in the future
|
||||
progress=document.progress,
|
||||
progress_msg=document.progress_msg,
|
||||
process_begin_at=document.process_begin_at,
|
||||
process_duration=document.process_duration,
|
||||
dsl=json.loads(dsl),
|
||||
task_type=task_type,
|
||||
operation_status=operation_status,
|
||||
avatar=avatar,
|
||||
)
|
||||
log["create_time"] = current_timestamp()
|
||||
log["create_date"] = datetime_format(datetime.now())
|
||||
log["update_time"] = current_timestamp()
|
||||
log["update_date"] = datetime_format(datetime.now())
|
||||
|
||||
with DB.atomic():
|
||||
obj = cls.save(**log)
|
||||
|
||||
limit = int(os.getenv("PIPELINE_OPERATION_LOG_LIMIT", 1000))
|
||||
total = cls.model.select().where(cls.model.kb_id == document.kb_id).count()
|
||||
|
||||
if total > limit:
|
||||
keep_ids = [m.id for m in cls.model.select(cls.model.id).where(cls.model.kb_id == document.kb_id).order_by(cls.model.create_time.desc()).limit(limit)]
|
||||
|
||||
deleted = cls.model.delete().where(cls.model.kb_id == document.kb_id, cls.model.id.not_in(keep_ids)).execute()
|
||||
logging.info(f"[PipelineOperationLogService] Cleaned {deleted} old logs, kept latest {limit} for {document.kb_id}")
|
||||
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def record_pipeline_operation(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
|
||||
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type, fake_document_ids=fake_document_ids)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from=None, create_date_to=None):
|
||||
fields = cls.get_file_logs_fields()
|
||||
if keywords:
|
||||
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
|
||||
else:
|
||||
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
|
||||
|
||||
logs = logs.where(cls.model.document_id != GRAPH_RAPTOR_FAKE_DOC_ID)
|
||||
|
||||
if operation_status:
|
||||
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||
if types:
|
||||
logs = logs.where(cls.model.document_type.in_(types))
|
||||
if suffix:
|
||||
logs = logs.where(cls.model.document_suffix.in_(suffix))
|
||||
if create_date_from:
|
||||
logs = logs.where(cls.model.create_date >= create_date_from)
|
||||
if create_date_to:
|
||||
logs = logs.where(cls.model.create_date <= create_date_to)
|
||||
|
||||
count = logs.count()
|
||||
if desc:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
if page_number and items_per_page:
|
||||
logs = logs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(logs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_documents_info(cls, id):
|
||||
fields = [Document.id, Document.name, Document.progress, Document.kb_id]
|
||||
return (
|
||||
cls.model.select(*fields)
|
||||
.join(Document, on=(cls.model.document_id == Document.id))
|
||||
.where(
|
||||
cls.model.id == id
|
||||
)
|
||||
.dicts()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from=None, create_date_to=None):
|
||||
fields = cls.get_dataset_logs_fields()
|
||||
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID))
|
||||
|
||||
if operation_status:
|
||||
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||
if create_date_from:
|
||||
logs = logs.where(cls.model.create_date >= create_date_from)
|
||||
if create_date_to:
|
||||
logs = logs.where(cls.model.create_date <= create_date_to)
|
||||
|
||||
count = logs.count()
|
||||
if desc:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
if page_number and items_per_page:
|
||||
logs = logs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(logs.dicts()), count
|
||||
@ -110,3 +110,8 @@ class SearchService(CommonService):
|
||||
query = query.paginate(page_number, items_per_page)
|
||||
|
||||
return list(query.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@ -35,6 +35,8 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from api import settings
|
||||
from rag.nlp import search
|
||||
|
||||
CANVAS_DEBUG_DOC_ID = "dataflow_x"
|
||||
GRAPH_RAPTOR_FAKE_DOC_ID = "graph_raptor_x"
|
||||
|
||||
def trim_header_by_lines(text: str, max_length) -> str:
|
||||
# Trim header text to maximum length while preserving line breaks
|
||||
@ -70,7 +72,7 @@ class TaskService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_task(cls, task_id):
|
||||
def get_task(cls, task_id, doc_ids=[]):
|
||||
"""Retrieve detailed task information by task ID.
|
||||
|
||||
This method fetches comprehensive task details including associated document,
|
||||
@ -84,6 +86,10 @@ class TaskService(CommonService):
|
||||
dict: Task details dictionary containing all task information and related metadata.
|
||||
Returns None if task is not found or has exceeded retry limit.
|
||||
"""
|
||||
doc_id = cls.model.doc_id
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID and doc_ids:
|
||||
doc_id = doc_ids[0]
|
||||
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.doc_id,
|
||||
@ -109,7 +115,7 @@ class TaskService(CommonService):
|
||||
]
|
||||
docs = (
|
||||
cls.model.select(*fields)
|
||||
.join(Document, on=(cls.model.doc_id == Document.id))
|
||||
.join(Document, on=(doc_id == Document.id))
|
||||
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(cls.model.id == task_id)
|
||||
@ -292,21 +298,29 @@ class TaskService(CommonService):
|
||||
((prog == -1) | (prog > cls.model.progress))
|
||||
)
|
||||
).execute()
|
||||
return
|
||||
else:
|
||||
with DB.lock("update_progress", -1):
|
||||
if info["progress_msg"]:
|
||||
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
|
||||
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
||||
if "progress" in info:
|
||||
prog = info["progress"]
|
||||
cls.model.update(progress=prog).where(
|
||||
(cls.model.id == id) &
|
||||
(
|
||||
(cls.model.progress != -1) &
|
||||
((prog == -1) | (prog > cls.model.progress))
|
||||
)
|
||||
).execute()
|
||||
|
||||
with DB.lock("update_progress", -1):
|
||||
if info["progress_msg"]:
|
||||
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
|
||||
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
||||
if "progress" in info:
|
||||
prog = info["progress"]
|
||||
cls.model.update(progress=prog).where(
|
||||
(cls.model.id == id) &
|
||||
(
|
||||
(cls.model.progress != -1) &
|
||||
((prog == -1) | (prog > cls.model.progress))
|
||||
)
|
||||
).execute()
|
||||
process_duration = (datetime.now() - task.begin_at).total_seconds()
|
||||
cls.model.update(process_duration=process_duration).where(cls.model.id == id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_doc_ids(cls, doc_ids):
|
||||
"""Delete task associated with a document."""
|
||||
return cls.model.delete().where(cls.model.doc_id.in_(doc_ids)).execute()
|
||||
|
||||
|
||||
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
@ -330,7 +344,14 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
- Previous task chunks may be reused if available
|
||||
"""
|
||||
def new_task():
|
||||
return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0, "from_page": 0, "to_page": 100000000}
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc["id"],
|
||||
"progress": 0.0,
|
||||
"from_page": 0,
|
||||
"to_page": 100000000,
|
||||
"begin_at": datetime.now(),
|
||||
}
|
||||
|
||||
parse_task_array = []
|
||||
|
||||
@ -343,7 +364,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
page_size = doc["parser_config"].get("task_page_size") or 12
|
||||
if doc["parser_id"] == "paper":
|
||||
page_size = doc["parser_config"].get("task_page_size") or 22
|
||||
if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC":
|
||||
if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC" or doc["parser_config"].get("toc", True):
|
||||
page_size = 10 ** 9
|
||||
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
|
||||
for s, e in page_ranges:
|
||||
@ -472,33 +493,26 @@ def has_canceled(task_id):
|
||||
return False
|
||||
|
||||
|
||||
def queue_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, priority: int, callback=None) -> tuple[bool, str]:
|
||||
"""
|
||||
Returns a tuple (success: bool, error_message: str).
|
||||
"""
|
||||
_ = callback
|
||||
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DEBUG_DOC_ID, file:dict=None, priority: int=0, rerun:bool=False) -> tuple[bool, str]:
|
||||
|
||||
task = dict(
|
||||
id=get_uuid() if not task_id else task_id,
|
||||
doc_id=doc_id,
|
||||
from_page=0,
|
||||
to_page=100000000,
|
||||
task_type="dataflow",
|
||||
priority=priority,
|
||||
id=task_id,
|
||||
doc_id=doc_id,
|
||||
from_page=0,
|
||||
to_page=100000000,
|
||||
task_type="dataflow" if not rerun else "dataflow_rerun",
|
||||
priority=priority,
|
||||
begin_at=datetime.now(),
|
||||
)
|
||||
|
||||
TaskService.model.delete().where(TaskService.model.id == task["id"]).execute()
|
||||
if doc_id not in [CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID]:
|
||||
TaskService.model.delete().where(TaskService.model.doc_id == doc_id).execute()
|
||||
DocumentService.begin2parse(doc_id)
|
||||
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
|
||||
|
||||
kb_id = DocumentService.get_knowledgebase_id(doc_id)
|
||||
if not kb_id:
|
||||
return False, f"Can't find KB of this document: {doc_id}"
|
||||
|
||||
task["kb_id"] = kb_id
|
||||
task["kb_id"] = DocumentService.get_knowledgebase_id(doc_id)
|
||||
task["tenant_id"] = tenant_id
|
||||
task["task_type"] = "dataflow"
|
||||
task["dsl"] = dsl
|
||||
task["dataflow_id"] = get_uuid() if not flow_id else flow_id
|
||||
task["dataflow_id"] = flow_id
|
||||
task["file"] = file
|
||||
|
||||
if not REDIS_CONN.queue_product(
|
||||
get_svr_queue_name(priority), message=task
|
||||
|
||||
@ -209,6 +209,11 @@ class TenantLLMService(CommonService):
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
return list(objs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@staticmethod
|
||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||
from api.db.services.llm_service import LLMService
|
||||
|
||||
@ -24,7 +24,24 @@ class UserCanvasVersionService(CommonService):
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_canvas_version_by_canvas_ids(cls, canvas_ids):
|
||||
fields = [cls.model.id]
|
||||
versions = cls.model.select(*fields).where(cls.model.user_canvas_id.in_(canvas_ids))
|
||||
versions.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
version_batch = versions.offset(offset).limit(limit)
|
||||
_temp = list(version_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_all_versions(cls, user_canvas_id):
|
||||
|
||||
@ -100,6 +100,12 @@ class UserService(CommonService):
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query_user_by_email(cls, email):
|
||||
users = cls.model.select().where((cls.model.email == email))
|
||||
return list(users)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
@ -133,6 +139,17 @@ class UserService(CommonService):
|
||||
cls.model.update(user_dict).where(
|
||||
cls.model.id == user_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_user_password(cls, user_id, new_password):
|
||||
with DB.atomic():
|
||||
update_dict = {
|
||||
"password": generate_password_hash(str(new_password)),
|
||||
"update_time": current_timestamp(),
|
||||
"update_date": datetime_format(datetime.now())
|
||||
}
|
||||
cls.model.update(update_dict).where(cls.model.id == user_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def is_admin(cls, user_id):
|
||||
@ -271,6 +288,17 @@ class UserTenantService(CommonService):
|
||||
.join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_user_tenant_relation_by_user_id(cls, user_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.user_id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.role
|
||||
]
|
||||
return list(cls.model.select(*fields).where(cls.model.user_id == user_id).dicts().dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_num_members(cls, user_id: str):
|
||||
|
||||
@ -41,7 +41,7 @@ from api import utils
|
||||
from api.db.db_models import init_database_tables as init_web_db
|
||||
from api.db.init_data import init_web_data
|
||||
from api.versions import get_ragflow_version
|
||||
from api.utils import show_configs
|
||||
from api.utils.configs import show_configs
|
||||
from rag.settings import print_rag_settings
|
||||
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
|
||||
@ -24,7 +24,7 @@ import rag.utils.es_conn
|
||||
import rag.utils.infinity_conn
|
||||
import rag.utils.opensearch_conn
|
||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||
from api.utils import decrypt_database_config, get_base_config
|
||||
from api.utils.configs import decrypt_database_config, get_base_config
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.nlp import search
|
||||
|
||||
|
||||
@ -16,182 +16,15 @@
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import time
|
||||
import uuid
|
||||
import requests
|
||||
import logging
|
||||
import copy
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
import importlib
|
||||
from filelock import FileLock
|
||||
from api.constants import SERVICE_CONF
|
||||
|
||||
from . import file_utils
|
||||
|
||||
|
||||
def conf_realpath(conf_name):
|
||||
conf_path = f"conf/{conf_name}"
|
||||
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
|
||||
def read_config(conf_name=SERVICE_CONF):
|
||||
local_config = {}
|
||||
local_path = conf_realpath(f'local.{conf_name}')
|
||||
|
||||
# load local config file
|
||||
if os.path.exists(local_path):
|
||||
local_config = file_utils.load_yaml_conf(local_path)
|
||||
if not isinstance(local_config, dict):
|
||||
raise ValueError(f'Invalid config file: "{local_path}".')
|
||||
|
||||
global_config_path = conf_realpath(conf_name)
|
||||
global_config = file_utils.load_yaml_conf(global_config_path)
|
||||
|
||||
if not isinstance(global_config, dict):
|
||||
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
||||
|
||||
global_config.update(local_config)
|
||||
return global_config
|
||||
|
||||
|
||||
CONFIGS = read_config()
|
||||
|
||||
|
||||
def show_configs():
|
||||
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
||||
for k, v in CONFIGS.items():
|
||||
if isinstance(v, dict):
|
||||
if "password" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["password"] = "*" * 8
|
||||
if "access_key" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["access_key"] = "*" * 8
|
||||
if "secret_key" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["secret_key"] = "*" * 8
|
||||
if "secret" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["secret"] = "*" * 8
|
||||
if "sas_token" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["sas_token"] = "*" * 8
|
||||
if "oauth" in k:
|
||||
v = copy.deepcopy(v)
|
||||
for key, val in v.items():
|
||||
if "client_secret" in val:
|
||||
val["client_secret"] = "*" * 8
|
||||
if "authentication" in k:
|
||||
v = copy.deepcopy(v)
|
||||
for key, val in v.items():
|
||||
if "http_secret_key" in val:
|
||||
val["http_secret_key"] = "*" * 8
|
||||
msg += f"\n\t{k}: {v}"
|
||||
logging.info(msg)
|
||||
|
||||
|
||||
def get_base_config(key, default=None):
|
||||
if key is None:
|
||||
return None
|
||||
if default is None:
|
||||
default = os.environ.get(key.upper())
|
||||
return CONFIGS.get(key, default)
|
||||
|
||||
|
||||
use_deserialize_safe_module = get_base_config(
|
||||
'use_deserialize_safe_module', False)
|
||||
|
||||
|
||||
class BaseType:
|
||||
def to_dict(self):
|
||||
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
||||
|
||||
def to_dict_with_type(self):
|
||||
def _dict(obj):
|
||||
module = None
|
||||
if issubclass(obj.__class__, BaseType):
|
||||
data = {}
|
||||
for attr, v in obj.__dict__.items():
|
||||
k = attr.lstrip("_")
|
||||
data[k] = _dict(v)
|
||||
module = obj.__module__
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
data = []
|
||||
for i, vv in enumerate(obj):
|
||||
data.append(_dict(vv))
|
||||
elif isinstance(obj, dict):
|
||||
data = {}
|
||||
for _k, vv in obj.items():
|
||||
data[_k] = _dict(vv)
|
||||
else:
|
||||
data = obj
|
||||
return {"type": obj.__class__.__name__,
|
||||
"data": data, "module": module}
|
||||
|
||||
return _dict(self)
|
||||
|
||||
|
||||
class CustomJSONEncoder(json.JSONEncoder):
|
||||
def __init__(self, **kwargs):
|
||||
self._with_type = kwargs.pop("with_type", False)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime.datetime):
|
||||
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||
elif isinstance(obj, datetime.date):
|
||||
return obj.strftime('%Y-%m-%d')
|
||||
elif isinstance(obj, datetime.timedelta):
|
||||
return str(obj)
|
||||
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
||||
return obj.value
|
||||
elif isinstance(obj, set):
|
||||
return list(obj)
|
||||
elif issubclass(type(obj), BaseType):
|
||||
if not self._with_type:
|
||||
return obj.to_dict()
|
||||
else:
|
||||
return obj.to_dict_with_type()
|
||||
elif isinstance(obj, type):
|
||||
return obj.__name__
|
||||
else:
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
def rag_uuid():
|
||||
return uuid.uuid1().hex
|
||||
|
||||
|
||||
def string_to_bytes(string):
|
||||
return string if isinstance(
|
||||
string, bytes) else string.encode(encoding="utf-8")
|
||||
|
||||
|
||||
def bytes_to_string(byte):
|
||||
return byte.decode(encoding="utf-8")
|
||||
|
||||
|
||||
def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||
dest = json.dumps(
|
||||
src,
|
||||
indent=indent,
|
||||
cls=CustomJSONEncoder,
|
||||
with_type=with_type)
|
||||
if byte:
|
||||
dest = string_to_bytes(dest)
|
||||
return dest
|
||||
|
||||
|
||||
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
||||
if isinstance(src, bytes):
|
||||
src = bytes_to_string(src)
|
||||
return json.loads(src, object_hook=object_hook,
|
||||
object_pairs_hook=object_pairs_hook)
|
||||
from .common import string_to_bytes
|
||||
|
||||
|
||||
def current_timestamp():
|
||||
@ -213,45 +46,6 @@ def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
|
||||
return time_stamp
|
||||
|
||||
|
||||
def serialize_b64(src, to_str=False):
|
||||
dest = base64.b64encode(pickle.dumps(src))
|
||||
if not to_str:
|
||||
return dest
|
||||
else:
|
||||
return bytes_to_string(dest)
|
||||
|
||||
|
||||
def deserialize_b64(src):
|
||||
src = base64.b64decode(
|
||||
string_to_bytes(src) if isinstance(
|
||||
src, str) else src)
|
||||
if use_deserialize_safe_module:
|
||||
return restricted_loads(src)
|
||||
return pickle.loads(src)
|
||||
|
||||
|
||||
safe_module = {
|
||||
'numpy',
|
||||
'rag_flow'
|
||||
}
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
import importlib
|
||||
if module.split('.')[0] in safe_module:
|
||||
_module = importlib.import_module(module)
|
||||
return getattr(_module, name)
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||
(module, name))
|
||||
|
||||
|
||||
def restricted_loads(src):
|
||||
"""Helper function analogous to pickle.loads()."""
|
||||
return RestrictedUnpickler(io.BytesIO(src)).load()
|
||||
|
||||
|
||||
def get_lan_ip():
|
||||
if os.name != "nt":
|
||||
import fcntl
|
||||
@ -296,47 +90,6 @@ def from_dict_hook(in_dict: dict):
|
||||
return in_dict
|
||||
|
||||
|
||||
def decrypt_database_password(password):
|
||||
encrypt_password = get_base_config("encrypt_password", False)
|
||||
encrypt_module = get_base_config("encrypt_module", False)
|
||||
private_key = get_base_config("private_key", None)
|
||||
|
||||
if not password or not encrypt_password:
|
||||
return password
|
||||
|
||||
if not private_key:
|
||||
raise ValueError("No private key")
|
||||
|
||||
module_fun = encrypt_module.split("#")
|
||||
pwdecrypt_fun = getattr(
|
||||
importlib.import_module(
|
||||
module_fun[0]),
|
||||
module_fun[1])
|
||||
|
||||
return pwdecrypt_fun(private_key, password)
|
||||
|
||||
|
||||
def decrypt_database_config(
|
||||
database=None, passwd_key="password", name="database"):
|
||||
if not database:
|
||||
database = get_base_config(name, {})
|
||||
|
||||
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
||||
return database
|
||||
|
||||
|
||||
def update_config(key, value, conf_name=SERVICE_CONF):
|
||||
conf_path = conf_realpath(conf_name=conf_name)
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(
|
||||
file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||
config[key] = value
|
||||
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
||||
|
||||
|
||||
def get_uuid():
|
||||
return uuid.uuid1().hex
|
||||
|
||||
@ -375,5 +128,5 @@ def delta_seconds(date_string: str):
|
||||
return (datetime.datetime.now() - dt).total_seconds()
|
||||
|
||||
|
||||
def hash_str2int(line:str, mod: int=10 ** 8) -> int:
|
||||
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
|
||||
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||
|
||||
@ -39,6 +39,7 @@ from flask import (
|
||||
make_response,
|
||||
send_file,
|
||||
)
|
||||
from flask_login import current_user
|
||||
from flask import (
|
||||
request as flask_request,
|
||||
)
|
||||
@ -48,10 +49,13 @@ from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
from api import settings
|
||||
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
||||
from api.db import ActiveEnum
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services import UserService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
||||
from api.utils.json import CustomJSONEncoder, json_dumps
|
||||
from api.utils import get_uuid
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
@ -226,6 +230,18 @@ def not_allowed_parameters(*params):
|
||||
return decorator
|
||||
|
||||
|
||||
def active_required(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
user_id = current_user.id
|
||||
usr = UserService.filter_by_id(user_id)
|
||||
# check is_active
|
||||
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
|
||||
return get_json_result(code=settings.RetCode.FORBIDDEN, message="User isn't active, please activate first.")
|
||||
return f(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_localhost(ip):
|
||||
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
|
||||
|
||||
@ -643,6 +659,16 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
|
||||
return transformed_data
|
||||
|
||||
|
||||
def group_by(list_of_dict, key):
|
||||
res = {}
|
||||
for item in list_of_dict:
|
||||
if item[key] in res.keys():
|
||||
res[item[key]].append(item)
|
||||
else:
|
||||
res[item[key]] = [item]
|
||||
return res
|
||||
|
||||
|
||||
def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
|
||||
results = {}
|
||||
tool_call_sessions = []
|
||||
@ -679,7 +705,9 @@ TimeoutException = Union[Type[BaseException], BaseException]
|
||||
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
|
||||
|
||||
|
||||
def timeout(seconds: float | int = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None):
|
||||
def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None):
|
||||
if isinstance(seconds, str):
|
||||
seconds = float(seconds)
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
||||
@ -1,3 +1,56 @@
|
||||
import base64
|
||||
import logging
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
|
||||
from PIL import Image
|
||||
|
||||
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
|
||||
test_image = base64.b64decode(test_image_base64)
|
||||
test_image = base64.b64decode(test_image_base64)
|
||||
|
||||
|
||||
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
|
||||
import logging
|
||||
from io import BytesIO
|
||||
import trio
|
||||
from rag.svr.task_executor import minio_limiter
|
||||
if not d.get("image"):
|
||||
return
|
||||
|
||||
with BytesIO() as output_buffer:
|
||||
if isinstance(d["image"], bytes):
|
||||
output_buffer.write(d["image"])
|
||||
output_buffer.seek(0)
|
||||
else:
|
||||
# If the image is in RGBA mode, convert it to RGB mode before saving it in JPEG format.
|
||||
if d["image"].mode in ("RGBA", "P"):
|
||||
converted_image = d["image"].convert("RGB")
|
||||
d["image"] = converted_image
|
||||
try:
|
||||
d["image"].save(output_buffer, format='JPEG')
|
||||
except OSError as e:
|
||||
logging.warning(
|
||||
"Saving image exception, ignore: {}".format(str(e)))
|
||||
|
||||
async with minio_limiter:
|
||||
await trio.to_thread.run_sync(lambda: storage_put_func(bucket=bucket, fnm=objname, binary=output_buffer.getvalue()))
|
||||
d["img_id"] = f"{bucket}-{objname}"
|
||||
if not isinstance(d["image"], bytes):
|
||||
d["image"].close()
|
||||
del d["image"] # Remove image reference
|
||||
|
||||
|
||||
def id2image(image_id:str|None, storage_get_func: partial):
|
||||
if not image_id:
|
||||
return
|
||||
arr = image_id.split("-")
|
||||
if len(arr) != 2:
|
||||
return
|
||||
bkt, nm = image_id.split("-")
|
||||
try:
|
||||
blob = storage_get_func(bucket=bkt, filename=nm)
|
||||
if not blob:
|
||||
return
|
||||
return Image.open(BytesIO(blob))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
23
api/utils/common.py
Normal file
23
api/utils/common.py
Normal file
@ -0,0 +1,23 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
def string_to_bytes(string):
|
||||
return string if isinstance(
|
||||
string, bytes) else string.encode(encoding="utf-8")
|
||||
|
||||
|
||||
def bytes_to_string(byte):
|
||||
return byte.decode(encoding="utf-8")
|
||||
179
api/utils/configs.py
Normal file
179
api/utils/configs.py
Normal file
@ -0,0 +1,179 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import io
|
||||
import copy
|
||||
import logging
|
||||
import base64
|
||||
import pickle
|
||||
import importlib
|
||||
|
||||
from api.utils import file_utils
|
||||
from filelock import FileLock
|
||||
from api.utils.common import bytes_to_string, string_to_bytes
|
||||
from api.constants import SERVICE_CONF
|
||||
|
||||
|
||||
def conf_realpath(conf_name):
|
||||
conf_path = f"conf/{conf_name}"
|
||||
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
|
||||
def read_config(conf_name=SERVICE_CONF):
|
||||
local_config = {}
|
||||
local_path = conf_realpath(f'local.{conf_name}')
|
||||
|
||||
# load local config file
|
||||
if os.path.exists(local_path):
|
||||
local_config = file_utils.load_yaml_conf(local_path)
|
||||
if not isinstance(local_config, dict):
|
||||
raise ValueError(f'Invalid config file: "{local_path}".')
|
||||
|
||||
global_config_path = conf_realpath(conf_name)
|
||||
global_config = file_utils.load_yaml_conf(global_config_path)
|
||||
|
||||
if not isinstance(global_config, dict):
|
||||
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
||||
|
||||
global_config.update(local_config)
|
||||
return global_config
|
||||
|
||||
|
||||
CONFIGS = read_config()
|
||||
|
||||
|
||||
def show_configs():
|
||||
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
||||
for k, v in CONFIGS.items():
|
||||
if isinstance(v, dict):
|
||||
if "password" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["password"] = "*" * 8
|
||||
if "access_key" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["access_key"] = "*" * 8
|
||||
if "secret_key" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["secret_key"] = "*" * 8
|
||||
if "secret" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["secret"] = "*" * 8
|
||||
if "sas_token" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["sas_token"] = "*" * 8
|
||||
if "oauth" in k:
|
||||
v = copy.deepcopy(v)
|
||||
for key, val in v.items():
|
||||
if "client_secret" in val:
|
||||
val["client_secret"] = "*" * 8
|
||||
if "authentication" in k:
|
||||
v = copy.deepcopy(v)
|
||||
for key, val in v.items():
|
||||
if "http_secret_key" in val:
|
||||
val["http_secret_key"] = "*" * 8
|
||||
msg += f"\n\t{k}: {v}"
|
||||
logging.info(msg)
|
||||
|
||||
|
||||
def get_base_config(key, default=None):
|
||||
if key is None:
|
||||
return None
|
||||
if default is None:
|
||||
default = os.environ.get(key.upper())
|
||||
return CONFIGS.get(key, default)
|
||||
|
||||
|
||||
def decrypt_database_password(password):
|
||||
encrypt_password = get_base_config("encrypt_password", False)
|
||||
encrypt_module = get_base_config("encrypt_module", False)
|
||||
private_key = get_base_config("private_key", None)
|
||||
|
||||
if not password or not encrypt_password:
|
||||
return password
|
||||
|
||||
if not private_key:
|
||||
raise ValueError("No private key")
|
||||
|
||||
module_fun = encrypt_module.split("#")
|
||||
pwdecrypt_fun = getattr(
|
||||
importlib.import_module(
|
||||
module_fun[0]),
|
||||
module_fun[1])
|
||||
|
||||
return pwdecrypt_fun(private_key, password)
|
||||
|
||||
|
||||
def decrypt_database_config(
|
||||
database=None, passwd_key="password", name="database"):
|
||||
if not database:
|
||||
database = get_base_config(name, {})
|
||||
|
||||
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
||||
return database
|
||||
|
||||
|
||||
def update_config(key, value, conf_name=SERVICE_CONF):
|
||||
conf_path = conf_realpath(conf_name=conf_name)
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(
|
||||
file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||
config[key] = value
|
||||
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
||||
|
||||
|
||||
safe_module = {
|
||||
'numpy',
|
||||
'rag_flow'
|
||||
}
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
import importlib
|
||||
if module.split('.')[0] in safe_module:
|
||||
_module = importlib.import_module(module)
|
||||
return getattr(_module, name)
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||
(module, name))
|
||||
|
||||
|
||||
def restricted_loads(src):
|
||||
"""Helper function analogous to pickle.loads()."""
|
||||
return RestrictedUnpickler(io.BytesIO(src)).load()
|
||||
|
||||
|
||||
def serialize_b64(src, to_str=False):
|
||||
dest = base64.b64encode(pickle.dumps(src))
|
||||
if not to_str:
|
||||
return dest
|
||||
else:
|
||||
return bytes_to_string(dest)
|
||||
|
||||
|
||||
def deserialize_b64(src):
|
||||
src = base64.b64decode(
|
||||
string_to_bytes(src) if isinstance(
|
||||
src, str) else src)
|
||||
use_deserialize_safe_module = get_base_config(
|
||||
'use_deserialize_safe_module', False)
|
||||
if use_deserialize_safe_module:
|
||||
return restricted_loads(src)
|
||||
return pickle.loads(src)
|
||||
@ -23,6 +23,9 @@ from api.utils import file_utils
|
||||
|
||||
|
||||
def crypt(line):
|
||||
"""
|
||||
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
|
||||
"""
|
||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
|
||||
@ -155,7 +155,7 @@ def filename_type(filename):
|
||||
if re.match(r".*\.pdf$", filename):
|
||||
return FileType.PDF.value
|
||||
|
||||
if re.match(r".*\.(eml|doc|docx|ppt|pptx|yml|xml|htm|json|jsonl|ldjson|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
|
||||
if re.match(r".*\.(msg|eml|doc|docx|ppt|pptx|yml|xml|htm|json|jsonl|ldjson|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
|
||||
return FileType.DOC.value
|
||||
|
||||
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus)$", filename):
|
||||
|
||||
104
api/utils/health.py
Normal file
104
api/utils/health.py
Normal file
@ -0,0 +1,104 @@
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from api import settings
|
||||
from api.db.db_models import DB
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
def _ok_nok(ok: bool) -> str:
|
||||
return "ok" if ok else "nok"
|
||||
|
||||
|
||||
def check_db() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
# lightweight probe; works for MySQL/Postgres
|
||||
DB.execute_sql("SELECT 1")
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def check_redis() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
ok = bool(REDIS_CONN.health())
|
||||
return ok, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def check_doc_engine() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
meta = settings.docStoreConn.health()
|
||||
# treat any successful call as ok
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def check_storage() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
STORAGE_IMPL.health()
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def check_chat() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
cfg = getattr(settings, "CHAT_CFG", None)
|
||||
ok = bool(cfg and cfg.get("factory"))
|
||||
return ok, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def run_health_checks() -> tuple[dict, bool]:
|
||||
result: dict[str, str | dict] = {}
|
||||
|
||||
db_ok, db_meta = check_db()
|
||||
chat_ok, chat_meta = check_chat()
|
||||
|
||||
result["db"] = _ok_nok(db_ok)
|
||||
if not db_ok:
|
||||
result.setdefault("_meta", {})["db"] = db_meta
|
||||
|
||||
result["chat"] = _ok_nok(chat_ok)
|
||||
if not chat_ok:
|
||||
result.setdefault("_meta", {})["chat"] = chat_meta
|
||||
|
||||
# Optional probes (do not change minimal contract but exposed for observability)
|
||||
try:
|
||||
redis_ok, redis_meta = check_redis()
|
||||
result["redis"] = _ok_nok(redis_ok)
|
||||
if not redis_ok:
|
||||
result.setdefault("_meta", {})["redis"] = redis_meta
|
||||
except Exception:
|
||||
result["redis"] = "nok"
|
||||
|
||||
try:
|
||||
doc_ok, doc_meta = check_doc_engine()
|
||||
result["doc_engine"] = _ok_nok(doc_ok)
|
||||
if not doc_ok:
|
||||
result.setdefault("_meta", {})["doc_engine"] = doc_meta
|
||||
except Exception:
|
||||
result["doc_engine"] = "nok"
|
||||
|
||||
try:
|
||||
sto_ok, sto_meta = check_storage()
|
||||
result["storage"] = _ok_nok(sto_ok)
|
||||
if not sto_ok:
|
||||
result.setdefault("_meta", {})["storage"] = sto_meta
|
||||
except Exception:
|
||||
result["storage"] = "nok"
|
||||
|
||||
all_ok = (result.get("db") == "ok") and (result.get("chat") == "ok")
|
||||
result["status"] = "ok" if all_ok else "nok"
|
||||
return result, all_ok
|
||||
|
||||
|
||||
78
api/utils/json.py
Normal file
78
api/utils/json.py
Normal file
@ -0,0 +1,78 @@
|
||||
import datetime
|
||||
import json
|
||||
from enum import Enum, IntEnum
|
||||
from api.utils.common import string_to_bytes, bytes_to_string
|
||||
|
||||
|
||||
class BaseType:
|
||||
def to_dict(self):
|
||||
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
||||
|
||||
def to_dict_with_type(self):
|
||||
def _dict(obj):
|
||||
module = None
|
||||
if issubclass(obj.__class__, BaseType):
|
||||
data = {}
|
||||
for attr, v in obj.__dict__.items():
|
||||
k = attr.lstrip("_")
|
||||
data[k] = _dict(v)
|
||||
module = obj.__module__
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
data = []
|
||||
for i, vv in enumerate(obj):
|
||||
data.append(_dict(vv))
|
||||
elif isinstance(obj, dict):
|
||||
data = {}
|
||||
for _k, vv in obj.items():
|
||||
data[_k] = _dict(vv)
|
||||
else:
|
||||
data = obj
|
||||
return {"type": obj.__class__.__name__,
|
||||
"data": data, "module": module}
|
||||
|
||||
return _dict(self)
|
||||
|
||||
|
||||
class CustomJSONEncoder(json.JSONEncoder):
|
||||
def __init__(self, **kwargs):
|
||||
self._with_type = kwargs.pop("with_type", False)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime.datetime):
|
||||
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||
elif isinstance(obj, datetime.date):
|
||||
return obj.strftime('%Y-%m-%d')
|
||||
elif isinstance(obj, datetime.timedelta):
|
||||
return str(obj)
|
||||
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
||||
return obj.value
|
||||
elif isinstance(obj, set):
|
||||
return list(obj)
|
||||
elif issubclass(type(obj), BaseType):
|
||||
if not self._with_type:
|
||||
return obj.to_dict()
|
||||
else:
|
||||
return obj.to_dict_with_type()
|
||||
elif isinstance(obj, type):
|
||||
return obj.__name__
|
||||
else:
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||
dest = json.dumps(
|
||||
src,
|
||||
indent=indent,
|
||||
cls=CustomJSONEncoder,
|
||||
with_type=with_type)
|
||||
if byte:
|
||||
dest = string_to_bytes(dest)
|
||||
return dest
|
||||
|
||||
|
||||
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
||||
if isinstance(src, bytes):
|
||||
src = bytes_to_string(src)
|
||||
return json.loads(src, object_hook=object_hook,
|
||||
object_pairs_hook=object_pairs_hook)
|
||||
@ -1075,11 +1075,10 @@ class RAGFlowPdfParser:
|
||||
def insert_table_figures(tbls_or_figs, layout_type):
|
||||
def min_rectangle_distance(rect1, rect2):
|
||||
import math
|
||||
|
||||
pn1, left1, right1, top1, bottom1 = rect1
|
||||
pn2, left2, right2, top2, bottom2 = rect2
|
||||
if right1 >= left2 and right2 >= left1 and bottom1 >= top2 and bottom2 >= top1:
|
||||
return 0 + (pn1 - pn2) * 10000
|
||||
return 0
|
||||
if right1 < left2:
|
||||
dx = left2 - right1
|
||||
elif right2 < left1:
|
||||
@ -1092,20 +1091,27 @@ class RAGFlowPdfParser:
|
||||
dy = top1 - bottom2
|
||||
else:
|
||||
dy = 0
|
||||
return math.sqrt(dx * dx + dy * dy) + (pn1 - pn2) * 10000
|
||||
return math.sqrt(dx*dx + dy*dy)# + (pn2-pn1)*10000
|
||||
|
||||
for (img, txt), poss in tbls_or_figs:
|
||||
bboxes = [(i, (b["page_number"], b["x0"], b["x1"], b["top"], b["bottom"])) for i, b in enumerate(self.boxes)]
|
||||
dists = [(min_rectangle_distance((pn, left, right, top, bott), rect), i) for i, rect in bboxes for pn, left, right, top, bott in poss]
|
||||
dists = [(min_rectangle_distance((pn, left, right, top+self.page_cum_height[pn], bott+self.page_cum_height[pn]), rect),i) for i, rect in bboxes for pn, left, right, top, bott in poss]
|
||||
min_i = np.argmin(dists, axis=0)[0]
|
||||
min_i, rect = bboxes[dists[min_i][-1]]
|
||||
if isinstance(txt, list):
|
||||
txt = "\n".join(txt)
|
||||
self.boxes.insert(min_i, {"page_number": rect[0], "x0": rect[1], "x1": rect[2], "top": rect[3], "bottom": rect[4], "layout_type": layout_type, "text": txt, "image": img})
|
||||
pn, left, right, top, bott = poss[0]
|
||||
if self.boxes[min_i]["bottom"] < top+self.page_cum_height[pn]:
|
||||
min_i += 1
|
||||
self.boxes.insert(min_i, {
|
||||
"page_number": pn+1, "x0": left, "x1": right, "top": top+self.page_cum_height[pn], "bottom": bott+self.page_cum_height[pn], "layout_type": layout_type, "text": txt, "image": img,
|
||||
"positions": [[pn+1, int(left), int(right), int(top), int(bott)]]
|
||||
})
|
||||
|
||||
for b in self.boxes:
|
||||
b["position_tag"] = self._line_tag(b, zoomin)
|
||||
b["image"] = self.crop(b["position_tag"], zoomin)
|
||||
b["positions"] = [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(b["position_tag"])]
|
||||
|
||||
insert_table_figures(tbls, "table")
|
||||
insert_table_figures(figs, "figure")
|
||||
@ -1123,7 +1129,7 @@ class RAGFlowPdfParser:
|
||||
for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt):
|
||||
pn, left, right, top, bottom = tag.strip("#").strip("@").split("\t")
|
||||
left, right, top, bottom = float(left), float(right), float(top), float(bottom)
|
||||
poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom))
|
||||
poss.append(([int(p) - 1 for p in pn.split("-")], int(left), int(right), int(top), int(bottom)))
|
||||
return poss
|
||||
|
||||
def crop(self, text, ZM=3, need_position=False):
|
||||
|
||||
@ -350,7 +350,7 @@ class TextRecognizer:
|
||||
|
||||
def close(self):
|
||||
# close session and release manually
|
||||
logging.info('Close TextRecognizer.')
|
||||
logging.info('Close text recognizer.')
|
||||
if hasattr(self, "predictor"):
|
||||
del self.predictor
|
||||
gc.collect()
|
||||
@ -490,7 +490,7 @@ class TextDetector:
|
||||
return dt_boxes
|
||||
|
||||
def close(self):
|
||||
logging.info("Close TextDetector.")
|
||||
logging.info("Close text detector.")
|
||||
if hasattr(self, "predictor"):
|
||||
del self.predictor
|
||||
gc.collect()
|
||||
|
||||
@ -65,7 +65,7 @@ A complete list of models supported by RAGFlow, which will continue to expand.
|
||||
| 01.AI | :heavy_check_mark: | | | | | |
|
||||
| DeepInfra | :heavy_check_mark: | :heavy_check_mark: | | | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| 302.AI | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
| CometAPI | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
| CometAPI | :heavy_check_mark: | :heavy_check_mark: | | | | |
|
||||
|
||||
```mdx-code-block
|
||||
</APITable>
|
||||
|
||||
@ -21,6 +21,7 @@ import networkx as nx
|
||||
import trio
|
||||
|
||||
from api import settings
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import timeout
|
||||
from graphrag.entity_resolution import EntityResolution
|
||||
@ -54,7 +55,7 @@ async def run_graphrag(
|
||||
start = trio.current_time()
|
||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||
chunks = []
|
||||
for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]):
|
||||
for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||
chunks.append(d["content_with_weight"])
|
||||
|
||||
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
||||
@ -125,6 +126,212 @@ async def run_graphrag(
|
||||
return
|
||||
|
||||
|
||||
async def run_graphrag_for_kb(
|
||||
row: dict,
|
||||
doc_ids: list[str],
|
||||
language: str,
|
||||
kb_parser_config: dict,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
*,
|
||||
with_resolution: bool = True,
|
||||
with_community: bool = True,
|
||||
max_parallel_docs: int = 4,
|
||||
) -> dict:
|
||||
tenant_id, kb_id = row["tenant_id"], row["kb_id"]
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
start = trio.current_time()
|
||||
fields_for_chunks = ["content_with_weight", "doc_id"]
|
||||
|
||||
if not doc_ids:
|
||||
logging.info(f"Fetching all docs for {kb_id}")
|
||||
docs, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
doc_ids = [doc["id"] for doc in docs]
|
||||
|
||||
doc_ids = list(dict.fromkeys(doc_ids))
|
||||
if not doc_ids:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} has no processable doc_id.")
|
||||
return {"ok_docs": [], "failed_docs": [], "total_docs": 0, "total_chunks": 0, "seconds": 0.0}
|
||||
|
||||
def load_doc_chunks(doc_id: str) -> list[str]:
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
for d in settings.retrievaler.chunk_list(
|
||||
doc_id,
|
||||
tenant_id,
|
||||
[kb_id],
|
||||
fields=fields_for_chunks,
|
||||
sort_by_position=True,
|
||||
):
|
||||
content = d["content_with_weight"]
|
||||
if num_tokens_from_string(current_chunk + content) < 1024:
|
||||
current_chunk += content
|
||||
else:
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = content
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
all_doc_chunks: dict[str, list[str]] = {}
|
||||
total_chunks = 0
|
||||
for doc_id in doc_ids:
|
||||
chunks = load_doc_chunks(doc_id)
|
||||
all_doc_chunks[doc_id] = chunks
|
||||
total_chunks += len(chunks)
|
||||
|
||||
if total_chunks == 0:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
|
||||
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
|
||||
|
||||
semaphore = trio.Semaphore(max_parallel_docs)
|
||||
|
||||
subgraphs: dict[str, object] = {}
|
||||
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
|
||||
|
||||
async def build_one(doc_id: str):
|
||||
chunks = all_doc_chunks.get(doc_id, [])
|
||||
if not chunks:
|
||||
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
|
||||
return
|
||||
|
||||
kg_extractor = LightKGExt if ("method" not in kb_parser_config.get("graphrag", {}) or kb_parser_config["graphrag"]["method"] != "general") else GeneralKGExt
|
||||
|
||||
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
||||
|
||||
async with semaphore:
|
||||
try:
|
||||
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
|
||||
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
|
||||
with trio.fail_after(deadline):
|
||||
sg = await generate_subgraph(
|
||||
kg_extractor,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chunks,
|
||||
language,
|
||||
kb_parser_config.get("graphrag", {}).get("entity_types", []),
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
if sg:
|
||||
subgraphs[doc_id] = sg
|
||||
callback(msg=f"{msg} done")
|
||||
else:
|
||||
failed_docs.append((doc_id, "subgraph is empty"))
|
||||
callback(msg=f"{msg} empty")
|
||||
except Exception as e:
|
||||
failed_docs.append((doc_id, repr(e)))
|
||||
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for doc_id in doc_ids:
|
||||
nursery.start_soon(build_one, doc_id)
|
||||
|
||||
ok_docs = [d for d in doc_ids if d in subgraphs]
|
||||
if not ok_docs:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
|
||||
now = trio.current_time()
|
||||
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||
|
||||
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
|
||||
await kb_lock.spin_acquire()
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
|
||||
|
||||
try:
|
||||
union_nodes: set = set()
|
||||
final_graph = None
|
||||
|
||||
for doc_id in ok_docs:
|
||||
sg = subgraphs[doc_id]
|
||||
union_nodes.update(set(sg.nodes()))
|
||||
|
||||
new_graph = await merge_subgraph(
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
sg,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
if new_graph is not None:
|
||||
final_graph = new_graph
|
||||
|
||||
if final_graph is None:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished (no in-memory graph returned).")
|
||||
else:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished, graph ready.")
|
||||
finally:
|
||||
kb_lock.release()
|
||||
|
||||
if not with_resolution and not with_community:
|
||||
now = trio.current_time()
|
||||
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
|
||||
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||
|
||||
await kb_lock.spin_acquire()
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
|
||||
|
||||
try:
|
||||
subgraph_nodes = set()
|
||||
for sg in subgraphs.values():
|
||||
subgraph_nodes.update(set(sg.nodes()))
|
||||
|
||||
if with_resolution:
|
||||
await resolve_entities(
|
||||
final_graph,
|
||||
subgraph_nodes,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
None,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
|
||||
if with_community:
|
||||
await extract_community(
|
||||
final_graph,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
None,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
finally:
|
||||
kb_lock.release()
|
||||
|
||||
now = trio.current_time()
|
||||
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
|
||||
return {
|
||||
"ok_docs": ok_docs,
|
||||
"failed_docs": failed_docs, # [(doc_id, error), ...]
|
||||
"total_docs": len(doc_ids),
|
||||
"total_chunks": total_chunks,
|
||||
"seconds": now - start,
|
||||
}
|
||||
|
||||
|
||||
async def generate_subgraph(
|
||||
extractor: Extractor,
|
||||
tenant_id: str,
|
||||
|
||||
@ -34,6 +34,7 @@ dependencies = [
|
||||
"elastic-transport==8.12.0",
|
||||
"elasticsearch==8.12.1",
|
||||
"elasticsearch-dsl==8.12.0",
|
||||
"extract-msg>=0.39.0",
|
||||
"filelock==3.15.4",
|
||||
"flask==3.0.3",
|
||||
"flask-cors==5.0.0",
|
||||
@ -157,6 +158,9 @@ test = [
|
||||
"requests-toolbelt>=1.0.0",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
url = "https://mirrors.aliyun.com/pypi/simple"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = [
|
||||
'agent',
|
||||
@ -170,9 +174,6 @@ packages = [
|
||||
'sdk.python.ragflow_sdk',
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
url = "https://mirrors.aliyun.com/pypi/simple"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 200
|
||||
exclude = [".venv", "rag/svr/discord_svr.py"]
|
||||
|
||||
@ -78,7 +78,7 @@ def chunk(
|
||||
_add_content(msg, msg.get_content_type())
|
||||
|
||||
sections = TxtParser.parser_txt("\n".join(text_txt)) + [
|
||||
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt)) if line
|
||||
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line
|
||||
]
|
||||
|
||||
st = timer()
|
||||
|
||||
@ -18,9 +18,7 @@ import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import trio
|
||||
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
@ -36,9 +34,9 @@ class ProcessBase(ComponentBase):
|
||||
def __init__(self, pipeline, id, param: ProcessParamBase):
|
||||
super().__init__(pipeline, id, param)
|
||||
if hasattr(self._canvas, "callback"):
|
||||
self.callback = partial(self._canvas.callback, self.component_name)
|
||||
self.callback = partial(self._canvas.callback, id)
|
||||
else:
|
||||
self.callback = partial(lambda *args, **kwargs: None, self.component_name)
|
||||
self.callback = partial(lambda *args, **kwargs: None, id)
|
||||
|
||||
async def invoke(self, **kwargs) -> dict[str, Any]:
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
@ -58,6 +56,6 @@ class ProcessBase(ComponentBase):
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
|
||||
async def _invoke(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -12,18 +12,19 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import random
|
||||
|
||||
import trio
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from graphrag.utils import chat_limiter, get_llm_cache, set_llm_cache
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.chunker.schema import ChunkerFromUpstream
|
||||
from rag.nlp import naive_merge, naive_merge_with_images
|
||||
from rag.prompts.generator import keyword_extraction, question_proposal
|
||||
from rag.nlp import naive_merge, naive_merge_with_images, concat_img
|
||||
from rag.prompts.prompts import keyword_extraction, question_proposal, detect_table_of_contents, \
|
||||
table_of_contents_index, toc_transformer
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class ChunkerParam(ProcessParamBase):
|
||||
@ -43,6 +44,7 @@ class ChunkerParam(ProcessParamBase):
|
||||
"paper",
|
||||
"laws",
|
||||
"presentation",
|
||||
"toc" # table of contents
|
||||
# Other
|
||||
# "Tag" # TODO: Other method
|
||||
]
|
||||
@ -54,7 +56,7 @@ class ChunkerParam(ProcessParamBase):
|
||||
self.auto_keywords = 0
|
||||
self.auto_questions = 0
|
||||
self.tag_sets = []
|
||||
self.llm_setting = {"llm_name": "", "lang": "Chinese"}
|
||||
self.llm_setting = {"llm_id": "", "lang": "Chinese"}
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.method.lower(), "Chunk method abnormal.", self.method_options)
|
||||
@ -142,6 +144,91 @@ class Chunker(ProcessBase):
|
||||
def _one(self, from_upstream: ChunkerFromUpstream):
|
||||
pass
|
||||
|
||||
def _toc(self, from_upstream: ChunkerFromUpstream):
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to chunk via `ToC`.")
|
||||
if from_upstream.output_format in ["markdown", "text", "html"]:
|
||||
return
|
||||
|
||||
# json
|
||||
sections, section_images, page_1024, tc_arr = [], [], [""], [0]
|
||||
for o in from_upstream.json_result or []:
|
||||
txt = o.get("text", "")
|
||||
tc = num_tokens_from_string(txt)
|
||||
page_1024[-1] += "\n" + txt
|
||||
tc_arr[-1] += tc
|
||||
if tc_arr[-1] > 1024:
|
||||
page_1024.append("")
|
||||
tc_arr.append(0)
|
||||
sections.append((o.get("text", ""), o.get("position_tag", "")))
|
||||
section_images.append(o.get("image"))
|
||||
print(len(sections), o)
|
||||
|
||||
llm_setting = self._param.llm_setting
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_id"], lang=llm_setting["lang"])
|
||||
self.callback(random.randint(5, 15) / 100.0, "Start to detect table of contents...")
|
||||
toc_secs = detect_table_of_contents(page_1024, chat_mdl)
|
||||
if toc_secs:
|
||||
self.callback(random.randint(25, 35) / 100.0, "Start to extract table of contents...")
|
||||
toc_arr = toc_transformer(toc_secs, chat_mdl)
|
||||
toc_arr = [it for it in toc_arr if it.get("structure")]
|
||||
print(json.dumps(toc_arr, ensure_ascii=False, indent=2), flush=True)
|
||||
self.callback(random.randint(35, 75) / 100.0, "Start to link table of contents...")
|
||||
toc_arr = table_of_contents_index(toc_arr, [t for t,_ in sections], chat_mdl)
|
||||
for i in range(len(toc_arr)-1):
|
||||
if not toc_arr[i].get("indices"):
|
||||
continue
|
||||
|
||||
for j in range(i+1, len(toc_arr)):
|
||||
if toc_arr[j].get("indices"):
|
||||
if toc_arr[j]["indices"][0] - toc_arr[i]["indices"][-1] > 1:
|
||||
toc_arr[i]["indices"].extend([x for x in range(toc_arr[i]["indices"][-1]+1, toc_arr[j]["indices"][0])])
|
||||
break
|
||||
# put all sections ahead of toc_arr[0] into it
|
||||
# for i in range(len(toc_arr)):
|
||||
# if toc_arr[i].get("indices") and toc_arr[i]["indices"][0]:
|
||||
# toc_arr[i]["indices"] = [x for x in range(toc_arr[i]["indices"][-1]+1)]
|
||||
# break
|
||||
# put all sections after toc_arr[-1] into it
|
||||
for i in range(len(toc_arr)-1, -1, -1):
|
||||
if toc_arr[i].get("indices") and toc_arr[i]["indices"][-1]:
|
||||
toc_arr[i]["indices"] = [x for x in range(toc_arr[i]["indices"][0], len(sections))]
|
||||
break
|
||||
print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n", json.dumps(toc_arr, ensure_ascii=False, indent=2), flush=True)
|
||||
|
||||
chunks, images = [], []
|
||||
for it in toc_arr:
|
||||
if not it.get("indices"):
|
||||
continue
|
||||
txt = ""
|
||||
img = None
|
||||
for i in it["indices"]:
|
||||
idx = i
|
||||
txt += "\n" + sections[idx][0] + "\t" + sections[idx][1]
|
||||
if img and section_images[idx]:
|
||||
img = concat_img(img, section_images[idx])
|
||||
elif section_images[idx]:
|
||||
img = section_images[idx]
|
||||
|
||||
it["indices"] = []
|
||||
if not txt:
|
||||
continue
|
||||
it["indices"] = [len(chunks)]
|
||||
print(it, "KKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK\n", txt)
|
||||
chunks.append(txt)
|
||||
images.append(img)
|
||||
self.callback(1, "Done")
|
||||
return [
|
||||
{
|
||||
"text": RAGFlowPdfParser.remove_tag(c),
|
||||
"image": img,
|
||||
"positions": RAGFlowPdfParser.extract_positions(c),
|
||||
}
|
||||
for c, img in zip(chunks, images)
|
||||
]
|
||||
|
||||
self.callback(message="No table of contents detected.")
|
||||
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
function_map = {
|
||||
"general": self._general,
|
||||
@ -167,7 +254,7 @@ class Chunker(ProcessBase):
|
||||
|
||||
async def auto_keywords():
|
||||
nonlocal chunks, llm_setting
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_name"], lang=llm_setting["lang"])
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_id"], lang=llm_setting["lang"])
|
||||
|
||||
async def doc_keyword_extraction(chat_mdl, ck, topn):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, ck["text"], "keywords", {"topn": topn})
|
||||
@ -184,7 +271,7 @@ class Chunker(ProcessBase):
|
||||
|
||||
async def auto_questions():
|
||||
nonlocal chunks, llm_setting
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_name"], lang=llm_setting["lang"])
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_id"], lang=llm_setting["lang"])
|
||||
|
||||
async def doc_question_proposal(chat_mdl, d, topn):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, ck["text"], "question", {"topn": topn})
|
||||
|
||||
@ -22,7 +22,7 @@ class ChunkerFromUpstream(BaseModel):
|
||||
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||
|
||||
name: str
|
||||
blob: bytes
|
||||
file: dict | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||
|
||||
|
||||
15
rag/flow/extractor/__init__.py
Normal file
15
rag/flow/extractor/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
63
rag/flow/extractor/extractor.py
Normal file
63
rag/flow/extractor/extractor.py
Normal file
@ -0,0 +1,63 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
|
||||
|
||||
class ExtractorParam(ProcessParamBase, LLMParam):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.field_name = ""
|
||||
|
||||
def check(self):
|
||||
super().check()
|
||||
self.check_empty(self.field_name, "Result Destination")
|
||||
|
||||
|
||||
class Extractor(ProcessBase, LLM):
|
||||
component_name = "Extractor"
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
self.set_output("output_format", "chunks")
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
|
||||
inputs = self.get_input_elements()
|
||||
chunks = []
|
||||
chunks_key = ""
|
||||
args = {}
|
||||
for k, v in inputs.items():
|
||||
args[k] = v["value"]
|
||||
if isinstance(args[k], list):
|
||||
chunks = deepcopy(args[k])
|
||||
chunks_key = k
|
||||
|
||||
if chunks:
|
||||
prog = 0
|
||||
for i, ck in enumerate(chunks):
|
||||
args[chunks_key] = ck["text"]
|
||||
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||
ck[self._param.field_name] = self._generate(msg)
|
||||
prog += 1./len(chunks)
|
||||
if i % (len(chunks)//100+1) == 1:
|
||||
self.callback(prog, f"{i+1} / {len(chunks)}")
|
||||
self.set_output("chunks", chunks)
|
||||
else:
|
||||
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])
|
||||
|
||||
|
||||
38
rag/flow/extractor/schema.py
Normal file
38
rag/flow/extractor/schema.py
Normal file
@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ExtractorFromUpstream(BaseModel):
|
||||
created_time: float | None = Field(default=None, alias="_created_time")
|
||||
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||
|
||||
name: str
|
||||
file: dict | None = Field(default=None)
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
|
||||
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
text_result: str | None = Field(default=None, alias="text")
|
||||
html_result: str | None = Field(default=None, alias="html")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
# def to_dict(self, *, exclude_none: bool = True) -> dict:
|
||||
# return self.model_dump(by_alias=True, exclude_none=exclude_none)
|
||||
@ -14,10 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
class FileParam(ProcessParamBase):
|
||||
@ -41,10 +38,13 @@ class File(ProcessBase):
|
||||
self.set_output("_ERROR", f"Document({self._canvas._doc_id}) not found!")
|
||||
return
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
|
||||
self.set_output("blob", STORAGE_IMPL.get(b, n))
|
||||
#b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
|
||||
#self.set_output("blob", STORAGE_IMPL.get(b, n))
|
||||
self.set_output("name", doc.name)
|
||||
else:
|
||||
file = kwargs.get("file")
|
||||
self.set_output("name", file["name"])
|
||||
self.set_output("blob", FileService.get_blob(file["created_by"], file["id"]))
|
||||
self.set_output("file", file)
|
||||
#self.set_output("blob", FileService.get_blob(file["created_by"], file["id"]))
|
||||
|
||||
self.callback(1, "File fetched.")
|
||||
|
||||
15
rag/flow/hierarchical_merger/__init__.py
Normal file
15
rag/flow/hierarchical_merger/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
186
rag/flow/hierarchical_merger/hierarchical_merger.py
Normal file
186
rag/flow/hierarchical_merger/hierarchical_merger.py
Normal file
@ -0,0 +1,186 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
|
||||
from api.utils import get_uuid
|
||||
from api.utils.base64_image import id2image, image2id
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
|
||||
from rag.nlp import concat_img
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
class HierarchicalMergerParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.levels = []
|
||||
self.hierarchy = None
|
||||
|
||||
def check(self):
|
||||
self.check_empty(self.levels, "Hierarchical setups.")
|
||||
self.check_empty(self.hierarchy, "Hierarchy number.")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {}
|
||||
|
||||
|
||||
class HierarchicalMerger(ProcessBase):
|
||||
component_name = "HierarchicalMerger"
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
try:
|
||||
from_upstream = HierarchicalMergerFromUpstream.model_validate(kwargs)
|
||||
except Exception as e:
|
||||
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||
return
|
||||
|
||||
self.set_output("output_format", "chunks")
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to merge hierarchically.")
|
||||
if from_upstream.output_format in ["markdown", "text", "html"]:
|
||||
if from_upstream.output_format == "markdown":
|
||||
payload = from_upstream.markdown_result
|
||||
elif from_upstream.output_format == "text":
|
||||
payload = from_upstream.text_result
|
||||
else: # == "html"
|
||||
payload = from_upstream.html_result
|
||||
|
||||
if not payload:
|
||||
payload = ""
|
||||
|
||||
lines = [ln for ln in payload.split("\n") if ln]
|
||||
else:
|
||||
arr = from_upstream.chunks if from_upstream.output_format == "chunks" else from_upstream.json_result
|
||||
lines = [o.get("text", "") for o in arr]
|
||||
sections, section_images = [], []
|
||||
for o in arr or []:
|
||||
sections.append((o.get("text", ""), o.get("position_tag", "")))
|
||||
section_images.append(o.get("img_id"))
|
||||
|
||||
matches = []
|
||||
for txt in lines:
|
||||
good = False
|
||||
for lvl, regs in enumerate(self._param.levels):
|
||||
for reg in regs:
|
||||
if re.search(reg, txt):
|
||||
matches.append(lvl)
|
||||
good = True
|
||||
break
|
||||
if good:
|
||||
break
|
||||
if not good:
|
||||
matches.append(len(self._param.levels))
|
||||
assert len(matches) == len(lines), f"{len(matches)} vs. {len(lines)}"
|
||||
|
||||
root = {
|
||||
"level": -1,
|
||||
"index": -1,
|
||||
"texts": [],
|
||||
"children": []
|
||||
}
|
||||
for i, m in enumerate(matches):
|
||||
if m == 0:
|
||||
root["children"].append({
|
||||
"level": m,
|
||||
"index": i,
|
||||
"texts": [],
|
||||
"children": []
|
||||
})
|
||||
elif m == len(self._param.levels):
|
||||
def dfs(b):
|
||||
if not b["children"]:
|
||||
b["texts"].append(i)
|
||||
else:
|
||||
dfs(b["children"][-1])
|
||||
dfs(root)
|
||||
else:
|
||||
def dfs(b):
|
||||
nonlocal m, i
|
||||
if not b["children"] or m == b["level"] + 1:
|
||||
b["children"].append({
|
||||
"level": m,
|
||||
"index": i,
|
||||
"texts": [],
|
||||
"children": []
|
||||
})
|
||||
return
|
||||
dfs(b["children"][-1])
|
||||
|
||||
dfs(root)
|
||||
|
||||
all_pathes = []
|
||||
def dfs(n, path, depth):
|
||||
nonlocal all_pathes
|
||||
if not n["children"] and path:
|
||||
all_pathes.append(path)
|
||||
|
||||
for nn in n["children"]:
|
||||
if depth < self._param.hierarchy:
|
||||
_path = deepcopy(path)
|
||||
else:
|
||||
_path = path
|
||||
_path.extend([nn["index"], *nn["texts"]])
|
||||
dfs(nn, _path, depth+1)
|
||||
|
||||
if depth == self._param.hierarchy:
|
||||
all_pathes.append(_path)
|
||||
|
||||
for i in range(len(lines)):
|
||||
print(i, lines[i])
|
||||
dfs(root, [], 0)
|
||||
|
||||
if root["texts"]:
|
||||
all_pathes.insert(0, root["texts"])
|
||||
if from_upstream.output_format in ["markdown", "text", "html"]:
|
||||
cks = []
|
||||
for path in all_pathes:
|
||||
txt = ""
|
||||
for i in path:
|
||||
txt += lines[i] + "\n"
|
||||
cks.append(txt)
|
||||
|
||||
self.set_output("chunks", [{"text": c} for c in cks if c])
|
||||
else:
|
||||
cks = []
|
||||
images = []
|
||||
for path in all_pathes:
|
||||
txt = ""
|
||||
img = None
|
||||
for i in path:
|
||||
txt += lines[i] + "\n"
|
||||
concat_img(img, id2image(section_images[i], partial(STORAGE_IMPL.get)))
|
||||
cks.append(txt)
|
||||
images.append(img)
|
||||
|
||||
cks = [
|
||||
{
|
||||
"text": RAGFlowPdfParser.remove_tag(c),
|
||||
"image": img,
|
||||
"positions": RAGFlowPdfParser.extract_positions(c),
|
||||
}
|
||||
for c, img in zip(cks, images)
|
||||
]
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in cks:
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
|
||||
self.set_output("chunks", cks)
|
||||
|
||||
self.callback(1, "Done.")
|
||||
37
rag/flow/hierarchical_merger/schema.py
Normal file
37
rag/flow/hierarchical_merger/schema.py
Normal file
@ -0,0 +1,37 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class HierarchicalMergerFromUpstream(BaseModel):
|
||||
created_time: float | None = Field(default=None, alias="_created_time")
|
||||
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||
|
||||
name: str
|
||||
file: dict | None = Field(default=None)
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "chunks"] | None = Field(default=None)
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
text_result: str | None = Field(default=None, alias="text")
|
||||
html_result: str | None = Field(default=None, alias="html")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
# def to_dict(self, *, exclude_none: bool = True) -> dict:
|
||||
# return self.model_dump(by_alias=True, exclude_none=exclude_none)
|
||||
@ -13,20 +13,28 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.utils import get_uuid
|
||||
from api.utils.base64_image import image2id
|
||||
from deepdoc.parser import ExcelParser
|
||||
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
|
||||
from rag.app.naive import Docx
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.parser.schema import ParserFromUpstream
|
||||
from rag.llm.cv_model import Base as VLM
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
class ParserParam(ProcessParamBase):
|
||||
@ -45,12 +53,14 @@ class ParserParam(ProcessParamBase):
|
||||
"word": [
|
||||
"json",
|
||||
],
|
||||
"ppt": [],
|
||||
"slides": [
|
||||
"json",
|
||||
],
|
||||
"image": [
|
||||
"text"
|
||||
],
|
||||
"email": [],
|
||||
"text": [
|
||||
"email": ["text", "json"],
|
||||
"text&markdown": [
|
||||
"text",
|
||||
"json"
|
||||
],
|
||||
@ -63,7 +73,6 @@ class ParserParam(ProcessParamBase):
|
||||
self.setups = {
|
||||
"pdf": {
|
||||
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
|
||||
"llm_id": "",
|
||||
"lang": "Chinese",
|
||||
"suffix": [
|
||||
"pdf",
|
||||
@ -85,23 +94,29 @@ class ParserParam(ProcessParamBase):
|
||||
],
|
||||
"output_format": "json",
|
||||
},
|
||||
"markdown": {
|
||||
"suffix": ["md", "markdown"],
|
||||
"text&markdown": {
|
||||
"suffix": ["md", "markdown", "mdx", "txt"],
|
||||
"output_format": "json",
|
||||
},
|
||||
"slides": {
|
||||
"suffix": [
|
||||
"pptx",
|
||||
],
|
||||
"output_format": "json",
|
||||
},
|
||||
"ppt": {},
|
||||
"image": {
|
||||
"parse_method": ["ocr", "vlm"],
|
||||
"parse_method": "ocr",
|
||||
"llm_id": "",
|
||||
"lang": "Chinese",
|
||||
"system_prompt": "",
|
||||
"suffix": ["jpg", "jpeg", "png", "gif"],
|
||||
"output_format": "json",
|
||||
"output_format": "text",
|
||||
},
|
||||
"email": {},
|
||||
"text": {
|
||||
"email": {
|
||||
"suffix": [
|
||||
"txt"
|
||||
"eml", "msg"
|
||||
],
|
||||
"fields": ["from", "to", "cc", "bcc", "date", "subject", "body", "attachments", "metadata"],
|
||||
"output_format": "json",
|
||||
},
|
||||
"audio": {
|
||||
@ -131,13 +146,10 @@ class ParserParam(ProcessParamBase):
|
||||
pdf_config = self.setups.get("pdf", {})
|
||||
if pdf_config:
|
||||
pdf_parse_method = pdf_config.get("parse_method", "")
|
||||
self.check_valid_value(pdf_parse_method.lower(), "Parse method abnormal.", ["deepdoc", "plain_text", "vlm"])
|
||||
self.check_empty(pdf_parse_method, "Parse method abnormal.")
|
||||
|
||||
if pdf_parse_method not in ["deepdoc", "plain_text"]:
|
||||
self.check_empty(pdf_config.get("llm_id"), "VLM")
|
||||
|
||||
pdf_language = pdf_config.get("lang", "")
|
||||
self.check_empty(pdf_language, "Language")
|
||||
if pdf_parse_method.lower() not in ["deepdoc", "plain_text"]:
|
||||
self.check_empty(pdf_config.get("lang", ""), "PDF VLM language")
|
||||
|
||||
pdf_output_format = pdf_config.get("output_format", "")
|
||||
self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"])
|
||||
@ -147,32 +159,38 @@ class ParserParam(ProcessParamBase):
|
||||
spreadsheet_output_format = spreadsheet_config.get("output_format", "")
|
||||
self.check_valid_value(spreadsheet_output_format, "Spreadsheet output format abnormal.", self.allowed_output_format["spreadsheet"])
|
||||
|
||||
doc_config = self.setups.get("doc", "")
|
||||
doc_config = self.setups.get("word", "")
|
||||
if doc_config:
|
||||
doc_output_format = doc_config.get("output_format", "")
|
||||
self.check_valid_value(doc_output_format, "Word processer document output format abnormal.", self.allowed_output_format["doc"])
|
||||
self.check_valid_value(doc_output_format, "Word processer document output format abnormal.", self.allowed_output_format["word"])
|
||||
|
||||
slides_config = self.setups.get("slides", "")
|
||||
if slides_config:
|
||||
slides_output_format = slides_config.get("output_format", "")
|
||||
self.check_valid_value(slides_output_format, "Slides output format abnormal.", self.allowed_output_format["slides"])
|
||||
|
||||
image_config = self.setups.get("image", "")
|
||||
if image_config:
|
||||
image_parse_method = image_config.get("parse_method", "")
|
||||
self.check_valid_value(image_parse_method.lower(), "Parse method abnormal.", ["ocr", "vlm"])
|
||||
if image_parse_method not in ["ocr"]:
|
||||
self.check_empty(image_config.get("llm_id"), "VLM")
|
||||
self.check_empty(image_config.get("lang", ""), "Image VLM language")
|
||||
|
||||
image_language = image_config.get("lang", "")
|
||||
self.check_empty(image_language, "Language")
|
||||
|
||||
text_config = self.setups.get("text", "")
|
||||
text_config = self.setups.get("text&markdown", "")
|
||||
if text_config:
|
||||
text_output_format = text_config.get("output_format", "")
|
||||
self.check_valid_value(text_output_format, "Text output format abnormal.", self.allowed_output_format["text"])
|
||||
self.check_valid_value(text_output_format, "Text output format abnormal.", self.allowed_output_format["text&markdown"])
|
||||
|
||||
audio_config = self.setups.get("audio", "")
|
||||
if audio_config:
|
||||
self.check_empty(audio_config.get("llm_id"), "VLM")
|
||||
self.check_empty(audio_config.get("llm_id"), "Audio VLM")
|
||||
audio_language = audio_config.get("lang", "")
|
||||
self.check_empty(audio_language, "Language")
|
||||
|
||||
email_config = self.setups.get("email", "")
|
||||
if email_config:
|
||||
email_output_format = email_config.get("output_format", "")
|
||||
self.check_valid_value(email_output_format, "Email output format abnormal.", self.allowed_output_format["email"])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {}
|
||||
|
||||
@ -180,21 +198,18 @@ class ParserParam(ProcessParamBase):
|
||||
class Parser(ProcessBase):
|
||||
component_name = "Parser"
|
||||
|
||||
def _pdf(self, from_upstream: ParserFromUpstream):
|
||||
def _pdf(self, name, blob):
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on a PDF.")
|
||||
|
||||
blob = from_upstream.blob
|
||||
conf = self._param.setups["pdf"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
if conf.get("parse_method") == "deepdoc":
|
||||
if conf.get("parse_method").lower() == "deepdoc":
|
||||
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
|
||||
elif conf.get("parse_method") == "plain_text":
|
||||
elif conf.get("parse_method").lower() == "plain_text":
|
||||
lines, _ = PlainParser()(blob)
|
||||
bboxes = [{"text": t} for t, _ in lines]
|
||||
else:
|
||||
assert conf.get("llm_id")
|
||||
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("llm_id"), lang=self._param.setups["pdf"].get("lang"))
|
||||
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("parse_method"), lang=self._param.setups["pdf"].get("lang"))
|
||||
lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback)
|
||||
bboxes = []
|
||||
for t, poss in lines:
|
||||
@ -214,66 +229,63 @@ class Parser(ProcessBase):
|
||||
mkdn += b.get("text", "") + "\n"
|
||||
self.set_output("markdown", mkdn)
|
||||
|
||||
def _spreadsheet(self, from_upstream: ParserFromUpstream):
|
||||
def _spreadsheet(self, name, blob):
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on a Spreadsheet.")
|
||||
|
||||
blob = from_upstream.blob
|
||||
conf = self._param.setups["spreadsheet"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
print("spreadsheet {conf=}", flush=True)
|
||||
spreadsheet_parser = ExcelParser()
|
||||
if conf.get("output_format") == "html":
|
||||
html = spreadsheet_parser.html(blob, 1000000000)
|
||||
self.set_output("html", html)
|
||||
htmls = spreadsheet_parser.html(blob, 1000000000)
|
||||
self.set_output("html", htmls[0])
|
||||
elif conf.get("output_format") == "json":
|
||||
self.set_output("json", [{"text": txt} for txt in spreadsheet_parser(blob) if txt])
|
||||
elif conf.get("output_format") == "markdown":
|
||||
self.set_output("markdown", spreadsheet_parser.markdown(blob))
|
||||
|
||||
def _word(self, from_upstream: ParserFromUpstream):
|
||||
from tika import parser as word_parser
|
||||
|
||||
def _word(self, name, blob):
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on a Word Processor Document")
|
||||
|
||||
blob = from_upstream.blob
|
||||
name = from_upstream.name
|
||||
conf = self._param.setups["word"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
print("word {conf=}", flush=True)
|
||||
doc_parsed = word_parser.from_buffer(blob)
|
||||
|
||||
sections = []
|
||||
if doc_parsed.get("content"):
|
||||
sections = doc_parsed["content"].split("\n")
|
||||
sections = [{"text": section} for section in sections if section]
|
||||
else:
|
||||
logging.warning(f"tika.parser got empty content from {name}.")
|
||||
|
||||
docx_parser = Docx()
|
||||
sections, tbls = docx_parser(name, binary=blob)
|
||||
sections = [{"text": section[0], "image": section[1]} for section in sections if section]
|
||||
sections.extend([{"text": tb, "image": None} for ((_,tb), _) in tbls])
|
||||
# json
|
||||
assert conf.get("output_format") == "json", "have to be json for doc"
|
||||
if conf.get("output_format") == "json":
|
||||
self.set_output("json", sections)
|
||||
|
||||
def _markdown(self, from_upstream: ParserFromUpstream):
|
||||
def _slides(self, name, blob):
|
||||
from deepdoc.parser.ppt_parser import RAGFlowPptParser as ppt_parser
|
||||
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on a PowerPoint Document")
|
||||
|
||||
conf = self._param.setups["slides"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
ppt_parser = ppt_parser()
|
||||
txts = ppt_parser(blob, 0, 100000, None)
|
||||
|
||||
sections = [{"text": section} for section in txts if section.strip()]
|
||||
|
||||
# json
|
||||
assert conf.get("output_format") == "json", "have to be json for ppt"
|
||||
if conf.get("output_format") == "json":
|
||||
self.set_output("json", sections)
|
||||
|
||||
def _markdown(self, name, blob):
|
||||
from functools import reduce
|
||||
|
||||
from rag.app.naive import Markdown as naive_markdown_parser
|
||||
from rag.nlp import concat_img
|
||||
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on a markdown.")
|
||||
|
||||
blob = from_upstream.blob
|
||||
name = from_upstream.name
|
||||
conf = self._param.setups["markdown"]
|
||||
conf = self._param.setups["text&markdown"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
markdown_parser = naive_markdown_parser()
|
||||
sections, tables = markdown_parser(name, blob, separate_tables=False)
|
||||
|
||||
# json
|
||||
assert conf.get("output_format") == "json", "have to be json for doc"
|
||||
if conf.get("output_format") == "json":
|
||||
json_results = []
|
||||
|
||||
@ -291,69 +303,51 @@ class Parser(ProcessBase):
|
||||
json_results.append(json_result)
|
||||
|
||||
self.set_output("json", json_results)
|
||||
|
||||
def _text(self, from_upstream: ParserFromUpstream):
|
||||
from deepdoc.parser.utils import get_text
|
||||
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on a text.")
|
||||
|
||||
blob = from_upstream.blob
|
||||
name = from_upstream.name
|
||||
conf = self._param.setups["text"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
# parse binary to text
|
||||
text_content = get_text(name, binary=blob)
|
||||
|
||||
if conf.get("output_format") == "json":
|
||||
result = [{"text": text_content}]
|
||||
self.set_output("json", result)
|
||||
else:
|
||||
result = text_content
|
||||
self.set_output("text", result)
|
||||
self.set_output("text", "\n".join([section_text for section_text, _ in sections]))
|
||||
|
||||
def _image(self, from_upstream: ParserFromUpstream):
|
||||
|
||||
def _image(self, name, blob):
|
||||
from deepdoc.vision import OCR
|
||||
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on an image.")
|
||||
|
||||
blob = from_upstream.blob
|
||||
conf = self._param.setups["image"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
img = Image.open(io.BytesIO(blob)).convert("RGB")
|
||||
lang = conf["lang"]
|
||||
|
||||
if conf["parse_method"] == "ocr":
|
||||
# use ocr, recognize chars only
|
||||
ocr = OCR()
|
||||
bxs = ocr(np.array(img)) # return boxes and recognize result
|
||||
txt = "\n".join([t[0] for _, t in bxs if t[0]])
|
||||
|
||||
else:
|
||||
lang = conf["lang"]
|
||||
# use VLM to describe the picture
|
||||
cv_model = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, llm_name=conf["llm_id"],lang=lang)
|
||||
cv_model = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, llm_name=conf["parse_method"], lang=lang)
|
||||
img_binary = io.BytesIO()
|
||||
img.save(img_binary, format="JPEG")
|
||||
img_binary.seek(0)
|
||||
txt = cv_model.describe(img_binary.read())
|
||||
|
||||
system_prompt = conf.get("system_prompt")
|
||||
if system_prompt:
|
||||
txt = cv_model.describe_with_prompt(img_binary.read(), system_prompt)
|
||||
else:
|
||||
txt = cv_model.describe(img_binary.read())
|
||||
|
||||
self.set_output("text", txt)
|
||||
|
||||
def _audio(self, from_upstream: ParserFromUpstream):
|
||||
def _audio(self, name, blob):
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on an audio.")
|
||||
|
||||
blob = from_upstream.blob
|
||||
name = from_upstream.name
|
||||
conf = self._param.setups["audio"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
lang = conf["lang"]
|
||||
_, ext = os.path.splitext(name)
|
||||
tmp_path = ""
|
||||
with tempfile.NamedTemporaryFile(suffix=ext) as tmpf:
|
||||
tmpf.write(blob)
|
||||
tmpf.flush()
|
||||
@ -364,15 +358,131 @@ class Parser(ProcessBase):
|
||||
|
||||
self.set_output("text", txt)
|
||||
|
||||
def _email(self, name, blob):
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on an email.")
|
||||
|
||||
email_content = {}
|
||||
conf = self._param.setups["email"]
|
||||
target_fields = conf["fields"]
|
||||
|
||||
_, ext = os.path.splitext(name)
|
||||
if ext == ".eml":
|
||||
# handle eml file
|
||||
from email import policy
|
||||
from email.parser import BytesParser
|
||||
|
||||
msg = BytesParser(policy=policy.default).parse(io.BytesIO(blob))
|
||||
email_content['metadata'] = {}
|
||||
# handle header info
|
||||
for header, value in msg.items():
|
||||
# get fields like from, to, cc, bcc, date, subject
|
||||
if header.lower() in target_fields:
|
||||
email_content[header.lower()] = value
|
||||
# get metadata
|
||||
elif header.lower() not in ["from", "to", "cc", "bcc", "date", "subject"]:
|
||||
email_content["metadata"][header.lower()] = value
|
||||
# get body
|
||||
if "body" in target_fields:
|
||||
body_text, body_html = [], []
|
||||
def _add_content(m, content_type):
|
||||
if content_type == "text/plain":
|
||||
body_text.append(
|
||||
m.get_payload(decode=True).decode(m.get_content_charset())
|
||||
)
|
||||
elif content_type == "text/html":
|
||||
body_html.append(
|
||||
m.get_payload(decode=True).decode(m.get_content_charset())
|
||||
)
|
||||
elif "multipart" in content_type:
|
||||
if m.is_multipart():
|
||||
for part in m.iter_parts():
|
||||
_add_content(part, part.get_content_type())
|
||||
|
||||
_add_content(msg, msg.get_content_type())
|
||||
|
||||
email_content["text"] = body_text
|
||||
email_content["text_html"] = body_html
|
||||
# get attachment
|
||||
if "attachments" in target_fields:
|
||||
attachments = []
|
||||
for part in msg.iter_attachments():
|
||||
content_disposition = part.get("Content-Disposition")
|
||||
if content_disposition:
|
||||
dispositions = content_disposition.strip().split(";")
|
||||
if dispositions[0].lower() == "attachment":
|
||||
filename = part.get_filename()
|
||||
payload = part.get_payload(decode=True)
|
||||
attachments.append({
|
||||
"filename": filename,
|
||||
"payload": payload,
|
||||
})
|
||||
email_content["attachments"] = attachments
|
||||
else:
|
||||
# handle msg file
|
||||
import extract_msg
|
||||
print("handle a msg file.")
|
||||
msg = extract_msg.Message(blob)
|
||||
# handle header info
|
||||
basic_content = {
|
||||
"from": msg.sender,
|
||||
"to": msg.to,
|
||||
"cc": msg.cc,
|
||||
"bcc": msg.bcc,
|
||||
"date": msg.date,
|
||||
"subject": msg.subject,
|
||||
}
|
||||
email_content.update({k: v for k, v in basic_content.items() if k in target_fields})
|
||||
# get metadata
|
||||
email_content['metadata'] = {
|
||||
'message_id': msg.messageId,
|
||||
'in_reply_to': msg.inReplyTo,
|
||||
}
|
||||
# get body
|
||||
if "body" in target_fields:
|
||||
email_content["text"] = msg.body # usually empty. try text_html instead
|
||||
email_content["text_html"] = msg.htmlBody
|
||||
# get attachments
|
||||
if "attachments" in target_fields:
|
||||
attachments = []
|
||||
for t in msg.attachments:
|
||||
attachments.append({
|
||||
"filename": t.name,
|
||||
"payload": t.data # binary
|
||||
})
|
||||
email_content["attachments"] = attachments
|
||||
|
||||
if conf["output_format"] == "json":
|
||||
self.set_output("json", [email_content])
|
||||
else:
|
||||
content_txt = ''
|
||||
for k, v in email_content.items():
|
||||
if isinstance(v, str):
|
||||
# basic info
|
||||
content_txt += f'{k}:{v}' + "\n"
|
||||
elif isinstance(v, dict):
|
||||
# metadata
|
||||
content_txt += f'{k}:{json.dumps(v)}' + "\n"
|
||||
elif isinstance(v, list):
|
||||
# attachments or others
|
||||
for fb in v:
|
||||
if isinstance(fb, dict):
|
||||
# attachments
|
||||
content_txt += f'{fb["filename"]}:{fb["payload"]}' + "\n"
|
||||
else:
|
||||
# str, usually plain text
|
||||
content_txt += fb
|
||||
self.set_output("text", content_txt)
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
function_map = {
|
||||
"pdf": self._pdf,
|
||||
"markdown": self._markdown,
|
||||
"text&markdown": self._markdown,
|
||||
"spreadsheet": self._spreadsheet,
|
||||
"slides": self._slides,
|
||||
"word": self._word,
|
||||
"text": self._text,
|
||||
"image": self._image,
|
||||
"audio": self._audio,
|
||||
"email": self._email,
|
||||
}
|
||||
try:
|
||||
from_upstream = ParserFromUpstream.model_validate(kwargs)
|
||||
@ -380,8 +490,25 @@ class Parser(ProcessBase):
|
||||
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||
return
|
||||
|
||||
name = from_upstream.name
|
||||
if self._canvas._doc_id:
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
|
||||
blob = STORAGE_IMPL.get(b, n)
|
||||
else:
|
||||
blob = FileService.get_blob(from_upstream.file["created_by"], from_upstream.file["id"])
|
||||
|
||||
done = False
|
||||
for p_type, conf in self._param.setups.items():
|
||||
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
|
||||
continue
|
||||
await trio.to_thread.run_sync(function_map[p_type], from_upstream)
|
||||
await trio.to_thread.run_sync(function_map[p_type], name, blob)
|
||||
done = True
|
||||
break
|
||||
|
||||
if not done:
|
||||
raise Exception("No suitable for file extension: `.%s`" % from_upstream.name.split(".")[-1].lower())
|
||||
|
||||
outs = self.output()
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in outs.get("json", []):
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
|
||||
|
||||
@ -20,6 +20,5 @@ class ParserFromUpstream(BaseModel):
|
||||
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||
|
||||
name: str
|
||||
blob: bytes
|
||||
|
||||
file: dict | None = Field(default=None)
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
@ -17,41 +17,92 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
|
||||
from timeit import default_timer as timer
|
||||
import trio
|
||||
|
||||
from agent.canvas import Graph
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
class Pipeline(Graph):
|
||||
def __init__(self, dsl: str, tenant_id=None, doc_id=None, task_id=None, flow_id=None):
|
||||
def __init__(self, dsl: str|dict, tenant_id=None, doc_id=None, task_id=None, flow_id=None):
|
||||
if isinstance(dsl, dict):
|
||||
dsl = json.dumps(dsl, ensure_ascii=False)
|
||||
super().__init__(dsl, tenant_id, task_id)
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID:
|
||||
doc_id = None
|
||||
self._doc_id = doc_id
|
||||
self._flow_id = flow_id
|
||||
self._kb_id = None
|
||||
if doc_id:
|
||||
if self._doc_id:
|
||||
self._kb_id = DocumentService.get_knowledgebase_id(doc_id)
|
||||
assert self._kb_id, f"Can't find KB of this document: {doc_id}"
|
||||
if not self._kb_id:
|
||||
self._doc_id = None
|
||||
|
||||
def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None:
|
||||
from rag.svr.task_executor import TaskCanceledException
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
timestamp = timer()
|
||||
if has_canceled(self.task_id):
|
||||
progress = -1
|
||||
message += "[CANCEL]"
|
||||
try:
|
||||
bin = REDIS_CONN.get(log_key)
|
||||
obj = json.loads(bin.encode("utf-8"))
|
||||
if obj:
|
||||
if obj[-1]["component_name"] == component_name:
|
||||
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")})
|
||||
if obj[-1]["component_id"] == component_name:
|
||||
obj[-1]["trace"].append(
|
||||
{
|
||||
"progress": progress,
|
||||
"message": message,
|
||||
"datetime": datetime.datetime.now().strftime("%H:%M:%S"),
|
||||
"timestamp": timestamp,
|
||||
"elapsed_time": timestamp - obj[-1]["trace"][-1]["timestamp"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
obj.append({"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]})
|
||||
obj.append(
|
||||
{
|
||||
"component_id": component_name,
|
||||
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": 0}],
|
||||
}
|
||||
)
|
||||
else:
|
||||
obj = [{"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]}]
|
||||
REDIS_CONN.set_obj(log_key, obj, 60 * 10)
|
||||
obj = [
|
||||
{
|
||||
"component_id": component_name,
|
||||
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": 0}],
|
||||
}
|
||||
]
|
||||
if component_name != "END" and self._doc_id and self.task_id:
|
||||
percentage = 1.0 / len(self.components.items())
|
||||
finished = 0.0
|
||||
for o in obj:
|
||||
for t in o["trace"]:
|
||||
if t["progress"] < 0:
|
||||
finished = -1
|
||||
break
|
||||
if finished < 0:
|
||||
break
|
||||
finished += o["trace"][-1]["progress"] * percentage
|
||||
|
||||
msg = ""
|
||||
if len(obj[-1]["trace"]) == 1:
|
||||
msg += f"\n-------------------------------------\n[{self.get_component_name(o['component_id'])}]:\n"
|
||||
t = obj[-1]["trace"][-1]
|
||||
msg += "%s: %s\n" % (t["datetime"], t["message"])
|
||||
TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
|
||||
elif component_name == "END" and not self._doc_id:
|
||||
obj[-1]["trace"][-1]["dsl"] = json.loads(str(self))
|
||||
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
if has_canceled(self.task_id):
|
||||
raise TaskCanceledException(message)
|
||||
|
||||
def fetch_logs(self):
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
try:
|
||||
@ -62,34 +113,32 @@ class Pipeline(Graph):
|
||||
logging.exception(e)
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
|
||||
async def run(self, **kwargs):
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
try:
|
||||
REDIS_CONN.set_obj(log_key, [], 60 * 10)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
async def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
self.error = ""
|
||||
if not self.path:
|
||||
self.path.append("File")
|
||||
|
||||
if self._doc_id:
|
||||
DocumentService.update_by_id(
|
||||
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
||||
)
|
||||
|
||||
self.error = ""
|
||||
idx = len(self.path) - 1
|
||||
if idx == 0:
|
||||
cpn_obj = self.get_component_obj(self.path[0])
|
||||
await cpn_obj.invoke(**kwargs)
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
else:
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
self.callback(cpn_obj.component_name, -1, self.error)
|
||||
|
||||
if self._doc_id:
|
||||
TaskService.update_progress(self.task_id, {
|
||||
"progress": random.randint(0, 5) / 100.0,
|
||||
"progress_msg": "Start the pipeline...",
|
||||
"begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
|
||||
|
||||
idx = len(self.path) - 1
|
||||
cpn_obj = self.get_component_obj(self.path[idx])
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
|
||||
while idx < len(self.path) and not self.error:
|
||||
last_cpn = self.get_component_obj(self.path[idx - 1])
|
||||
@ -98,15 +147,28 @@ class Pipeline(Graph):
|
||||
async def invoke():
|
||||
nonlocal last_cpn, cpn_obj
|
||||
await cpn_obj.invoke(**last_cpn.output())
|
||||
#if inspect.iscoroutinefunction(cpn_obj.invoke):
|
||||
# await cpn_obj.invoke(**last_cpn.output())
|
||||
#else:
|
||||
# cpn_obj.invoke(**last_cpn.output())
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(invoke)
|
||||
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
self.callback(cpn_obj.component_name, -1, self.error)
|
||||
self.callback(cpn_obj._id, -1, self.error)
|
||||
break
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
|
||||
if self._doc_id:
|
||||
DocumentService.update_by_id(self._doc_id, {"progress": 1 if not self.error else -1, "progress_msg": "Pipeline finished...\n" + self.error, "process_duration": time.perf_counter() - st})
|
||||
self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
|
||||
|
||||
if not self.error:
|
||||
return self.get_component_obj(self.path[-1]).output()
|
||||
|
||||
TaskService.update_progress(self.task_id, {
|
||||
"progress": -1,
|
||||
"progress_msg": f"[ERROR]: {self.error}"})
|
||||
|
||||
return {}
|
||||
|
||||
15
rag/flow/splitter/__init__.py
Normal file
15
rag/flow/splitter/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
38
rag/flow/splitter/schema.py
Normal file
38
rag/flow/splitter/schema.py
Normal file
@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class SplitterFromUpstream(BaseModel):
|
||||
created_time: float | None = Field(default=None, alias="_created_time")
|
||||
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||
|
||||
name: str
|
||||
file: dict | None = Field(default=None)
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
text_result: str | None = Field(default=None, alias="text")
|
||||
html_result: str | None = Field(default=None, alias="html")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
# def to_dict(self, *, exclude_none: bool = True) -> dict:
|
||||
# return self.model_dump(by_alias=True, exclude_none=exclude_none)
|
||||
111
rag/flow/splitter/splitter.py
Normal file
111
rag/flow/splitter/splitter.py
Normal file
@ -0,0 +1,111 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
|
||||
from api.utils import get_uuid
|
||||
from api.utils.base64_image import id2image, image2id
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.splitter.schema import SplitterFromUpstream
|
||||
from rag.nlp import naive_merge, naive_merge_with_images
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
class SplitterParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.chunk_token_size = 512
|
||||
self.delimiters = ["\n"]
|
||||
self.overlapped_percent = 0
|
||||
|
||||
def check(self):
|
||||
self.check_empty(self.delimiters, "Delimiters.")
|
||||
self.check_positive_integer(self.chunk_token_size, "Chunk token size.")
|
||||
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {}
|
||||
|
||||
|
||||
class Splitter(ProcessBase):
|
||||
component_name = "Splitter"
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
try:
|
||||
from_upstream = SplitterFromUpstream.model_validate(kwargs)
|
||||
except Exception as e:
|
||||
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||
return
|
||||
|
||||
deli = ""
|
||||
for d in self._param.delimiters:
|
||||
if len(d) > 1:
|
||||
deli += f"`{d}`"
|
||||
else:
|
||||
deli += d
|
||||
|
||||
self.set_output("output_format", "chunks")
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to split into chunks.")
|
||||
if from_upstream.output_format in ["markdown", "text", "html"]:
|
||||
if from_upstream.output_format == "markdown":
|
||||
payload = from_upstream.markdown_result
|
||||
elif from_upstream.output_format == "text":
|
||||
payload = from_upstream.text_result
|
||||
else: # == "html"
|
||||
payload = from_upstream.html_result
|
||||
|
||||
if not payload:
|
||||
payload = ""
|
||||
|
||||
cks = naive_merge(
|
||||
payload,
|
||||
self._param.chunk_token_size,
|
||||
deli,
|
||||
self._param.overlapped_percent,
|
||||
)
|
||||
self.set_output("chunks", [{"text": c.strip()} for c in cks if c.strip()])
|
||||
|
||||
self.callback(1, "Done.")
|
||||
return
|
||||
|
||||
# json
|
||||
sections, section_images = [], []
|
||||
for o in from_upstream.json_result or []:
|
||||
sections.append((o.get("text", ""), o.get("position_tag", "")))
|
||||
section_images.append(id2image(o.get("img_id"), partial(STORAGE_IMPL.get)))
|
||||
|
||||
chunks, images = naive_merge_with_images(
|
||||
sections,
|
||||
section_images,
|
||||
self._param.chunk_token_size,
|
||||
deli,
|
||||
self._param.overlapped_percent,
|
||||
)
|
||||
cks = [
|
||||
{
|
||||
"text": RAGFlowPdfParser.remove_tag(c),
|
||||
"image": img,
|
||||
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
|
||||
}
|
||||
for c, img in zip(chunks, images) if c.strip()
|
||||
]
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in cks:
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
|
||||
self.set_output("chunks", cks)
|
||||
self.callback(1, "Done.")
|
||||
@ -30,7 +30,7 @@ def print_logs(pipeline: Pipeline):
|
||||
while True:
|
||||
time.sleep(5)
|
||||
logs = pipeline.fetch_logs()
|
||||
logs_str = json.dumps(logs)
|
||||
logs_str = json.dumps(logs, ensure_ascii=False)
|
||||
if logs_str != last_logs:
|
||||
print(logs_str)
|
||||
last_logs = logs_str
|
||||
|
||||
@ -38,6 +38,13 @@
|
||||
],
|
||||
"output_format": "json"
|
||||
},
|
||||
"slides": {
|
||||
"parse_method": "presentation",
|
||||
"suffix": [
|
||||
"pptx"
|
||||
],
|
||||
"output_format": "json"
|
||||
},
|
||||
"markdown": {
|
||||
"suffix": [
|
||||
"md",
|
||||
@ -82,19 +89,36 @@
|
||||
"lang": "Chinese",
|
||||
"llm_id": "SenseVoiceSmall",
|
||||
"output_format": "json"
|
||||
},
|
||||
"email": {
|
||||
"suffix": [
|
||||
"msg"
|
||||
],
|
||||
"fields": [
|
||||
"from",
|
||||
"to",
|
||||
"cc",
|
||||
"bcc",
|
||||
"date",
|
||||
"subject",
|
||||
"body",
|
||||
"attachments"
|
||||
],
|
||||
"output_format": "json"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"downstream": ["Chunker:0"],
|
||||
"downstream": ["Splitter:0"],
|
||||
"upstream": ["Begin"]
|
||||
},
|
||||
"Chunker:0": {
|
||||
"Splitter:0": {
|
||||
"obj": {
|
||||
"component_name": "Chunker",
|
||||
"component_name": "Splitter",
|
||||
"params": {
|
||||
"method": "general",
|
||||
"auto_keywords": 5
|
||||
"chunk_token_size": 512,
|
||||
"delimiters": ["\n"],
|
||||
"overlapped_percent": 0
|
||||
}
|
||||
},
|
||||
"downstream": ["Tokenizer:0"],
|
||||
|
||||
84
rag/flow/tests/dsl_examples/hierarchical_merger.json
Normal file
84
rag/flow/tests/dsl_examples/hierarchical_merger.json
Normal file
@ -0,0 +1,84 @@
|
||||
{
|
||||
"components": {
|
||||
"File": {
|
||||
"obj":{
|
||||
"component_name": "File",
|
||||
"params": {
|
||||
}
|
||||
},
|
||||
"downstream": ["Parser:0"],
|
||||
"upstream": []
|
||||
},
|
||||
"Parser:0": {
|
||||
"obj": {
|
||||
"component_name": "Parser",
|
||||
"params": {
|
||||
"setups": {
|
||||
"pdf": {
|
||||
"parse_method": "deepdoc",
|
||||
"vlm_name": "",
|
||||
"lang": "Chinese",
|
||||
"suffix": [
|
||||
"pdf"
|
||||
],
|
||||
"output_format": "json"
|
||||
},
|
||||
"spreadsheet": {
|
||||
"suffix": [
|
||||
"xls",
|
||||
"xlsx",
|
||||
"csv"
|
||||
],
|
||||
"output_format": "html"
|
||||
},
|
||||
"word": {
|
||||
"suffix": [
|
||||
"doc",
|
||||
"docx"
|
||||
],
|
||||
"output_format": "json"
|
||||
},
|
||||
"markdown": {
|
||||
"suffix": [
|
||||
"md",
|
||||
"markdown"
|
||||
],
|
||||
"output_format": "text"
|
||||
},
|
||||
"text": {
|
||||
"suffix": ["txt"],
|
||||
"output_format": "json"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"downstream": ["Splitter:0"],
|
||||
"upstream": ["File"]
|
||||
},
|
||||
"Splitter:0": {
|
||||
"obj": {
|
||||
"component_name": "Splitter",
|
||||
"params": {
|
||||
"chunk_token_size": 512,
|
||||
"delimiters": ["\r\n"],
|
||||
"overlapped_percent": 0
|
||||
}
|
||||
},
|
||||
"downstream": ["HierarchicalMerger:0"],
|
||||
"upstream": ["Parser:0"]
|
||||
},
|
||||
"HierarchicalMerger:0": {
|
||||
"obj": {
|
||||
"component_name": "HierarchicalMerger",
|
||||
"params": {
|
||||
"levels": [["^#[^#]"], ["^##[^#]"], ["^###[^#]"], ["^####[^#]"]],
|
||||
"hierarchy": 2
|
||||
}
|
||||
},
|
||||
"downstream": [],
|
||||
"upstream": ["Splitter:0"]
|
||||
}
|
||||
},
|
||||
"path": []
|
||||
}
|
||||
|
||||
@ -22,16 +22,16 @@ class TokenizerFromUpstream(BaseModel):
|
||||
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||
|
||||
name: str = ""
|
||||
blob: bytes
|
||||
file: dict | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||
output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
|
||||
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
text_result: str | None = Field(default=None, alias="text")
|
||||
html_result: list[str] | None = Field(default=None, alias="html")
|
||||
html_result: str | None = Field(default=None, alias="html")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
@ -40,12 +40,14 @@ class TokenizerFromUpstream(BaseModel):
|
||||
if self.chunks:
|
||||
return self
|
||||
|
||||
if self.output_format in {"markdown", "text"}:
|
||||
if self.output_format in {"markdown", "text", "html"}:
|
||||
if self.output_format == "markdown" and not self.markdown_result:
|
||||
raise ValueError("output_format=markdown requires a markdown payload (field: 'markdown' or 'markdown_result').")
|
||||
if self.output_format == "text" and not self.text_result:
|
||||
raise ValueError("output_format=text requires a text payload (field: 'text' or 'text_result').")
|
||||
if self.output_format == "html" and not self.html_result:
|
||||
raise ValueError("output_format=text requires a html payload (field: 'html' or 'html_result').")
|
||||
else:
|
||||
if not self.json_result:
|
||||
if not self.json_result and not self.chunks:
|
||||
raise ValueError("When no chunks are provided and output_format is not markdown/text, a JSON list payload is required (field: 'json' or 'json_result').")
|
||||
return self
|
||||
|
||||
@ -37,6 +37,7 @@ class TokenizerParam(ProcessParamBase):
|
||||
super().__init__()
|
||||
self.search_method = ["full_text", "embedding"]
|
||||
self.filename_embd_weight = 0.1
|
||||
self.fields = ["text"]
|
||||
|
||||
def check(self):
|
||||
for v in self.search_method:
|
||||
@ -61,10 +62,14 @@ class Tokenizer(ProcessBase):
|
||||
embedding_model = LLMBundle(self._canvas._tenant_id, LLMType.EMBEDDING, llm_name=embedding_id)
|
||||
texts = []
|
||||
for c in chunks:
|
||||
if c.get("questions"):
|
||||
texts.append("\n".join(c["questions"]))
|
||||
else:
|
||||
texts.append(re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", c["text"]))
|
||||
txt = ""
|
||||
for f in self._param.fields:
|
||||
f = c.get(f)
|
||||
if isinstance(f, str):
|
||||
txt += f
|
||||
elif isinstance(f, list):
|
||||
txt += "\n".join(f)
|
||||
texts.append(re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", txt))
|
||||
vts, c = embedding_model.encode([name])
|
||||
token_count += c
|
||||
tts = np.concatenate([vts[0] for _ in range(len(texts))], axis=0)
|
||||
@ -103,26 +108,36 @@ class Tokenizer(ProcessBase):
|
||||
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||
return
|
||||
|
||||
self.set_output("output_format", "chunks")
|
||||
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
|
||||
if "full_text" in self._param.search_method:
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to tokenize.")
|
||||
if from_upstream.chunks:
|
||||
chunks = from_upstream.chunks
|
||||
for i, ck in enumerate(chunks):
|
||||
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
|
||||
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
|
||||
if ck.get("questions"):
|
||||
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
||||
ck["question_kwd"] = ck["questions"].split("\n")
|
||||
ck["question_tks"] = rag_tokenizer.tokenize(str(ck["questions"]))
|
||||
if ck.get("keywords"):
|
||||
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
ck["important_kwd"] = ck["keywords"].split(",")
|
||||
ck["important_tks"] = rag_tokenizer.tokenize(str(ck["keywords"]))
|
||||
if ck.get("summary"):
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
else:
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
self.callback(i * 1.0 / len(chunks) / parts)
|
||||
|
||||
elif from_upstream.output_format in ["markdown", "text", "html"]:
|
||||
if from_upstream.output_format == "markdown":
|
||||
payload = from_upstream.markdown_result
|
||||
elif from_upstream.output_format == "text":
|
||||
payload = from_upstream.text_result
|
||||
else: # == "html"
|
||||
else:
|
||||
payload = from_upstream.html_result
|
||||
|
||||
if not payload:
|
||||
@ -130,12 +145,16 @@ class Tokenizer(ProcessBase):
|
||||
|
||||
ck = {"text": payload}
|
||||
if "full_text" in self._param.search_method:
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(kwargs.get(kwargs["output_format"], ""))
|
||||
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
|
||||
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(payload)
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
chunks = [ck]
|
||||
else:
|
||||
chunks = from_upstream.json_result
|
||||
for i, ck in enumerate(chunks):
|
||||
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
|
||||
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
|
||||
@ -143,10 +143,10 @@ class Base(ABC):
|
||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||
if self.model_name.lower().find("qwen3") >= 0:
|
||||
kwargs["extra_body"] = {"enable_thinking": False}
|
||||
|
||||
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||
|
||||
if (not response.choices or not response.choices[0].message or not response.choices[0].message.content):
|
||||
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
|
||||
return "", 0
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
@ -156,12 +156,12 @@ class Base(ABC):
|
||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||
reasoning_start = False
|
||||
|
||||
|
||||
if kwargs.get("stop") or "stop" in gen_conf:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
|
||||
else:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
|
||||
|
||||
for resp in response:
|
||||
if not resp.choices:
|
||||
continue
|
||||
@ -457,7 +457,7 @@ class Base(ABC):
|
||||
yield total_tokens
|
||||
|
||||
def total_token_count(self, resp):
|
||||
return total_token_count_from_response(resp)
|
||||
return total_token_count_from_response(resp)
|
||||
|
||||
def _calculate_dynamic_ctx(self, history):
|
||||
"""Calculate dynamic context window size"""
|
||||
@ -643,7 +643,7 @@ class ZhipuChat(Base):
|
||||
del gen_conf["max_tokens"]
|
||||
gen_conf = self._clean_conf_plealty(gen_conf)
|
||||
return gen_conf
|
||||
|
||||
|
||||
def _clean_conf_plealty(self, gen_conf):
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
@ -1305,10 +1305,6 @@ class LiteLLMBase(ABC):
|
||||
"302.AI",
|
||||
]
|
||||
|
||||
import litellm
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
||||
self.provider = kwargs.get("provider", "")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user