Fix: Merge main branch (#10377)

### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: jinhai <haijin.chn@gmail.com>
Signed-off-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: Lynn <lynn_inf@hotmail.com>
Co-authored-by: chanx <1243304602@qq.com>
Co-authored-by: balibabu <cike8899@users.noreply.github.com>
Co-authored-by: 纷繁下的无奈 <zhileihuang@126.com>
Co-authored-by: huangzl <huangzl@shinemo.com>
Co-authored-by: writinwaters <93570324+writinwaters@users.noreply.github.com>
Co-authored-by: Wilmer <33392318@qq.com>
Co-authored-by: Adrian Weidig <adrianweidig@gmx.net>
Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Yongteng Lei <yongtengrey@outlook.com>
Co-authored-by: Liu An <asiro@qq.com>
Co-authored-by: buua436 <66937541+buua436@users.noreply.github.com>
Co-authored-by: BadwomanCraZY <511528396@qq.com>
Co-authored-by: cucusenok <31804608+cucusenok@users.noreply.github.com>
Co-authored-by: Russell Valentine <russ@coldstonelabs.org>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Billy Bao <newyorkupperbay@gmail.com>
Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
Co-authored-by: TensorNull <129579691+TensorNull@users.noreply.github.com>
Co-authored-by: TensorNull <tensor.null@gmail.com>
Co-authored-by: Ajay <160579663+aybanda@users.noreply.github.com>
Co-authored-by: AB <aj@Ajays-MacBook-Air.local>
Co-authored-by: 天海蒼灆 <huangaoqin@tecpie.com>
Co-authored-by: He Wang <wanghechn@qq.com>
Co-authored-by: Atsushi Hatakeyama <atu729@icloud.com>
Co-authored-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: Mohamed Mathari <155896313+melmathari@users.noreply.github.com>
Co-authored-by: Mohamed Mathari <nocodeventure@Mac-mini-van-Mohamed.fritz.box>
Co-authored-by: Stephen Hu <stephenhu@seismic.com>
Co-authored-by: Shaun Zhang <zhangwfjh@users.noreply.github.com>
Co-authored-by: zhimeng123 <60221886+zhimeng123@users.noreply.github.com>
Co-authored-by: mxc <mxc@example.com>
Co-authored-by: Dominik Novotný <50611433+SgtMarmite@users.noreply.github.com>
Co-authored-by: EVGENY M <168018528+rjohny55@users.noreply.github.com>
Co-authored-by: mcoder6425 <mcoder64@gmail.com>
Co-authored-by: TeslaZY <TeslaZY@outlook.com>
Co-authored-by: lemsn <lemsn@msn.com>
Co-authored-by: lemsn <lemsn@126.com>
Co-authored-by: Adrian Gora <47756404+adagora@users.noreply.github.com>
Co-authored-by: Womsxd <45663319+Womsxd@users.noreply.github.com>
Co-authored-by: FatMii <39074672+FatMii@users.noreply.github.com>
This commit is contained in:
Kevin Hu
2025-09-30 13:13:15 +08:00
committed by GitHub
parent 4d6ff672eb
commit 20b577a72c
201 changed files with 7929 additions and 1110 deletions

101
admin/README.md Normal file
View File

@ -0,0 +1,101 @@
# RAGFlow Admin Service & CLI
### Introduction
Admin Service is a dedicated management component designed to monitor, maintain, and administrate the RAGFlow system. It provides comprehensive tools for ensuring system stability, performing operational tasks, and managing users and permissions efficiently.
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
For user and system management, it supports listing, creating, modifying, and deleting users and their associated resources like knowledge bases and Agents.
Built with scalability and reliability in mind, the Admin Service ensures smooth system operation and simplifies maintenance workflows.
It consists of a server-side Service and a command-line client (CLI), both implemented in Python. User commands are parsed using the Lark parsing toolkit.
- **Admin Service**: A backend service that interfaces with the RAGFlow system to execute administrative operations and monitor its status.
- **Admin CLI**: A command-line interface that allows users to connect to the Admin Service and issue commands for system management.
### Starting the Admin Service
1. Before start Admin Service, please make sure RAGFlow system is already started.
2. Run the service script:
```bash
python admin/admin_server.py
```
The service will start and listen for incoming connections from the CLI on the configured port.
### Using the Admin CLI
1. Ensure the Admin Service is running.
2. Launch the CLI client:
```bash
python admin/admin_client.py -h 0.0.0.0 -p 9381
## Supported Commands
Commands are case-insensitive and must be terminated with a semicolon (`;`).
### Service Management Commands
- `LIST SERVICES;`
- Lists all available services within the RAGFlow system.
- `SHOW SERVICE <id>;`
- Shows detailed status information for the service identified by `<id>`.
- `STARTUP SERVICE <id>;`
- Attempts to start the service identified by `<id>`.
- `SHUTDOWN SERVICE <id>;`
- Attempts to gracefully shut down the service identified by `<id>`.
- `RESTART SERVICE <id>;`
- Attempts to restart the service identified by `<id>`.
### User Management Commands
- `LIST USERS;`
- Lists all users known to the system.
- `SHOW USER '<username>';`
- Shows details and permissions for the specified user. The username must be enclosed in single or double quotes.
- `DROP USER '<username>';`
- Removes the specified user from the system. Use with caution.
- `ALTER USER PASSWORD '<username>' '<new_password>';`
- Changes the password for the specified user.
### Data and Agent Commands
- `LIST DATASETS OF '<username>';`
- Lists the datasets associated with the specified user.
- `LIST AGENTS OF '<username>';`
- Lists the agents associated with the specified user.
### Meta-Commands
Meta-commands are prefixed with a backslash (`\`).
- `\?` or `\help`
- Shows help information for the available commands.
- `\q` or `\quit`
- Exits the CLI application.
## Examples
```commandline
admin> list users;
+-------------------------------+------------------------+-----------+-------------+
| create_date | email | is_active | nickname |
+-------------------------------+------------------------+-----------+-------------+
| Fri, 22 Nov 2024 16:03:41 GMT | jeffery@infiniflow.org | 1 | Jeffery |
| Fri, 22 Nov 2024 16:10:55 GMT | aya@infiniflow.org | 1 | Waterdancer |
+-------------------------------+------------------------+-----------+-------------+
admin> list services;
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
| extra | host | id | name | port | service_type |
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
| {} | 0.0.0.0 | 0 | ragflow_0 | 9380 | ragflow_server |
| {'meta_type': 'mysql', 'password': 'infini_rag_flow', 'username': 'root'} | localhost | 1 | mysql | 5455 | meta_data |
| {'password': 'infini_rag_flow', 'store_type': 'minio', 'user': 'rag_flow'} | localhost | 2 | minio | 9000 | file_store |
| {'password': 'infini_rag_flow', 'retrieval_type': 'elasticsearch', 'username': 'elastic'} | localhost | 3 | elasticsearch | 1200 | retrieval |
| {'db_name': 'default_db', 'retrieval_type': 'infinity'} | localhost | 4 | infinity | 23817 | retrieval |
| {'database': 1, 'mq_type': 'redis', 'password': 'infini_rag_flow'} | localhost | 5 | redis | 6379 | message_queue |
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
```

574
admin/admin_client.py Normal file
View File

@ -0,0 +1,574 @@
import argparse
import base64
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from typing import Dict, List, Any
from lark import Lark, Transformer, Tree
import requests
from requests.auth import HTTPBasicAuth
from api.common.base64 import encode_to_base64
GRAMMAR = r"""
start: command
command: sql_command | meta_command
sql_command: list_services
| show_service
| startup_service
| shutdown_service
| restart_service
| list_users
| show_user
| drop_user
| alter_user
| create_user
| activate_user
| list_datasets
| list_agents
// meta command definition
meta_command: "\\" meta_command_name [meta_args]
meta_command_name: /[a-zA-Z?]+/
meta_args: (meta_arg)+
meta_arg: /[^\\s"']+/ | quoted_string
// command definition
LIST: "LIST"i
SERVICES: "SERVICES"i
SHOW: "SHOW"i
CREATE: "CREATE"i
SERVICE: "SERVICE"i
SHUTDOWN: "SHUTDOWN"i
STARTUP: "STARTUP"i
RESTART: "RESTART"i
USERS: "USERS"i
DROP: "DROP"i
USER: "USER"i
ALTER: "ALTER"i
ACTIVE: "ACTIVE"i
PASSWORD: "PASSWORD"i
DATASETS: "DATASETS"i
OF: "OF"i
AGENTS: "AGENTS"i
list_services: LIST SERVICES ";"
show_service: SHOW SERVICE NUMBER ";"
startup_service: STARTUP SERVICE NUMBER ";"
shutdown_service: SHUTDOWN SERVICE NUMBER ";"
restart_service: RESTART SERVICE NUMBER ";"
list_users: LIST USERS ";"
drop_user: DROP USER quoted_string ";"
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
show_user: SHOW USER quoted_string ";"
create_user: CREATE USER quoted_string quoted_string ";"
activate_user: ALTER USER ACTIVE quoted_string status ";"
list_datasets: LIST DATASETS OF quoted_string ";"
list_agents: LIST AGENTS OF quoted_string ";"
identifier: WORD
quoted_string: QUOTED_STRING
status: WORD
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
WORD: /[a-zA-Z0-9_\-\.]+/
NUMBER: /[0-9]+/
%import common.WS
%ignore WS
"""
class AdminTransformer(Transformer):
def start(self, items):
return items[0]
def command(self, items):
return items[0]
def list_services(self, items):
result = {'type': 'list_services'}
return result
def show_service(self, items):
service_id = int(items[2])
return {"type": "show_service", "number": service_id}
def startup_service(self, items):
service_id = int(items[2])
return {"type": "startup_service", "number": service_id}
def shutdown_service(self, items):
service_id = int(items[2])
return {"type": "shutdown_service", "number": service_id}
def restart_service(self, items):
service_id = int(items[2])
return {"type": "restart_service", "number": service_id}
def list_users(self, items):
return {"type": "list_users"}
def show_user(self, items):
user_name = items[2]
return {"type": "show_user", "username": user_name}
def drop_user(self, items):
user_name = items[2]
return {"type": "drop_user", "username": user_name}
def alter_user(self, items):
user_name = items[3]
new_password = items[4]
return {"type": "alter_user", "username": user_name, "password": new_password}
def create_user(self, items):
user_name = items[2]
password = items[3]
return {"type": "create_user", "username": user_name, "password": password, "role": "user"}
def activate_user(self, items):
user_name = items[3]
activate_status = items[4]
return {"type": "activate_user", "activate_status": activate_status, "username": user_name}
def list_datasets(self, items):
user_name = items[3]
return {"type": "list_datasets", "username": user_name}
def list_agents(self, items):
user_name = items[3]
return {"type": "list_agents", "username": user_name}
def meta_command(self, items):
command_name = str(items[0]).lower()
args = items[1:] if len(items) > 1 else []
# handle quoted parameter
parsed_args = []
for arg in args:
if hasattr(arg, 'value'):
parsed_args.append(arg.value)
else:
parsed_args.append(str(arg))
return {'type': 'meta', 'command': command_name, 'args': parsed_args}
def meta_command_name(self, items):
return items[0]
def meta_args(self, items):
return items
def encrypt(input_string):
pub = '-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----'
pub_key = RSA.importKey(pub)
cipher = Cipher_pkcs1_v1_5.new(pub_key)
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode('utf-8')))
return base64.b64encode(cipher_text).decode("utf-8")
class AdminCommandParser:
def __init__(self):
self.parser = Lark(GRAMMAR, start='start', parser='lalr', transformer=AdminTransformer())
self.command_history = []
def parse_command(self, command_str: str) -> Dict[str, Any]:
if not command_str.strip():
return {'type': 'empty'}
self.command_history.append(command_str)
try:
result = self.parser.parse(command_str)
return result
except Exception as e:
return {'type': 'error', 'message': f'Parse error: {str(e)}'}
class AdminCLI:
def __init__(self):
self.parser = AdminCommandParser()
self.is_interactive = False
self.admin_account = "admin@ragflow.io"
self.admin_password: str = "admin"
self.host: str = ""
self.port: int = 0
def verify_admin(self, args):
conn_info = self._parse_connection_args(args)
if 'error' in conn_info:
print(f"Error: {conn_info['error']}")
return
self.host = conn_info['host']
self.port = conn_info['port']
print(f"Attempt to access ip: {self.host}, port: {self.port}")
url = f'http://{self.host}:{self.port}/api/v1/admin/auth'
try_count = 0
while True:
try_count += 1
if try_count > 3:
return False
admin_passwd = input(f"password for {self.admin_account}: ").strip()
try:
self.admin_password = encode_to_base64(admin_passwd)
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
if response.status_code == 200:
res_json = response.json()
error_code = res_json.get('code', -1)
if error_code == 0:
print("Authentication successful.")
return True
else:
error_message = res_json.get('message', 'Unknown error')
print(f"Authentication failed: {error_message}, try again")
continue
else:
print(f"Bad responsestatus: {response.status_code}, try again")
except Exception:
print(f"Can't access {self.host}, port: {self.port}")
def _print_table_simple(self, data):
if not data:
print("No data to print")
return
if isinstance(data, dict):
# handle single row data
data = [data]
columns = list(data[0].keys())
col_widths = {}
for col in columns:
max_width = len(str(col))
for item in data:
value_len = len(str(item.get(col, '')))
if value_len > max_width:
max_width = value_len
col_widths[col] = max(2, max_width)
# Generate delimiter
separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
# Print header
print(separator)
header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
print(header)
print(separator)
# Print data
for item in data:
row = "|"
for col in columns:
value = str(item.get(col, ''))
if len(value) > col_widths[col]:
value = value[:col_widths[col] - 3] + "..."
row += f" {value:<{col_widths[col]}} |"
print(row)
print(separator)
def run_interactive(self):
self.is_interactive = True
print("RAGFlow Admin command line interface - Type '\\?' for help, '\\q' to quit")
while True:
try:
command = input("admin> ").strip()
if not command:
continue
print(f"command: {command}")
result = self.parser.parse_command(command)
self.execute_command(result)
if isinstance(result, Tree):
continue
if result.get('type') == 'meta' and result.get('command') in ['q', 'quit', 'exit']:
break
except KeyboardInterrupt:
print("\nUse '\\q' to quit")
except EOFError:
print("\nGoodbye!")
break
def run_single_command(self, args):
conn_info = self._parse_connection_args(args)
if 'error' in conn_info:
print(f"Error: {conn_info['error']}")
return
def _parse_connection_args(self, args: List[str]) -> Dict[str, Any]:
parser = argparse.ArgumentParser(description='Admin CLI Client', add_help=False)
parser.add_argument('-h', '--host', default='localhost', help='Admin service host')
parser.add_argument('-p', '--port', type=int, default=8080, help='Admin service port')
try:
parsed_args, remaining_args = parser.parse_known_args(args)
return {
'host': parsed_args.host,
'port': parsed_args.port,
}
except SystemExit:
return {'error': 'Invalid connection arguments'}
def execute_command(self, parsed_command: Dict[str, Any]):
command_dict: dict
if isinstance(parsed_command, Tree):
command_dict = parsed_command.children[0]
else:
if parsed_command['type'] == 'error':
print(f"Error: {parsed_command['message']}")
return
else:
command_dict = parsed_command
# print(f"Parsed command: {command_dict}")
command_type = command_dict['type']
match command_type:
case 'list_services':
self._handle_list_services(command_dict)
case 'show_service':
self._handle_show_service(command_dict)
case 'restart_service':
self._handle_restart_service(command_dict)
case 'shutdown_service':
self._handle_shutdown_service(command_dict)
case 'startup_service':
self._handle_startup_service(command_dict)
case 'list_users':
self._handle_list_users(command_dict)
case 'show_user':
self._handle_show_user(command_dict)
case 'drop_user':
self._handle_drop_user(command_dict)
case 'alter_user':
self._handle_alter_user(command_dict)
case 'create_user':
self._handle_create_user(command_dict)
case 'activate_user':
self._handle_activate_user(command_dict)
case 'list_datasets':
self._handle_list_datasets(command_dict)
case 'list_agents':
self._handle_list_agents(command_dict)
case 'meta':
self._handle_meta_command(command_dict)
case _:
print(f"Command '{command_type}' would be executed with API")
def _handle_list_services(self, command):
print("Listing all services")
url = f'http://{self.host}:{self.port}/api/v1/admin/services'
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
def _handle_show_service(self, command):
service_id: int = command['number']
print(f"Showing service: {service_id}")
def _handle_restart_service(self, command):
service_id: int = command['number']
print(f"Restart service {service_id}")
def _handle_shutdown_service(self, command):
service_id: int = command['number']
print(f"Shutdown service {service_id}")
def _handle_startup_service(self, command):
service_id: int = command['number']
print(f"Startup service {service_id}")
def _handle_list_users(self, command):
print("Listing all users")
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
def _handle_show_user(self, command):
username_tree: Tree = command['username']
username: str = username_tree.children[0].strip("'\"")
print(f"Showing user: {username}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to get user {username}, code: {res_json['code']}, message: {res_json['message']}")
def _handle_drop_user(self, command):
username_tree: Tree = command['username']
username: str = username_tree.children[0].strip("'\"")
print(f"Drop user: {username}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
response = requests.delete(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
res_json = response.json()
if response.status_code == 200:
print(res_json["message"])
else:
print(f"Fail to drop user, code: {res_json['code']}, message: {res_json['message']}")
def _handle_alter_user(self, command):
username_tree: Tree = command['username']
username: str = username_tree.children[0].strip("'\"")
password_tree: Tree = command['password']
password: str = password_tree.children[0].strip("'\"")
print(f"Alter user: {username}, password: {password}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/password'
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password),
json={'new_password': encrypt(password)})
res_json = response.json()
if response.status_code == 200:
print(res_json["message"])
else:
print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
def _handle_create_user(self, command):
username_tree: Tree = command['username']
username: str = username_tree.children[0].strip("'\"")
password_tree: Tree = command['password']
password: str = password_tree.children[0].strip("'\"")
role: str = command['role']
print(f"Create user: {username}, password: {password}, role: {role}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
response = requests.post(
url,
auth=HTTPBasicAuth(self.admin_account, self.admin_password),
json={'username': username, 'password': encrypt(password), 'role': role}
)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to create user {username}, code: {res_json['code']}, message: {res_json['message']}")
def _handle_activate_user(self, command):
username_tree: Tree = command['username']
username: str = username_tree.children[0].strip("'\"")
activate_tree: Tree = command['activate_status']
activate_status: str = activate_tree.children[0].strip("'\"")
if activate_status.lower() in ['on', 'off']:
print(f"Alter user {username} activate status, turn {activate_status.lower()}.")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/activate'
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password),
json={'activate_status': activate_status})
res_json = response.json()
if response.status_code == 200:
print(res_json["message"])
else:
print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
else:
print(f"Unknown activate status: {activate_status}.")
def _handle_list_datasets(self, command):
username_tree: Tree = command['username']
username: str = username_tree.children[0].strip("'\"")
print(f"Listing all datasets of user: {username}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/datasets'
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to get all datasets of {username}, code: {res_json['code']}, message: {res_json['message']}")
def _handle_list_agents(self, command):
username_tree: Tree = command['username']
username: str = username_tree.children[0].strip("'\"")
print(f"Listing all agents of user: {username}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/agents'
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to get all agents of {username}, code: {res_json['code']}, message: {res_json['message']}")
def _handle_meta_command(self, command):
meta_command = command['command']
args = command.get('args', [])
if meta_command in ['?', 'h', 'help']:
self.show_help()
elif meta_command in ['q', 'quit', 'exit']:
print("Goodbye!")
else:
print(f"Meta command '{meta_command}' with args {args}")
def show_help(self):
"""Help info"""
help_text = """
Commands:
LIST SERVICES
SHOW SERVICE <service>
STARTUP SERVICE <service>
SHUTDOWN SERVICE <service>
RESTART SERVICE <service>
LIST USERS
SHOW USER <user>
DROP USER <user>
CREATE USER <user> <password>
ALTER USER PASSWORD <user> <new_password>
ALTER USER ACTIVE <user> <on/off>
LIST DATASETS OF <user>
LIST AGENTS OF <user>
Meta Commands:
\\?, \\h, \\help Show this help
\\q, \\quit, \\exit Quit the CLI
"""
print(help_text)
def main():
import sys
cli = AdminCLI()
if len(sys.argv) == 1 or (len(sys.argv) > 1 and sys.argv[1] == '-'):
print(r"""
____ ___ ______________ ___ __ _
/ __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
""")
if cli.verify_admin(sys.argv):
cli.run_interactive()
else:
if cli.verify_admin(sys.argv):
cli.run_interactive()
# cli.run_single_command(sys.argv[1:])
if __name__ == '__main__':
main()

47
admin/admin_server.py Normal file
View File

@ -0,0 +1,47 @@
import os
import signal
import logging
import time
import threading
import traceback
from werkzeug.serving import run_simple
from flask import Flask
from routes import admin_bp
from api.utils.log_utils import init_root_logger
from api.constants import SERVICE_CONF
from api import settings
from config import load_configurations, SERVICE_CONFIGS
stop_event = threading.Event()
if __name__ == '__main__':
init_root_logger("admin_service")
logging.info(r"""
____ ___ ______________ ___ __ _
/ __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
""")
app = Flask(__name__)
app.register_blueprint(admin_bp)
settings.init_settings()
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
try:
logging.info("RAGFlow Admin service start...")
run_simple(
hostname="0.0.0.0",
port=9381,
application=app,
threaded=True,
use_reloader=True,
use_debugger=True,
)
except Exception:
traceback.print_exc()
stop_event.set()
time.sleep(1)
os.kill(os.getpid(), signal.SIGKILL)

57
admin/auth.py Normal file
View File

@ -0,0 +1,57 @@
import logging
import uuid
from functools import wraps
from flask import request, jsonify
from exceptions import AdminException
from api.db.init_data import encode_to_base64
from api.db.services import UserService
def check_admin(username: str, password: str):
users = UserService.query(email=username)
if not users:
logging.info(f"Username: {username} is not registered!")
user_info = {
"id": uuid.uuid1().hex,
"password": encode_to_base64("admin"),
"nickname": "admin",
"is_superuser": True,
"email": "admin@ragflow.io",
"creator": "system",
"status": "1",
}
if not UserService.save(**user_info):
raise AdminException("Can't init admin.", 500)
user = UserService.query_user(username, password)
if user:
return True
else:
return False
def login_verify(f):
@wraps(f)
def decorated(*args, **kwargs):
auth = request.authorization
if not auth or 'username' not in auth.parameters or 'password' not in auth.parameters:
return jsonify({
"code": 401,
"message": "Authentication required",
"data": None
}), 200
username = auth.parameters['username']
password = auth.parameters['password']
# TODO: to check the username and password from DB
if check_admin(username, password) is False:
return jsonify({
"code": 403,
"message": "Access denied",
"data": None
}), 200
return f(*args, **kwargs)
return decorated

280
admin/config.py Normal file
View File

@ -0,0 +1,280 @@
import logging
import threading
from enum import Enum
from pydantic import BaseModel
from typing import Any
from api.utils.configs import read_config
from urllib.parse import urlparse
class ServiceConfigs:
def __init__(self):
self.configs = []
self.lock = threading.Lock()
SERVICE_CONFIGS = ServiceConfigs
class ServiceType(Enum):
METADATA = "metadata"
RETRIEVAL = "retrieval"
MESSAGE_QUEUE = "message_queue"
RAGFLOW_SERVER = "ragflow_server"
TASK_EXECUTOR = "task_executor"
FILE_STORE = "file_store"
class BaseConfig(BaseModel):
id: int
name: str
host: str
port: int
service_type: str
def to_dict(self) -> dict[str, Any]:
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port, 'service_type': self.service_type}
class MetaConfig(BaseConfig):
meta_type: str
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if 'extra' not in result:
result['extra'] = dict()
extra_dict = result['extra'].copy()
extra_dict['meta_type'] = self.meta_type
result['extra'] = extra_dict
return result
class MySQLConfig(MetaConfig):
username: str
password: str
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if 'extra' not in result:
result['extra'] = dict()
extra_dict = result['extra'].copy()
extra_dict['username'] = self.username
extra_dict['password'] = self.password
result['extra'] = extra_dict
return result
class PostgresConfig(MetaConfig):
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if 'extra' not in result:
result['extra'] = dict()
return result
class RetrievalConfig(BaseConfig):
retrieval_type: str
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if 'extra' not in result:
result['extra'] = dict()
extra_dict = result['extra'].copy()
extra_dict['retrieval_type'] = self.retrieval_type
result['extra'] = extra_dict
return result
class InfinityConfig(RetrievalConfig):
db_name: str
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if 'extra' not in result:
result['extra'] = dict()
extra_dict = result['extra'].copy()
extra_dict['db_name'] = self.db_name
result['extra'] = extra_dict
return result
class ElasticsearchConfig(RetrievalConfig):
username: str
password: str
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if 'extra' not in result:
result['extra'] = dict()
extra_dict = result['extra'].copy()
extra_dict['username'] = self.username
extra_dict['password'] = self.password
result['extra'] = extra_dict
return result
class MessageQueueConfig(BaseConfig):
mq_type: str
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if 'extra' not in result:
result['extra'] = dict()
extra_dict = result['extra'].copy()
extra_dict['mq_type'] = self.mq_type
result['extra'] = extra_dict
return result
class RedisConfig(MessageQueueConfig):
database: int
password: str
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if 'extra' not in result:
result['extra'] = dict()
extra_dict = result['extra'].copy()
extra_dict['database'] = self.database
extra_dict['password'] = self.password
result['extra'] = extra_dict
return result
class RabbitMQConfig(MessageQueueConfig):
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if 'extra' not in result:
result['extra'] = dict()
return result
class RAGFlowServerConfig(BaseConfig):
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if 'extra' not in result:
result['extra'] = dict()
return result
class TaskExecutorConfig(BaseConfig):
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if 'extra' not in result:
result['extra'] = dict()
return result
class FileStoreConfig(BaseConfig):
store_type: str
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if 'extra' not in result:
result['extra'] = dict()
extra_dict = result['extra'].copy()
extra_dict['store_type'] = self.store_type
result['extra'] = extra_dict
return result
class MinioConfig(FileStoreConfig):
user: str
password: str
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if 'extra' not in result:
result['extra'] = dict()
extra_dict = result['extra'].copy()
extra_dict['user'] = self.user
extra_dict['password'] = self.password
result['extra'] = extra_dict
return result
def load_configurations(config_path: str) -> list[BaseConfig]:
raw_configs = read_config(config_path)
configurations = []
ragflow_count = 0
id_count = 0
for k, v in raw_configs.items():
match (k):
case "ragflow":
name: str = f'ragflow_{ragflow_count}'
host: str = v['host']
http_port: int = v['http_port']
config = RAGFlowServerConfig(id=id_count, name=name, host=host, port=http_port, service_type="ragflow_server")
configurations.append(config)
id_count += 1
case "es":
name: str = 'elasticsearch'
url = v['hosts']
parsed = urlparse(url)
host: str = parsed.hostname
port: int = parsed.port
username: str = v.get('username')
password: str = v.get('password')
config = ElasticsearchConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval",
retrieval_type="elasticsearch",
username=username, password=password)
configurations.append(config)
id_count += 1
case "infinity":
name: str = 'infinity'
url = v['uri']
parts = url.split(':', 1)
host = parts[0]
port = int(parts[1])
database: str = v.get('db_name', 'default_db')
config = InfinityConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval", retrieval_type="infinity",
db_name=database)
configurations.append(config)
id_count += 1
case "minio":
name: str = 'minio'
url = v['host']
parts = url.split(':', 1)
host = parts[0]
port = int(parts[1])
user = v.get('user')
password = v.get('password')
config = MinioConfig(id=id_count, name=name, host=host, port=port, user=user, password=password, service_type="file_store",
store_type="minio")
configurations.append(config)
id_count += 1
case "redis":
name: str = 'redis'
url = v['host']
parts = url.split(':', 1)
host = parts[0]
port = int(parts[1])
password = v.get('password')
db: int = v.get('db')
config = RedisConfig(id=id_count, name=name, host=host, port=port, password=password, database=db,
service_type="message_queue", mq_type="redis")
configurations.append(config)
id_count += 1
case "mysql":
name: str = 'mysql'
host: str = v.get('host')
port: int = v.get('port')
username = v.get('user')
password = v.get('password')
config = MySQLConfig(id=id_count, name=name, host=host, port=port, username=username, password=password,
service_type="meta_data", meta_type="mysql")
configurations.append(config)
id_count += 1
case "admin":
pass
case _:
logging.warning(f"Unknown configuration key: {k}")
continue
return configurations

17
admin/exceptions.py Normal file
View File

@ -0,0 +1,17 @@
class AdminException(Exception):
def __init__(self, message, code=400):
super().__init__(message)
self.code = code
self.message = message
class UserNotFoundError(AdminException):
def __init__(self, username):
super().__init__(f"User '{username}' not found", 404)
class UserAlreadyExistsError(AdminException):
def __init__(self, username):
super().__init__(f"User '{username}' already exists", 409)
class CannotDeleteAdminError(AdminException):
def __init__(self):
super().__init__("Cannot delete admin account", 403)

0
admin/models.py Normal file
View File

15
admin/responses.py Normal file
View File

@ -0,0 +1,15 @@
from flask import jsonify
def success_response(data=None, message="Success", code = 0):
return jsonify({
"code": code,
"message": message,
"data": data
}), 200
def error_response(message="Error", code=-1, data=None):
return jsonify({
"code": code,
"message": message,
"data": data
}), 400

190
admin/routes.py Normal file
View File

@ -0,0 +1,190 @@
from flask import Blueprint, request
from auth import login_verify
from responses import success_response, error_response
from services import UserMgr, ServiceMgr, UserServiceMgr
from exceptions import AdminException
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
@admin_bp.route('/auth', methods=['GET'])
@login_verify
def auth_admin():
try:
return success_response(None, "Admin is authorized", 0)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/users', methods=['GET'])
@login_verify
def list_users():
try:
users = UserMgr.get_all_users()
return success_response(users, "Get all users", 0)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/users', methods=['POST'])
@login_verify
def create_user():
try:
data = request.get_json()
if not data or 'username' not in data or 'password' not in data:
return error_response("Username and password are required", 400)
username = data['username']
password = data['password']
role = data.get('role', 'user')
res = UserMgr.create_user(username, password, role)
if res["success"]:
user_info = res["user_info"]
user_info.pop("password") # do not return password
return success_response(user_info, "User created successfully")
else:
return error_response("create user failed")
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e))
@admin_bp.route('/users/<username>', methods=['DELETE'])
@login_verify
def delete_user(username):
try:
res = UserMgr.delete_user(username)
if res["success"]:
return success_response(None, res["message"])
else:
return error_response(res["message"])
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/users/<username>/password', methods=['PUT'])
@login_verify
def change_password(username):
try:
data = request.get_json()
if not data or 'new_password' not in data:
return error_response("New password is required", 400)
new_password = data['new_password']
msg = UserMgr.update_user_password(username, new_password)
return success_response(None, msg)
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/users/<username>/activate', methods=['PUT'])
@login_verify
def alter_user_activate_status(username):
try:
data = request.get_json()
if not data or 'activate_status' not in data:
return error_response("Activation status is required", 400)
activate_status = data['activate_status']
msg = UserMgr.update_user_activate_status(username, activate_status)
return success_response(None, msg)
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/users/<username>', methods=['GET'])
@login_verify
def get_user_details(username):
try:
user_details = UserMgr.get_user_details(username)
return success_response(user_details)
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/users/<username>/datasets', methods=['GET'])
@login_verify
def get_user_datasets(username):
try:
datasets_list = UserServiceMgr.get_user_datasets(username)
return success_response(datasets_list)
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/users/<username>/agents', methods=['GET'])
@login_verify
def get_user_agents(username):
try:
agents_list = UserServiceMgr.get_user_agents(username)
return success_response(agents_list)
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/services', methods=['GET'])
@login_verify
def get_services():
try:
services = ServiceMgr.get_all_services()
return success_response(services, "Get all services", 0)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/service_types/<service_type>', methods=['GET'])
@login_verify
def get_services_by_type(service_type_str):
try:
services = ServiceMgr.get_services_by_type(service_type_str)
return success_response(services)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/services/<service_id>', methods=['GET'])
@login_verify
def get_service(service_id):
try:
services = ServiceMgr.get_service_details(service_id)
return success_response(services)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/services/<service_id>', methods=['DELETE'])
@login_verify
def shutdown_service(service_id):
try:
services = ServiceMgr.shutdown_service(service_id)
return success_response(services)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route('/services/<service_id>', methods=['PUT'])
@login_verify
def restart_service(service_id):
try:
services = ServiceMgr.restart_service(service_id)
return success_response(services)
except Exception as e:
return error_response(str(e), 500)

175
admin/services.py Normal file
View File

@ -0,0 +1,175 @@
import re
from werkzeug.security import check_password_hash
from api.db import ActiveEnum
from api.db.services import UserService
from api.db.joint_services.user_account_service import create_new_user, delete_user_data
from api.db.services.canvas_service import UserCanvasService
from api.db.services.user_service import TenantService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.crypt import decrypt
from exceptions import AdminException, UserAlreadyExistsError, UserNotFoundError
from config import SERVICE_CONFIGS
class UserMgr:
@staticmethod
def get_all_users():
users = UserService.get_all_users()
result = []
for user in users:
result.append({'email': user.email, 'nickname': user.nickname, 'create_date': user.create_date, 'is_active': user.is_active})
return result
@staticmethod
def get_user_details(username):
# use email to query
users = UserService.query_user_by_email(username)
result = []
for user in users:
result.append({
'email': user.email,
'language': user.language,
'last_login_time': user.last_login_time,
'is_authenticated': user.is_authenticated,
'is_active': user.is_active,
'is_anonymous': user.is_anonymous,
'login_channel': user.login_channel,
'status': user.status,
'is_superuser': user.is_superuser,
'create_date': user.create_date,
'update_date': user.update_date
})
return result
@staticmethod
def create_user(username, password, role="user") -> dict:
# Validate the email address
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", username):
raise AdminException(f"Invalid email address: {username}!")
# Check if the email address is already used
if UserService.query(email=username):
raise UserAlreadyExistsError(username)
# Construct user info data
user_info_dict = {
"email": username,
"nickname": "", # ask user to edit it manually in settings.
"password": decrypt(password),
"login_channel": "password",
"is_superuser": role == "admin",
}
return create_new_user(user_info_dict)
@staticmethod
def delete_user(username):
# use email to delete
user_list = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
if len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
usr = user_list[0]
return delete_user_data(usr.id)
@staticmethod
def update_user_password(username, new_password) -> str:
# use email to find user. check exist and unique.
user_list = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
# check new_password different from old.
usr = user_list[0]
psw = decrypt(new_password)
if check_password_hash(usr.password, psw):
return "Same password, no need to update!"
# update password
UserService.update_user_password(usr.id, psw)
return "Password updated successfully!"
@staticmethod
def update_user_activate_status(username, activate_status: str):
# use email to find user. check exist and unique.
user_list = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
# check activate status different from new
usr = user_list[0]
# format activate_status before handle
_activate_status = activate_status.lower()
target_status = {
'on': ActiveEnum.ACTIVE.value,
'off': ActiveEnum.INACTIVE.value,
}.get(_activate_status)
if not target_status:
raise AdminException(f"Invalid activate_status: {activate_status}")
if target_status == usr.is_active:
return f"User activate status is already {_activate_status}!"
# update is_active
UserService.update_user(usr.id, {"is_active": target_status})
return f"Turn {_activate_status} user activate status successfully!"
class UserServiceMgr:
@staticmethod
def get_user_datasets(username):
# use email to find user.
user_list = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
# find tenants
usr = user_list[0]
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
tenant_ids = [m["tenant_id"] for m in tenants]
# filter permitted kb and owned kb
return KnowledgebaseService.get_all_kb_by_tenant_ids(tenant_ids, usr.id)
@staticmethod
def get_user_agents(username):
# use email to find user.
user_list = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
# find tenants
usr = user_list[0]
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
tenant_ids = [m["tenant_id"] for m in tenants]
# filter permitted agents and owned agents
res = UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
return [{
'title': r['title'],
'permission': r['permission'],
'canvas_type': r['canvas_type'],
'canvas_category': r['canvas_category']
} for r in res]
class ServiceMgr:
@staticmethod
def get_all_services():
result = []
configs = SERVICE_CONFIGS.configs
for config in configs:
result.append(config.to_dict())
return result
@staticmethod
def get_services_by_type(service_type_str: str):
raise AdminException("get_services_by_type: not implemented")
@staticmethod
def get_service_details(service_id: int):
raise AdminException("get_service_details: not implemented")
@staticmethod
def shutdown_service(service_id: int):
raise AdminException("shutdown_service: not implemented")
@staticmethod
def restart_service(service_id: int):
raise AdminException("restart_service: not implemented")

View File

@ -27,7 +27,7 @@ from agent.component import component_class
from agent.component.base import ComponentBase from agent.component.base import ComponentBase
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.utils import get_uuid, hash_str2int from api.utils import get_uuid, hash_str2int
from rag.prompts.prompts import chunks_format from rag.prompts.generator import chunks_format
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
class Graph: class Graph:
@ -490,7 +490,8 @@ class Canvas(Graph):
r = self.retrieval[-1] r = self.retrieval[-1]
for ck in chunks_format({"chunks": chunks}): for ck in chunks_format({"chunks": chunks}):
cid = hash_str2int(ck["id"], 100) cid = hash_str2int(ck["id"], 500)
# cid = uuid.uuid5(uuid.NAMESPACE_DNS, ck["id"])
if cid not in r: if cid not in r:
r["chunks"][cid] = ck r["chunks"][cid] = ck

View File

@ -28,9 +28,8 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.mcp_server_service import MCPServerService from api.db.services.mcp_server_service import MCPServerService
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
from rag.prompts import message_fit_in from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
from rag.prompts.prompts import next_step, COMPLETE_TASK, analyze_task, \ citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from agent.component.llm import LLMParam, LLM from agent.component.llm import LLMParam, LLM
@ -138,7 +137,7 @@ class Agent(LLM, ToolBase):
res.update(cpn.get_input_form()) res.update(cpn.get_input_form())
return res return res
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if kwargs.get("user_prompt"): if kwargs.get("user_prompt"):
usr_pmt = "" usr_pmt = ""

View File

@ -244,7 +244,7 @@ class ComponentParamBase(ABC):
if not value_legal: if not value_legal:
raise ValueError( raise ValueError(
"Plase check runtime conf, {} = {} does not match user-parameter restriction".format( "Please check runtime conf, {} = {} does not match user-parameter restriction".format(
variable, value variable, value
) )
) )
@ -431,7 +431,7 @@ class ComponentBase(ABC):
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output() return self.output()
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
raise NotImplementedError() raise NotImplementedError()

View File

@ -28,7 +28,7 @@ from rag.llm.chat_model import ERROR_PREFIX
class CategorizeParam(LLMParam): class CategorizeParam(LLMParam):
""" """
Define the Categorize component parameters. Define the categorize component parameters.
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -80,7 +80,7 @@ Here's description of each category:
- Prioritize the most specific applicable category - Prioritize the most specific applicable category
- Return only the category name without explanations - Return only the category name without explanations
- Use "Other" only when no other category fits - Use "Other" only when no other category fits
""".format( """.format(
"\n - ".join(list(self.category_description.keys())), "\n - ".join(list(self.category_description.keys())),
"\n".join(descriptions) "\n".join(descriptions)
@ -96,7 +96,7 @@ Here's description of each category:
class Categorize(LLM, ABC): class Categorize(LLM, ABC):
component_name = "Categorize" component_name = "Categorize"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
msg = self._canvas.get_history(self._param.message_history_window_size) msg = self._canvas.get_history(self._param.message_history_window_size)
if not msg: if not msg:
@ -112,7 +112,7 @@ class Categorize(LLM, ABC):
user_prompt = """ user_prompt = """
---- Real Data ---- ---- Real Data ----
{} {}
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg])) """.format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf()) ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
logging.info(f"input: {user_prompt}, answer: {str(ans)}") logging.info(f"input: {user_prompt}, answer: {str(ans)}")
@ -134,4 +134,4 @@ class Categorize(LLM, ABC):
self.set_output("_next", cpn_ids) self.set_output("_next", cpn_ids)
def thoughts(self) -> str: def thoughts(self) -> str:
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()])) return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))

View File

@ -53,7 +53,7 @@ class InvokeParam(ComponentParamBase):
class Invoke(ComponentBase, ABC): class Invoke(ComponentBase, ABC):
component_name = "Invoke" component_name = "Invoke"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
args = {} args = {}
for para in self._param.variables: for para in self._param.variables:

View File

@ -26,8 +26,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
from rag.prompts import message_fit_in, citation_prompt from rag.prompts.generator import tool_call_summary, message_fit_in, citation_prompt
from rag.prompts.prompts import tool_call_summary
class LLMParam(ComponentParamBase): class LLMParam(ComponentParamBase):
@ -82,9 +81,9 @@ class LLMParam(ComponentParamBase):
class LLM(ComponentBase): class LLM(ComponentBase):
component_name = "LLM" component_name = "LLM"
def __init__(self, canvas, id, param: ComponentParamBase): def __init__(self, canvas, component_id, param: ComponentParamBase):
super().__init__(canvas, id, param) super().__init__(canvas, component_id, param)
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id),
self._param.llm_id, max_retries=self._param.max_retries, self._param.llm_id, max_retries=self._param.max_retries,
retry_interval=self._param.delay_after_error retry_interval=self._param.delay_after_error
@ -206,7 +205,7 @@ class LLM(ComponentBase):
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs): for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
yield delta(txt) yield delta(txt)
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
def clean_formated_answer(ans: str) -> str: def clean_formated_answer(ans: str) -> str:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
@ -214,7 +213,7 @@ class LLM(ComponentBase):
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL) return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
prompt, msg, _ = self._prepare_prompt_variables() prompt, msg, _ = self._prepare_prompt_variables()
error = "" error: str = ""
if self._param.output_structure: if self._param.output_structure:
prompt += "\nThe output MUST follow this JSON format:\n"+json.dumps(self._param.output_structure, ensure_ascii=False, indent=2) prompt += "\nThe output MUST follow this JSON format:\n"+json.dumps(self._param.output_structure, ensure_ascii=False, indent=2)

View File

@ -49,7 +49,7 @@ class MessageParam(ComponentParamBase):
class Message(ComponentBase): class Message(ComponentBase):
component_name = "Message" component_name = "Message"
def get_kwargs(self, script:str, kwargs:dict = {}, delimeter:str=None) -> tuple[str, dict[str, str | list | Any]]: def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]:
for k,v in self.get_input_elements_from_text(script).items(): for k,v in self.get_input_elements_from_text(script).items():
if k in kwargs: if k in kwargs:
continue continue
@ -60,8 +60,8 @@ class Message(ComponentBase):
if isinstance(v, partial): if isinstance(v, partial):
for t in v(): for t in v():
ans += t ans += t
elif isinstance(v, list) and delimeter: elif isinstance(v, list) and delimiter:
ans = delimeter.join([str(vv) for vv in v]) ans = delimiter.join([str(vv) for vv in v])
elif not isinstance(v, str): elif not isinstance(v, str):
try: try:
ans = json.dumps(v, ensure_ascii=False) ans = json.dumps(v, ensure_ascii=False)
@ -127,7 +127,7 @@ class Message(ComponentBase):
] ]
return any([re.search(p, content) for p in patt]) return any([re.search(p, content) for p in patt])
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
rand_cnt = random.choice(self._param.content) rand_cnt = random.choice(self._param.content)
if self._param.stream and not self._is_jinjia2(rand_cnt): if self._param.stream and not self._is_jinjia2(rand_cnt):

View File

@ -56,7 +56,7 @@ class StringTransform(Message, ABC):
"type": "line" "type": "line"
} for k, o in self.get_input_elements_from_text(self._param.script).items()} } for k, o in self.get_input_elements_from_text(self._param.script).items()}
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self._param.method == "split": if self._param.method == "split":
self._split(kwargs.get("line")) self._split(kwargs.get("line"))
@ -90,7 +90,7 @@ class StringTransform(Message, ABC):
for k,v in kwargs.items(): for k,v in kwargs.items():
if not v: if not v:
v = "" v = ""
script = re.sub(k, v, script) script = re.sub(k, lambda match: v, script)
self.set_output("result", script) self.set_output("result", script)

View File

@ -61,7 +61,7 @@ class SwitchParam(ComponentParamBase):
class Switch(ComponentBase, ABC): class Switch(ComponentBase, ABC):
component_name = "Switch" component_name = "Switch"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
for cond in self._param.conditions: for cond in self._param.conditions:
res = [] res = []

View File

@ -61,7 +61,7 @@ class ArXivParam(ToolParamBase):
class ArXiv(ToolBase, ABC): class ArXiv(ToolBase, ABC):
component_name = "ArXiv" component_name = "ArXiv"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
@ -97,6 +97,6 @@ class ArXiv(ToolBase, ABC):
def thoughts(self) -> str: def thoughts(self) -> str:
return """ return """
Keywords: {} Keywords: {}
Looking for the most relevant articles. Looking for the most relevant articles.
""".format(self.get_input().get("query", "-_-!")) """.format(self.get_input().get("query", "-_-!"))

View File

@ -22,7 +22,7 @@ from typing import TypedDict, List, Any
from agent.component.base import ComponentParamBase, ComponentBase from agent.component.base import ComponentParamBase, ComponentBase
from api.utils import hash_str2int from api.utils import hash_str2int
from rag.llm.chat_model import ToolCallSession from rag.llm.chat_model import ToolCallSession
from rag.prompts.prompts import kb_prompt from rag.prompts.generator import kb_prompt
from rag.utils.mcp_tool_call_conn import MCPToolCallSession from rag.utils.mcp_tool_call_conn import MCPToolCallSession
from timeit import default_timer as timer from timeit import default_timer as timer

View File

@ -129,7 +129,7 @@ module.exports = { main };
class CodeExec(ToolBase, ABC): class CodeExec(ToolBase, ABC):
component_name = "CodeExec" component_name = "CodeExec"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
lang = kwargs.get("lang", self._param.lang) lang = kwargs.get("lang", self._param.lang)
script = kwargs.get("script", self._param.script) script = kwargs.get("script", self._param.script)
@ -157,7 +157,7 @@ class CodeExec(ToolBase, ABC):
try: try:
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)) resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:") logging.info(f"http://{settings.SANDBOX_HOST}:9385/run", code_req, resp.status_code)
if resp.status_code != 200: if resp.status_code != 200:
resp.raise_for_status() resp.raise_for_status()
body = resp.json() body = resp.json()

View File

@ -73,7 +73,7 @@ class DuckDuckGoParam(ToolParamBase):
class DuckDuckGo(ToolBase, ABC): class DuckDuckGo(ToolBase, ABC):
component_name = "DuckDuckGo" component_name = "DuckDuckGo"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
@ -115,6 +115,6 @@ class DuckDuckGo(ToolBase, ABC):
def thoughts(self) -> str: def thoughts(self) -> str:
return """ return """
Keywords: {} Keywords: {}
Looking for the most relevant articles. Looking for the most relevant articles.
""".format(self.get_input().get("query", "-_-!")) """.format(self.get_input().get("query", "-_-!"))

View File

@ -98,8 +98,8 @@ class EmailParam(ToolParamBase):
class Email(ToolBase, ABC): class Email(ToolBase, ABC):
component_name = "Email" component_name = "Email"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if not kwargs.get("to_email"): if not kwargs.get("to_email"):
self.set_output("success", False) self.set_output("success", False)
@ -212,4 +212,4 @@ class Email(ToolBase, ABC):
To: {} To: {}
Subject: {} Subject: {}
Your email is on its way—sit tight! Your email is on its way—sit tight!
""".format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!")) """.format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!"))

View File

@ -53,7 +53,7 @@ class ExeSQLParam(ToolParamBase):
self.max_records = 1024 self.max_records = 1024
def check(self): def check(self):
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql']) self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2'])
self.check_empty(self.database, "Database name") self.check_empty(self.database, "Database name")
self.check_empty(self.username, "database username") self.check_empty(self.username, "database username")
self.check_empty(self.host, "IP Address") self.check_empty(self.host, "IP Address")
@ -78,7 +78,7 @@ class ExeSQLParam(ToolParamBase):
class ExeSQL(ToolBase, ABC): class ExeSQL(ToolBase, ABC):
component_name = "ExeSQL" component_name = "ExeSQL"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
def convert_decimals(obj): def convert_decimals(obj):
@ -123,6 +123,55 @@ class ExeSQL(ToolBase, ABC):
r'PWD=' + self._param.password r'PWD=' + self._param.password
) )
db = pyodbc.connect(conn_str) db = pyodbc.connect(conn_str)
elif self._param.db_type == 'IBM DB2':
import ibm_db
conn_str = (
f"DATABASE={self._param.database};"
f"HOSTNAME={self._param.host};"
f"PORT={self._param.port};"
f"PROTOCOL=TCPIP;"
f"UID={self._param.username};"
f"PWD={self._param.password};"
)
try:
conn = ibm_db.connect(conn_str, "", "")
except Exception as e:
raise Exception("Database Connection Failed! \n" + str(e))
sql_res = []
formalized_content = []
for single_sql in sqls:
single_sql = single_sql.replace("```", "").strip()
if not single_sql:
continue
single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql)
stmt = ibm_db.exec_immediate(conn, single_sql)
rows = []
row = ibm_db.fetch_assoc(stmt)
while row and len(rows) < self._param.max_records:
rows.append(row)
row = ibm_db.fetch_assoc(stmt)
if not rows:
sql_res.append({"content": "No record in the database!"})
continue
df = pd.DataFrame(rows)
for col in df.columns:
if pd.api.types.is_datetime64_any_dtype(df[col]):
df[col] = df[col].dt.strftime("%Y-%m-%d")
df = df.where(pd.notnull(df), None)
sql_res.append(convert_decimals(df.to_dict(orient="records")))
formalized_content.append(df.to_markdown(index=False, floatfmt=".6f"))
ibm_db.close(conn)
self.set_output("json", sql_res)
self.set_output("formalized_content", "\n\n".join(formalized_content))
return self.output("formalized_content")
try: try:
cursor = db.cursor() cursor = db.cursor()
except Exception as e: except Exception as e:
@ -150,6 +199,8 @@ class ExeSQL(ToolBase, ABC):
if pd.api.types.is_datetime64_any_dtype(single_res[col]): if pd.api.types.is_datetime64_any_dtype(single_res[col]):
single_res[col] = single_res[col].dt.strftime('%Y-%m-%d') single_res[col] = single_res[col].dt.strftime('%Y-%m-%d')
single_res = single_res.where(pd.notnull(single_res), None)
sql_res.append(convert_decimals(single_res.to_dict(orient='records'))) sql_res.append(convert_decimals(single_res.to_dict(orient='records')))
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f")) formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))

View File

@ -57,7 +57,7 @@ class GitHubParam(ToolParamBase):
class GitHub(ToolBase, ABC): class GitHub(ToolBase, ABC):
component_name = "GitHub" component_name = "GitHub"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
@ -88,4 +88,4 @@ class GitHub(ToolBase, ABC):
assert False, self.output() assert False, self.output()
def thoughts(self) -> str: def thoughts(self) -> str:
return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!")) return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!"))

View File

@ -116,7 +116,7 @@ class GoogleParam(ToolParamBase):
class Google(ToolBase, ABC): class Google(ToolBase, ABC):
component_name = "Google" component_name = "Google"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if not kwargs.get("q"): if not kwargs.get("q"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
@ -154,6 +154,6 @@ class Google(ToolBase, ABC):
def thoughts(self) -> str: def thoughts(self) -> str:
return """ return """
Keywords: {} Keywords: {}
Looking for the most relevant articles. Looking for the most relevant articles.
""".format(self.get_input().get("query", "-_-!")) """.format(self.get_input().get("query", "-_-!"))

View File

@ -63,7 +63,7 @@ class GoogleScholarParam(ToolParamBase):
class GoogleScholar(ToolBase, ABC): class GoogleScholar(ToolBase, ABC):
component_name = "GoogleScholar" component_name = "GoogleScholar"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
@ -93,4 +93,4 @@ class GoogleScholar(ToolBase, ABC):
assert False, self.output() assert False, self.output()
def thoughts(self) -> str: def thoughts(self) -> str:
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!")) return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))

View File

@ -33,7 +33,7 @@ class PubMedParam(ToolParamBase):
self.meta:ToolMeta = { self.meta:ToolMeta = {
"name": "pubmed_search", "name": "pubmed_search",
"description": """ "description": """
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics. PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
In addition to MEDLINE, PubMed provides access to: In addition to MEDLINE, PubMed provides access to:
- older references from the print version of Index Medicus, back to 1951 and earlier - older references from the print version of Index Medicus, back to 1951 and earlier
- references to some journals before they were indexed in Index Medicus and MEDLINE, for instance Science, BMJ, and Annals of Surgery - references to some journals before they were indexed in Index Medicus and MEDLINE, for instance Science, BMJ, and Annals of Surgery
@ -69,7 +69,7 @@ In addition to MEDLINE, PubMed provides access to:
class PubMed(ToolBase, ABC): class PubMed(ToolBase, ABC):
component_name = "PubMed" component_name = "PubMed"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
@ -105,4 +105,4 @@ class PubMed(ToolBase, ABC):
assert False, self.output() assert False, self.output()
def thoughts(self) -> str: def thoughts(self) -> str:
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!")) return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))

View File

@ -23,8 +23,7 @@ from api.db.services.llm_service import LLMBundle
from api import settings from api import settings
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.prompts import kb_prompt from rag.prompts.generator import cross_languages, kb_prompt
from rag.prompts.prompts import cross_languages
class RetrievalParam(ToolParamBase): class RetrievalParam(ToolParamBase):
@ -75,7 +74,7 @@ class RetrievalParam(ToolParamBase):
class Retrieval(ToolBase, ABC): class Retrieval(ToolBase, ABC):
component_name = "Retrieval" component_name = "Retrieval"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response) self.set_output("formalized_content", self._param.empty_response)
@ -163,13 +162,20 @@ class Retrieval(ToolBase, ABC):
self.set_output("formalized_content", self._param.empty_response) self.set_output("formalized_content", self._param.empty_response)
return return
# Format the chunks for JSON output (similar to how other tools do it)
json_output = kbinfos["chunks"].copy()
self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"]) self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"])
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True)) form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
# Set both formalized content and JSON output
self.set_output("formalized_content", form_cnt) self.set_output("formalized_content", form_cnt)
self.set_output("json", json_output)
return form_cnt return form_cnt
def thoughts(self) -> str: def thoughts(self) -> str:
return """ return """
Keywords: {} Keywords: {}
Looking for the most relevant articles. Looking for the most relevant articles.
""".format(self.get_input().get("query", "-_-!")) """.format(self.get_input().get("query", "-_-!"))

View File

@ -77,7 +77,7 @@ class SearXNGParam(ToolParamBase):
class SearXNG(ToolBase, ABC): class SearXNG(ToolBase, ABC):
component_name = "SearXNG" component_name = "SearXNG"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
# Gracefully handle try-run without inputs # Gracefully handle try-run without inputs
query = kwargs.get("query") query = kwargs.get("query")
@ -94,7 +94,6 @@ class SearXNG(ToolBase, ABC):
last_e = "" last_e = ""
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
try: try:
# 构建搜索参数
search_params = { search_params = {
'q': query, 'q': query,
'format': 'json', 'format': 'json',
@ -104,33 +103,29 @@ class SearXNG(ToolBase, ABC):
'pageno': 1 'pageno': 1
} }
# 发送搜索请求
response = requests.get( response = requests.get(
f"{searxng_url}/search", f"{searxng_url}/search",
params=search_params, params=search_params,
timeout=10 timeout=10
) )
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
# 验证响应数据
if not data or not isinstance(data, dict): if not data or not isinstance(data, dict):
raise ValueError("Invalid response from SearXNG") raise ValueError("Invalid response from SearXNG")
results = data.get("results", []) results = data.get("results", [])
if not isinstance(results, list): if not isinstance(results, list):
raise ValueError("Invalid results format from SearXNG") raise ValueError("Invalid results format from SearXNG")
# 限制结果数量
results = results[:self._param.top_n] results = results[:self._param.top_n]
# 处理搜索结果
self._retrieve_chunks(results, self._retrieve_chunks(results,
get_title=lambda r: r.get("title", ""), get_title=lambda r: r.get("title", ""),
get_url=lambda r: r.get("url", ""), get_url=lambda r: r.get("url", ""),
get_content=lambda r: r.get("content", "")) get_content=lambda r: r.get("content", ""))
self.set_output("json", results) self.set_output("json", results)
return self.output("formalized_content") return self.output("formalized_content")
@ -151,6 +146,6 @@ class SearXNG(ToolBase, ABC):
def thoughts(self) -> str: def thoughts(self) -> str:
return """ return """
Keywords: {} Keywords: {}
Searching with SearXNG for relevant results... Searching with SearXNG for relevant results...
""".format(self.get_input().get("query", "-_-!")) """.format(self.get_input().get("query", "-_-!"))

View File

@ -31,7 +31,7 @@ class TavilySearchParam(ToolParamBase):
self.meta:ToolMeta = { self.meta:ToolMeta = {
"name": "tavily_search", "name": "tavily_search",
"description": """ "description": """
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results. Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
When searching: When searching:
- Start with specific query which should focus on just a single aspect. - Start with specific query which should focus on just a single aspect.
- Number of keywords in query should be less than 5. - Number of keywords in query should be less than 5.
@ -101,7 +101,7 @@ When searching:
class TavilySearch(ToolBase, ABC): class TavilySearch(ToolBase, ABC):
component_name = "TavilySearch" component_name = "TavilySearch"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
@ -136,7 +136,7 @@ class TavilySearch(ToolBase, ABC):
def thoughts(self) -> str: def thoughts(self) -> str:
return """ return """
Keywords: {} Keywords: {}
Looking for the most relevant articles. Looking for the most relevant articles.
""".format(self.get_input().get("query", "-_-!")) """.format(self.get_input().get("query", "-_-!"))
@ -199,7 +199,7 @@ class TavilyExtractParam(ToolParamBase):
class TavilyExtract(ToolBase, ABC): class TavilyExtract(ToolBase, ABC):
component_name = "TavilyExtract" component_name = "TavilyExtract"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
self.tavily_client = TavilyClient(api_key=self._param.api_key) self.tavily_client = TavilyClient(api_key=self._param.api_key)
last_e = None last_e = None
@ -224,4 +224,4 @@ class TavilyExtract(ToolBase, ABC):
assert False, self.output() assert False, self.output()
def thoughts(self) -> str: def thoughts(self) -> str:
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!")) return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))

View File

@ -68,7 +68,7 @@ fund selection platform: through AI technology, is committed to providing excell
class WenCai(ToolBase, ABC): class WenCai(ToolBase, ABC):
component_name = "WenCai" component_name = "WenCai"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("report", "") self.set_output("report", "")
@ -111,4 +111,4 @@ class WenCai(ToolBase, ABC):
assert False, self.output() assert False, self.output()
def thoughts(self) -> str: def thoughts(self) -> str:
return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!")) return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!"))

View File

@ -64,7 +64,7 @@ class WikipediaParam(ToolParamBase):
class Wikipedia(ToolBase, ABC): class Wikipedia(ToolBase, ABC):
component_name = "Wikipedia" component_name = "Wikipedia"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
@ -99,6 +99,6 @@ class Wikipedia(ToolBase, ABC):
def thoughts(self) -> str: def thoughts(self) -> str:
return """ return """
Keywords: {} Keywords: {}
Looking for the most relevant articles. Looking for the most relevant articles.
""".format(self.get_input().get("query", "-_-!")) """.format(self.get_input().get("query", "-_-!"))

View File

@ -72,7 +72,7 @@ class YahooFinanceParam(ToolParamBase):
class YahooFinance(ToolBase, ABC): class YahooFinance(ToolBase, ABC):
component_name = "YahooFinance" component_name = "YahooFinance"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if not kwargs.get("stock_code"): if not kwargs.get("stock_code"):
self.set_output("report", "") self.set_output("report", "")
@ -111,4 +111,4 @@ class YahooFinance(ToolBase, ABC):
assert False, self.output() assert False, self.output()
def thoughts(self) -> str: def thoughts(self) -> str:
return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!")) return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!"))

View File

@ -27,7 +27,8 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from api.db import StatusEnum from api.db import StatusEnum
from api.db.db_models import close_connection from api.db.db_models import close_connection
from api.db.services import UserService from api.db.services import UserService
from api.utils import CustomJSONEncoder, commands from api.utils.json import CustomJSONEncoder
from api.utils import commands
from flask_mail import Mail from flask_mail import Mail
from flask_session import Session from flask_session import Session

View File

@ -39,7 +39,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
from api.utils.file_utils import filename_type, thumbnail from api.utils.file_utils import filename_type, thumbnail
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.prompts import keyword_extraction from rag.prompts.generator import keyword_extraction
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import UserCanvasService

View File

@ -100,7 +100,7 @@ def save():
def get(canvas_id): def get(canvas_id):
if not UserCanvasService.accessible(canvas_id, current_user.id): if not UserCanvasService.accessible(canvas_id, current_user.id):
return get_data_error_result(message="canvas not found.") return get_data_error_result(message="canvas not found.")
e, c = UserCanvasService.get_by_tenant_id(canvas_id) e, c = UserCanvasService.get_by_canvas_id(canvas_id)
return get_json_result(data=c) return get_json_result(data=c)
@ -243,7 +243,7 @@ def reset():
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821 @manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
def upload(canvas_id): def upload(canvas_id):
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id) e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
if not e: if not e:
return get_data_error_result(message="canvas not found.") return get_data_error_result(message="canvas not found.")
@ -393,6 +393,22 @@ def test_db_connect():
cursor = db.cursor() cursor = db.cursor()
cursor.execute("SELECT 1") cursor.execute("SELECT 1")
cursor.close() cursor.close()
elif req["db_type"] == 'IBM DB2':
import ibm_db
conn_str = (
f"DATABASE={req['database']};"
f"HOSTNAME={req['host']};"
f"PORT={req['port']};"
f"PROTOCOL=TCPIP;"
f"UID={req['username']};"
f"PWD={req['password']};"
)
logging.info(conn_str)
conn = ibm_db.connect(conn_str, "", "")
stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1")
ibm_db.fetch_assoc(stmt)
ibm_db.close(conn)
return get_json_result(data="Database Connection Successful!")
else: else:
return server_error_response("Unsupported database type.") return server_error_response("Unsupported database type.")
if req["db_type"] != 'mssql': if req["db_type"] != 'mssql':
@ -529,7 +545,7 @@ def sessions(canvas_id):
@manager.route('/prompts', methods=['GET']) # noqa: F821 @manager.route('/prompts', methods=['GET']) # noqa: F821
@login_required @login_required
def prompts(): def prompts():
from rag.prompts.prompts import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
return get_json_result(data={ return get_json_result(data={
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER, "task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
"plan_generation": NEXT_STEP, "plan_generation": NEXT_STEP,

View File

@ -33,8 +33,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, server_e
from rag.app.qa import beAdoc, rmPrefix from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
from rag.prompts import cross_languages, keyword_extraction from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
from rag.prompts.prompts import gen_meta_filter
from rag.settings import PAGERANK_FLD from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace from rag.utils import rmSpace

View File

@ -15,7 +15,7 @@
# #
import json import json
import re import re
import traceback import logging
from copy import deepcopy from copy import deepcopy
from flask import Response, request from flask import Response, request
from flask_login import current_user, login_required from flask_login import current_user, login_required
@ -29,8 +29,8 @@ from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from rag.prompts.prompt_template import load_prompt from rag.prompts.template import load_prompt
from rag.prompts.prompts import chunks_format from rag.prompts.generator import chunks_format
@manager.route("/set", methods=["POST"]) # noqa: F821 @manager.route("/set", methods=["POST"]) # noqa: F821
@ -226,7 +226,7 @@ def completion():
if not is_embedded: if not is_embedded:
ConversationService.update_by_id(conv.id, conv.to_dict()) ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e: except Exception as e:
traceback.print_exc() logging.exception(e)
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"

View File

@ -577,7 +577,7 @@ def change_parser():
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
try: try:
if req.get("pipeline_id"): if "pipeline_id" in req:
if doc.pipeline_id == req["pipeline_id"]: if doc.pipeline_id == req["pipeline_id"]:
return get_json_result(data=True) return get_json_result(data=True)
DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"]}) DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"]})

View File

@ -246,6 +246,8 @@ def rm():
return get_data_error_result(message="File or Folder not found!") return get_data_error_result(message="File or Folder not found!")
if not file.tenant_id: if not file.tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
if file.tenant_id != current_user.id:
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
if file.source_type == FileSource.KNOWLEDGEBASE: if file.source_type == FileSource.KNOWLEDGEBASE:
continue continue
@ -292,6 +294,8 @@ def rename():
e, file = FileService.get_by_id(req["file_id"]) e, file = FileService.get_by_id(req["file_id"])
if not e: if not e:
return get_data_error_result(message="File not found!") return get_data_error_result(message="File not found!")
if file.tenant_id != current_user.id:
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
if file.type != FileType.FOLDER.value \ if file.type != FileType.FOLDER.value \
and pathlib.Path(req["name"].lower()).suffix != pathlib.Path( and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
file.name.lower()).suffix: file.name.lower()).suffix:
@ -328,6 +332,8 @@ def get(file_id):
e, file = FileService.get_by_id(file_id) e, file = FileService.get_by_id(file_id)
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if file.tenant_id != current_user.id:
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
blob = STORAGE_IMPL.get(file.parent_id, file.location) blob = STORAGE_IMPL.get(file.parent_id, file.location)
if not blob: if not blob:
@ -367,6 +373,8 @@ def move():
return get_data_error_result(message="File or Folder not found!") return get_data_error_result(message="File or Folder not found!")
if not file.tenant_id: if not file.tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
if file.tenant_id != current_user.id:
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
fe, _ = FileService.get_by_id(parent_id) fe, _ = FileService.get_by_id(parent_id)
if not fe: if not fe:
return get_data_error_result(message="Parent Folder not found!") return get_data_error_result(message="Parent Folder not found!")

View File

@ -40,7 +40,7 @@ from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_
from rag.app.qa import beAdoc, rmPrefix from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
from rag.prompts import cross_languages, keyword_extraction from rag.prompts.generator import cross_languages, keyword_extraction
from rag.utils import rmSpace from rag.utils import rmSpace
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL

View File

@ -3,9 +3,11 @@ import re
import flask import flask
from flask import request from flask import request
from pathlib import Path
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, token_required from api.utils.api_utils import server_error_response, token_required
from api.utils import get_uuid from api.utils import get_uuid
from api.db import FileType from api.db import FileType
@ -81,16 +83,16 @@ def upload(tenant_id):
return get_json_result(data=False, message="Can't find this folder!", code=404) return get_json_result(data=False, message="Can't find this folder!", code=404)
for file_obj in file_objs: for file_obj in file_objs:
# 文件路径处理 # Handle file path
full_path = '/' + file_obj.filename full_path = '/' + file_obj.filename
file_obj_names = full_path.split('/') file_obj_names = full_path.split('/')
file_len = len(file_obj_names) file_len = len(file_obj_names)
# 获取文件夹路径ID # Get folder path ID
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id]) file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
len_id_list = len(file_id_list) len_id_list = len(file_id_list)
# 创建文件夹结构 # Crete file folder
if file_len != len_id_list: if file_len != len_id_list:
e, file = FileService.get_by_id(file_id_list[len_id_list - 1]) e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
if not e: if not e:
@ -666,3 +668,71 @@ def move(tenant_id):
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/file/convert', methods=['POST']) # noqa: F821
@token_required
def convert(tenant_id):
req = request.json
kb_ids = req["kb_ids"]
file_ids = req["file_ids"]
file2documents = []
try:
files = FileService.get_by_ids(file_ids)
files_set = dict({file.id: file for file in files})
for file_id in file_ids:
file = files_set[file_id]
if not file:
return get_json_result(message="File not found!", code=404)
file_ids_list = [file_id]
if file.type == FileType.FOLDER.value:
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
for id in file_ids_list:
informs = File2DocumentService.get_by_file_id(id)
# delete
for inform in informs:
doc_id = inform.document_id
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_json_result(message="Document not found!", code=404)
tenant_id = DocumentService.get_tenant_id(doc_id)
if not tenant_id:
return get_json_result(message="Tenant not found!", code=404)
if not DocumentService.remove_document(doc, tenant_id):
return get_json_result(
message="Database error (Document removal)!", code=404)
File2DocumentService.delete_by_file_id(id)
# insert
for kb_id in kb_ids:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return get_json_result(
message="Can't find this knowledgebase!", code=404)
e, file = FileService.get_by_id(id)
if not e:
return get_json_result(
message="Can't find this file!", code=404)
doc = DocumentService.insert({
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": FileService.get_parser(file.type, file.name, kb.parser_id),
"parser_config": kb.parser_config,
"created_by": tenant_id,
"type": file.type,
"name": file.name,
"suffix": Path(file.name).suffix.lstrip("."),
"location": file.location,
"size": file.size
})
file2document = File2DocumentService.insert({
"id": get_uuid(),
"file_id": id,
"document_id": doc.id,
})
file2documents.append(file2document.to_json())
return get_json_result(data=file2documents)
except Exception as e:
return server_error_response(e)

View File

@ -38,9 +38,8 @@ from api.db.services.user_service import UserTenantService
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.prompts import chunks_format from rag.prompts.template import load_prompt
from rag.prompts.prompt_template import load_prompt from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
from rag.prompts.prompts import cross_languages, gen_meta_filter, keyword_extraction
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821

View File

@ -37,7 +37,8 @@ from timeit import default_timer as timer
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from flask import jsonify from flask import jsonify
from api.utils.health import run_health_checks from api.utils.health_utils import run_health_checks
@manager.route("/version", methods=["GET"]) # noqa: F821 @manager.route("/version", methods=["GET"]) # noqa: F821
@login_required @login_required

View File

@ -34,7 +34,6 @@ from api.db.services.user_service import TenantService, UserService, UserTenantS
from api.utils import ( from api.utils import (
current_timestamp, current_timestamp,
datetime_format, datetime_format,
decrypt,
download_img, download_img,
get_format_time, get_format_time,
get_uuid, get_uuid,
@ -46,6 +45,7 @@ from api.utils.api_utils import (
server_error_response, server_error_response,
validate_request, validate_request,
) )
from api.utils.crypt import decrypt
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821 @manager.route("/login", methods=["POST", "GET"]) # noqa: F821
@ -98,7 +98,14 @@ def login():
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password") return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password")
user = UserService.query_user(email, password) user = UserService.query_user(email, password)
if user:
if user and hasattr(user, 'is_active') and user.is_active == "0":
return get_json_result(
data=False,
code=settings.RetCode.FORBIDDEN,
message="This account has been disabled, please contact the administrator!",
)
elif user:
response_data = user.to_json() response_data = user.to_json()
user.access_token = get_uuid() user.access_token = get_uuid()
login_user(user) login_user(user)
@ -227,6 +234,9 @@ def oauth_callback(channel):
# User exists, try to log in # User exists, try to log in
user = users[0] user = users[0]
user.access_token = get_uuid() user.access_token = get_uuid()
if user and hasattr(user, 'is_active') and user.is_active == "0":
return redirect("/?error=user_inactive")
login_user(user) login_user(user)
user.save() user.save()
return redirect(f"/?auth={user.get_id()}") return redirect(f"/?auth={user.get_id()}")
@ -317,6 +327,8 @@ def github_callback():
# User has already registered, try to log in # User has already registered, try to log in
user = users[0] user = users[0]
user.access_token = get_uuid() user.access_token = get_uuid()
if user and hasattr(user, 'is_active') and user.is_active == "0":
return redirect("/?error=user_inactive")
login_user(user) login_user(user)
user.save() user.save()
return redirect("/?auth=%s" % user.get_id()) return redirect("/?auth=%s" % user.get_id())
@ -418,6 +430,8 @@ def feishu_callback():
# User has already registered, try to log in # User has already registered, try to log in
user = users[0] user = users[0]
if user and hasattr(user, 'is_active') and user.is_active == "0":
return redirect("/?error=user_inactive")
user.access_token = get_uuid() user.access_token = get_uuid()
login_user(user) login_user(user)
user.save() user.save()

2
api/common/README.md Normal file
View File

@ -0,0 +1,2 @@
The python files in this directory are shared between service. They contain common utilities, models, and functions that can be used across various
services to ensure consistency and reduce code duplication.

21
api/common/base64.py Normal file
View File

@ -0,0 +1,21 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
def encode_to_base64(input_string):
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
return base64_encoded.decode('utf-8')

View File

@ -23,6 +23,11 @@ class StatusEnum(Enum):
INVALID = "0" INVALID = "0"
class ActiveEnum(Enum):
ACTIVE = "1"
INACTIVE = "0"
class UserTenantRole(StrEnum): class UserTenantRole(StrEnum):
OWNER = 'owner' OWNER = 'owner'
ADMIN = 'admin' ADMIN = 'admin'
@ -111,7 +116,7 @@ class CanvasCategory(StrEnum):
Agent = "agent_canvas" Agent = "agent_canvas"
DataFlow = "dataflow_canvas" DataFlow = "dataflow_canvas"
VALID_CAVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow} VALID_CANVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow}
class MCPServerType(StrEnum): class MCPServerType(StrEnum):

View File

@ -26,12 +26,14 @@ from functools import wraps
from flask_login import UserMixin from flask_login import UserMixin
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
from api import settings, utils from api import settings, utils
from api.db import ParserType, SerializedType from api.db import ParserType, SerializedType
from api.utils.json import json_dumps, json_loads
from api.utils.configs import deserialize_b64, serialize_b64
def singleton(cls, *args, **kw): def singleton(cls, *args, **kw):
@ -70,12 +72,12 @@ class JSONField(LongTextField):
def db_value(self, value): def db_value(self, value):
if value is None: if value is None:
value = self.default_value value = self.default_value
return utils.json_dumps(value) return json_dumps(value)
def python_value(self, value): def python_value(self, value):
if not value: if not value:
return self.default_value return self.default_value
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
class ListField(JSONField): class ListField(JSONField):
@ -91,21 +93,21 @@ class SerializedField(LongTextField):
def db_value(self, value): def db_value(self, value):
if self._serialized_type == SerializedType.PICKLE: if self._serialized_type == SerializedType.PICKLE:
return utils.serialize_b64(value, to_str=True) return serialize_b64(value, to_str=True)
elif self._serialized_type == SerializedType.JSON: elif self._serialized_type == SerializedType.JSON:
if value is None: if value is None:
return None return None
return utils.json_dumps(value, with_type=True) return json_dumps(value, with_type=True)
else: else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported") raise ValueError(f"the serialized type {self._serialized_type} is not supported")
def python_value(self, value): def python_value(self, value):
if self._serialized_type == SerializedType.PICKLE: if self._serialized_type == SerializedType.PICKLE:
return utils.deserialize_b64(value) return deserialize_b64(value)
elif self._serialized_type == SerializedType.JSON: elif self._serialized_type == SerializedType.JSON:
if value is None: if value is None:
return {} return {}
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
else: else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported") raise ValueError(f"the serialized type {self._serialized_type} is not supported")
@ -250,36 +252,63 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def execute_sql(self, sql, params=None, commit=True): def execute_sql(self, sql, params=None, commit=True):
from peewee import OperationalError
for attempt in range(self.max_retries + 1): for attempt in range(self.max_retries + 1):
try: try:
return super().execute_sql(sql, params, commit) return super().execute_sql(sql, params, commit)
except OperationalError as e: except (OperationalError, InterfaceError) as e:
if e.args[0] in (2013, 2006) and attempt < self.max_retries: error_codes = [2013, 2006]
logging.warning(f"Lost connection (attempt {attempt + 1}/{self.max_retries}): {e}") error_messages = ['', 'Lost connection']
should_retry = (
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
(str(e) in error_messages) or
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
)
if should_retry and attempt < self.max_retries:
logging.warning(
f"Database connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
)
self._handle_connection_loss() self._handle_connection_loss()
time.sleep(self.retry_delay * (2**attempt)) time.sleep(self.retry_delay * (2 ** attempt))
else: else:
logging.error(f"DB execution failure: {e}") logging.error(f"DB execution failure: {e}")
raise raise
return None return None
def _handle_connection_loss(self): def _handle_connection_loss(self):
self.close_all() # self.close_all()
self.connect() # self.connect()
try:
self.close()
except Exception:
pass
try:
self.connect()
except Exception as e:
logging.error(f"Failed to reconnect: {e}")
time.sleep(0.1)
self.connect()
def begin(self): def begin(self):
from peewee import OperationalError
for attempt in range(self.max_retries + 1): for attempt in range(self.max_retries + 1):
try: try:
return super().begin() return super().begin()
except OperationalError as e: except (OperationalError, InterfaceError) as e:
if e.args[0] in (2013, 2006) and attempt < self.max_retries: error_codes = [2013, 2006]
logging.warning(f"Lost connection during transaction (attempt {attempt + 1}/{self.max_retries})") error_messages = ['', 'Lost connection']
should_retry = (
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
(str(e) in error_messages) or
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
)
if should_retry and attempt < self.max_retries:
logging.warning(
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
)
self._handle_connection_loss() self._handle_connection_loss()
time.sleep(self.retry_delay * (2**attempt)) time.sleep(self.retry_delay * (2 ** attempt))
else: else:
raise raise
@ -299,7 +328,16 @@ class BaseDataBase:
def __init__(self): def __init__(self):
database_config = settings.DATABASE.copy() database_config = settings.DATABASE.copy()
db_name = database_config.pop("name") db_name = database_config.pop("name")
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
pool_config = {
'max_retries': 5,
'retry_delay': 1,
}
database_config.update(pool_config)
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(
db_name, **database_config
)
# self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
logging.info("init database on cluster mode successfully") logging.info("init database on cluster mode successfully")

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
# #
import logging import logging
import base64
import json import json
import os import os
import time import time
@ -32,11 +31,7 @@ from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_l
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api import settings from api import settings
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.common.base64 import encode_to_base64
def encode_to_base64(input_string):
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
return base64_encoded.decode('utf-8')
def init_superuser(): def init_superuser():
@ -144,8 +139,9 @@ def init_llm_factory():
except Exception: except Exception:
pass pass
break break
doc_count = DocumentService.get_all_kb_doc_count()
for kb_id in KnowledgebaseService.get_all_ids(): for kb_id in KnowledgebaseService.get_all_ids():
KnowledgebaseService.update_document_number_in_init(kb_id=kb_id, doc_num=DocumentService.get_kb_doc_count(kb_id)) KnowledgebaseService.update_document_number_in_init(kb_id=kb_id, doc_num=doc_count.get(kb_id, 0))

View File

View File

@ -0,0 +1,327 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import uuid
from api import settings
from api.utils.api_utils import group_by
from api.db import FileType, UserTenantRole, ActiveEnum
from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.canvas_service import UserCanvasService
from api.db.services.conversation_service import ConversationService
from api.db.services.dialog_service import DialogService
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.llm_service import get_init_tenant_llm
from api.db.services.file_service import FileService
from api.db.services.mcp_server_service import MCPServerService
from api.db.services.search_service import SearchService
from api.db.services.task_service import TaskService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.db.services.user_service import TenantService, UserService, UserTenantService
from rag.utils.storage_factory import STORAGE_IMPL
from rag.nlp import search
def create_new_user(user_info: dict) -> dict:
"""
Add a new user, and create tenant, tenant llm, file folder for new user.
:param user_info: {
"email": <example@example.com>,
"nickname": <str, "name">,
"password": <decrypted password>,
"login_channel": <enum, "password">,
"is_superuser": <bool, role == "admin">,
}
:return: {
"success": <bool>,
"user_info": <dict>, # if true, return user_info
}
"""
# generate user_id and access_token for user
user_id = uuid.uuid1().hex
user_info['id'] = user_id
user_info['access_token'] = uuid.uuid1().hex
# construct tenant info
tenant = {
"id": user_id,
"name": user_info["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL,
"rerank_id": settings.RERANK_MDL,
}
usr_tenant = {
"tenant_id": user_id,
"user_id": user_id,
"invited_by": user_id,
"role": UserTenantRole.OWNER,
}
# construct file folder info
file_id = uuid.uuid1().hex
file = {
"id": file_id,
"parent_id": file_id,
"tenant_id": user_id,
"created_by": user_id,
"name": "/",
"type": FileType.FOLDER.value,
"size": 0,
"location": "",
}
try:
tenant_llm = get_init_tenant_llm(user_id)
if not UserService.save(**user_info):
return {"success": False}
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
FileService.insert(file)
return {
"success": True,
"user_info": user_info,
}
except Exception as create_error:
logging.exception(create_error)
# rollback
try:
TenantService.delete_by_id(user_id)
except Exception as e:
logging.exception(e)
try:
u = UserTenantService.query(tenant_id=user_id)
if u:
UserTenantService.delete_by_id(u[0].id)
except Exception as e:
logging.exception(e)
try:
TenantLLMService.delete_by_tenant_id(user_id)
except Exception as e:
logging.exception(e)
try:
FileService.delete_by_id(file["id"])
except Exception as e:
logging.exception(e)
# delete user row finally
try:
UserService.delete_by_id(user_id)
except Exception as e:
logging.exception(e)
# reraise
raise create_error
def delete_user_data(user_id: str) -> dict:
# use user_id to delete
usr = UserService.filter_by_id(user_id)
if not usr:
return {"success": False, "message": f"{user_id} can't be found."}
# check is inactive and not admin
if usr.is_active == ActiveEnum.ACTIVE.value:
return {"success": False, "message": f"{user_id} is active and can't be deleted."}
if usr.is_superuser:
return {"success": False, "message": "Can't delete the super user."}
# tenant info
tenants = UserTenantService.get_user_tenant_relation_by_user_id(usr.id)
owned_tenant = [t for t in tenants if t["role"] == UserTenantRole.OWNER.value]
done_msg = ''
try:
# step1. delete owned tenant info
if owned_tenant:
done_msg += "Start to delete owned tenant.\n"
tenant_id = owned_tenant[0]["tenant_id"]
kb_ids = KnowledgebaseService.get_kb_ids(usr.id)
# step1.1 delete knowledgebase related file and info
if kb_ids:
# step1.1.1 delete files in storage, remove bucket
for kb_id in kb_ids:
if STORAGE_IMPL.bucket_exists(kb_id):
STORAGE_IMPL.remove_bucket(kb_id)
done_msg += f"- Removed {len(kb_ids)} dataset's buckets.\n"
# step1.1.2 delete file and document info in db
doc_ids = DocumentService.get_all_doc_ids_by_kb_ids(kb_ids)
if doc_ids:
doc_delete_res = DocumentService.delete_by_ids([i["id"] for i in doc_ids])
done_msg += f"- Deleted {doc_delete_res} document records.\n"
task_delete_res = TaskService.delete_by_doc_ids([i["id"] for i in doc_ids])
done_msg += f"- Deleted {task_delete_res} task records.\n"
file_ids = FileService.get_all_file_ids_by_tenant_id(usr.id)
if file_ids:
file_delete_res = FileService.delete_by_ids([f["id"] for f in file_ids])
done_msg += f"- Deleted {file_delete_res} file records.\n"
if doc_ids or file_ids:
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
[i["id"] for i in doc_ids],
[f["id"] for f in file_ids]
)
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
# step1.1.3 delete chunk in es
r = settings.docStoreConn.delete({"kb_id": kb_ids},
search.index_name(tenant_id), kb_ids)
done_msg += f"- Deleted {r} chunk records.\n"
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
done_msg += f"- Deleted {kb_delete_res} knowledgebase records.\n"
# step1.1.4 delete agents
agent_delete_res = delete_user_agents(usr.id)
done_msg += f"- Deleted {agent_delete_res['agents_deleted_count']} agent, {agent_delete_res['version_deleted_count']} versions records.\n"
# step1.1.5 delete dialogs
dialog_delete_res = delete_user_dialogs(usr.id)
done_msg += f"- Deleted {dialog_delete_res['dialogs_deleted_count']} dialogs, {dialog_delete_res['conversations_deleted_count']} conversations, {dialog_delete_res['api_token_deleted_count']} api tokens, {dialog_delete_res['api4conversation_deleted_count']} api4conversations.\n"
# step1.1.6 delete mcp server
mcp_delete_res = MCPServerService.delete_by_tenant_id(usr.id)
done_msg += f"- Deleted {mcp_delete_res} MCP server.\n"
# step1.1.7 delete search
search_delete_res = SearchService.delete_by_tenant_id(usr.id)
done_msg += f"- Deleted {search_delete_res} search records.\n"
# step1.2 delete tenant_llm and tenant_langfuse
llm_delete_res = TenantLLMService.delete_by_tenant_id(tenant_id)
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
# step1.3 delete own tenant
tenant_delete_res = TenantService.delete_by_id(tenant_id)
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
# step2 delete user-tenant relation
if tenants:
# step2.1 delete docs and files in joined team
joined_tenants = [t for t in tenants if t["role"] == UserTenantRole.NORMAL.value]
if joined_tenants:
done_msg += "Start to delete data in joined tenants.\n"
created_documents = DocumentService.get_all_docs_by_creator_id(usr.id)
if created_documents:
# step2.1.1 delete files
doc_file_info = File2DocumentService.get_by_document_ids([d['id'] for d in created_documents])
created_files = FileService.get_by_ids([f['file_id'] for f in doc_file_info])
if created_files:
# step2.1.1.1 delete file in storage
for f in created_files:
STORAGE_IMPL.rm(f.parent_id, f.location)
done_msg += f"- Deleted {len(created_files)} uploaded file.\n"
# step2.1.1.2 delete file record
file_delete_res = FileService.delete_by_ids([f.id for f in created_files])
done_msg += f"- Deleted {file_delete_res} file records.\n"
# step2.1.2 delete document-file relation record
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
[d['id'] for d in created_documents],
[f.id for f in created_files]
)
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
# step2.1.3 delete chunks
doc_groups = group_by(created_documents, "tenant_id")
kb_grouped_doc = {k: group_by(v, "kb_id") for k, v in doc_groups.items()}
# chunks in {'tenant_id': {'kb_id': [{'id': doc_id}]}} structure
chunk_delete_res = 0
kb_doc_info = {}
for _tenant_id, kb_doc in kb_grouped_doc.items():
for _kb_id, docs in kb_doc.items():
chunk_delete_res += settings.docStoreConn.delete(
{"doc_id": [d["id"] for d in docs]},
search.index_name(_tenant_id), _kb_id
)
# record doc info
if _kb_id in kb_doc_info.keys():
kb_doc_info[_kb_id]['doc_num'] += 1
kb_doc_info[_kb_id]['token_num'] += sum([d["token_num"] for d in docs])
kb_doc_info[_kb_id]['chunk_num'] += sum([d["chunk_num"] for d in docs])
else:
kb_doc_info[_kb_id] = {
'doc_num': 1,
'token_num': sum([d["token_num"] for d in docs]),
'chunk_num': sum([d["chunk_num"] for d in docs])
}
done_msg += f"- Deleted {chunk_delete_res} chunks.\n"
# step2.1.4 delete tasks
task_delete_res = TaskService.delete_by_doc_ids([d['id'] for d in created_documents])
done_msg += f"- Deleted {task_delete_res} tasks.\n"
# step2.1.5 delete document record
doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents])
done_msg += f"- Deleted {doc_delete_res} documents.\n"
# step2.1.6 update knowledge base doc&chunk&token cnt
for kb_id, doc_num in kb_doc_info.items():
KnowledgebaseService.decrease_document_num_in_delete(kb_id, doc_num)
# step2.2 delete relation
user_tenant_delete_res = UserTenantService.delete_by_ids([t["id"] for t in tenants])
done_msg += f"- Deleted {user_tenant_delete_res} user-tenant records.\n"
# step3 finally delete user
user_delete_res = UserService.delete_by_id(usr.id)
done_msg += f"- Deleted {user_delete_res} user.\nDelete done!"
return {"success": True, "message": f"Successfully deleted user. Details:\n{done_msg}"}
except Exception as e:
logging.exception(e)
return {"success": False, "message": f"Error: {str(e)}. Already done:\n{done_msg}"}
def delete_user_agents(user_id: str) -> dict:
"""
use user_id to delete
:return: {
"agents_deleted_count": 1,
"version_deleted_count": 2
}
"""
agents_deleted_count, agents_version_deleted_count = 0, 0
user_agents = UserCanvasService.get_all_agents_by_tenant_ids([user_id], user_id)
if user_agents:
agents_version = UserCanvasVersionService.get_all_canvas_version_by_canvas_ids([a['id'] for a in user_agents])
agents_version_deleted_count = UserCanvasVersionService.delete_by_ids([v['id'] for v in agents_version])
agents_deleted_count = UserCanvasService.delete_by_ids([a['id'] for a in user_agents])
return {
"agents_deleted_count": agents_deleted_count,
"version_deleted_count": agents_version_deleted_count
}
def delete_user_dialogs(user_id: str) -> dict:
"""
use user_id to delete
:return: {
"dialogs_deleted_count": 1,
"conversations_deleted_count": 1,
"api_token_deleted_count": 2,
"api4conversation_deleted_count": 2
}
"""
dialog_deleted_count, conversations_deleted_count, api_token_deleted_count, api4conversation_deleted_count = 0, 0, 0, 0
user_dialogs = DialogService.get_all_dialogs_by_tenant_id(user_id)
if user_dialogs:
# delete conversation
conversations = ConversationService.get_all_conversation_by_dialog_ids([ud['id'] for ud in user_dialogs])
conversations_deleted_count = ConversationService.delete_by_ids([c['id'] for c in conversations])
# delete api token
api_token_deleted_count = APITokenService.delete_by_tenant_id(user_id)
# delete api for conversation
api4conversation_deleted_count = API4ConversationService.delete_by_dialog_ids([ud['id'] for ud in user_dialogs])
# delete dialog at last
dialog_deleted_count = DialogService.delete_by_ids([ud['id'] for ud in user_dialogs])
return {
"dialogs_deleted_count": dialog_deleted_count,
"conversations_deleted_count": conversations_deleted_count,
"api_token_deleted_count": api_token_deleted_count,
"api4conversation_deleted_count": api4conversation_deleted_count
}

View File

@ -19,7 +19,7 @@ from pathlib import PurePath
from .user_service import UserService as UserService from .user_service import UserService as UserService
def split_name_counter(filename: str) -> tuple[str, int | None]: def _split_name_counter(filename: str) -> tuple[str, int | None]:
""" """
Splits a filename into main part and counter (if present in parentheses). Splits a filename into main part and counter (if present in parentheses).
@ -87,7 +87,7 @@ def duplicate_name(query_func, **kwargs) -> str:
stem = path.stem stem = path.stem
suffix = path.suffix suffix = path.suffix
main_part, counter = split_name_counter(stem) main_part, counter = _split_name_counter(stem)
counter = counter + 1 if counter else 1 counter = counter + 1 if counter else 1
new_name = f"{main_part}({counter}){suffix}" new_name = f"{main_part}({counter}){suffix}"

View File

@ -35,6 +35,11 @@ class APITokenService(CommonService):
cls.model.token == token cls.model.token == token
) )
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
class API4ConversationService(CommonService): class API4ConversationService(CommonService):
model = API4Conversation model = API4Conversation
@ -100,3 +105,8 @@ class API4ConversationService(CommonService):
cls.model.create_date <= to_date, cls.model.create_date <= to_date,
cls.model.source == source cls.model.source == source
).group_by(cls.model.create_date.truncate("day")).dicts() ).group_by(cls.model.create_date.truncate("day")).dicts()
@classmethod
@DB.connection_context()
def delete_by_dialog_ids(cls, dialog_ids):
return cls.model.delete().where(cls.model.dialog_id.in_(dialog_ids)).execute()

View File

@ -18,7 +18,7 @@ import logging
import time import time
from uuid import uuid4 from uuid import uuid4
from agent.canvas import Canvas from agent.canvas import Canvas
from api.db import CanvasCategory from api.db import CanvasCategory, TenantPermission
from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation
from api.db.services.api_service import API4ConversationService from api.db.services.api_service import API4ConversationService
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
@ -63,7 +63,38 @@ class UserCanvasService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_by_tenant_id(cls, pid): def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
# will get all permitted agents, be cautious
fields = [
cls.model.id,
cls.model.title,
cls.model.permission,
cls.model.canvas_type,
cls.model.canvas_category
]
# find team agents and owned agents
agents = cls.model.select(*fields).where(
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
cls.model.user_id == user_id
)
)
# sort by create_time, asc
agents.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 50
res = []
while True:
ag_batch = agents.offset(offset).limit(limit)
_temp = list(ag_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def get_by_canvas_id(cls, pid):
try: try:
fields = [ fields = [
@ -138,7 +169,7 @@ class UserCanvasService(CommonService):
@DB.connection_context() @DB.connection_context()
def accessible(cls, canvas_id, tenant_id): def accessible(cls, canvas_id, tenant_id):
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
e, c = UserCanvasService.get_by_tenant_id(canvas_id) e, c = UserCanvasService.get_by_canvas_id(canvas_id)
if not e: if not e:
return False return False

View File

@ -14,12 +14,24 @@
# limitations under the License. # limitations under the License.
# #
from datetime import datetime from datetime import datetime
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import peewee import peewee
from peewee import InterfaceError, OperationalError
from api.db.db_models import DB from api.db.db_models import DB
from api.utils import current_timestamp, datetime_format, get_uuid from api.utils import current_timestamp, datetime_format, get_uuid
def retry_db_operation(func):
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=5),
retry=retry_if_exception_type((InterfaceError, OperationalError)),
before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"),
reraise=True,
)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
class CommonService: class CommonService:
"""Base service class that provides common database operations. """Base service class that provides common database operations.
@ -202,6 +214,7 @@ class CommonService:
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
@retry_db_operation
def update_by_id(cls, pid, data): def update_by_id(cls, pid, data):
# Update a single record by ID # Update a single record by ID
# Args: # Args:

View File

@ -23,7 +23,7 @@ from api.db.services.dialog_service import DialogService, chat
from api.utils import get_uuid from api.utils import get_uuid
import json import json
from rag.prompts import chunks_format from rag.prompts.generator import chunks_format
class ConversationService(CommonService): class ConversationService(CommonService):
@ -48,6 +48,21 @@ class ConversationService(CommonService):
return list(sessions.dicts()) return list(sessions.dicts())
@classmethod
@DB.connection_context()
def get_all_conversation_by_dialog_ids(cls, dialog_ids):
sessions = cls.model.select().where(cls.model.dialog_id.in_(dialog_ids))
sessions.order_by(cls.model.create_time.asc())
offset, limit = 0, 100
res = []
while True:
s_batch = sessions.offset(offset).limit(limit)
_temp = list(s_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
def structure_answer(conv, ans, message_id, session_id): def structure_answer(conv, ans, message_id, session_id):
reference = ans["reference"] reference = ans["reference"]

View File

@ -39,8 +39,8 @@ from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
from rag.prompts.prompts import gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
from rag.utils import num_tokens_from_string, rmSpace from rag.utils import num_tokens_from_string, rmSpace
from rag.utils.tavily_conn import Tavily from rag.utils.tavily_conn import Tavily
@ -159,6 +159,22 @@ class DialogService(CommonService):
return list(dialogs.dicts()), count return list(dialogs.dicts()), count
@classmethod
@DB.connection_context()
def get_all_dialogs_by_tenant_id(cls, tenant_id):
fields = [cls.model.id]
dialogs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
dialogs.order_by(cls.model.create_time.asc())
offset, limit = 0, 100
res = []
while True:
d_batch = dialogs.offset(offset).limit(limit)
_temp = list(d_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
def chat_solo(dialog, messages, stream=True): def chat_solo(dialog, messages, stream=True):
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
@ -176,7 +192,7 @@ def chat_solo(dialog, messages, stream=True):
delta_ans = "" delta_ans = ""
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans answer = ans
delta_ans = ans[len(last_ans) :] delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16: if num_tokens_from_string(delta_ans) < 16:
continue continue
last_ans = answer last_ans = answer
@ -261,13 +277,13 @@ def convert_conditions(metadata_condition):
"not is": "" "not is": ""
} }
return [ return [
{ {
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]), "op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
"key": cond["name"], "key": cond["name"],
"value": cond["value"] "value": cond["value"]
} }
for cond in metadata_condition.get("conditions", []) for cond in metadata_condition.get("conditions", [])
] ]
def meta_filter(metas: dict, filters: list[dict]): def meta_filter(metas: dict, filters: list[dict]):
@ -284,19 +300,19 @@ def meta_filter(metas: dict, filters: list[dict]):
value = str(value) value = str(value)
for conds in [ for conds in [
(operator == "contains", str(value).lower() in str(input).lower()), (operator == "contains", str(value).lower() in str(input).lower()),
(operator == "not contains", str(value).lower() not in str(input).lower()), (operator == "not contains", str(value).lower() not in str(input).lower()),
(operator == "start with", str(input).lower().startswith(str(value).lower())), (operator == "start with", str(input).lower().startswith(str(value).lower())),
(operator == "end with", str(input).lower().endswith(str(value).lower())), (operator == "end with", str(input).lower().endswith(str(value).lower())),
(operator == "empty", not input), (operator == "empty", not input),
(operator == "not empty", input), (operator == "not empty", input),
(operator == "=", input == value), (operator == "=", input == value),
(operator == "", input != value), (operator == "", input != value),
(operator == ">", input > value), (operator == ">", input > value),
(operator == "<", input < value), (operator == "<", input < value),
(operator == "", input >= value), (operator == "", input >= value),
(operator == "", input <= value), (operator == "", input <= value),
]: ]:
try: try:
if all(conds): if all(conds):
ids.extend(docids) ids.extend(docids)
@ -456,7 +472,8 @@ def chat(dialog, messages, stream=True, **kwargs):
kbinfos["chunks"].extend(tav_res["chunks"]) kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"): if prompt_config.get("use_kg"):
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT)) ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]: if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck) kbinfos["chunks"].insert(0, ck)
@ -467,7 +484,8 @@ def chat(dialog, messages, stream=True, **kwargs):
retrieval_ts = timer() retrieval_ts = timer()
if not knowledges and prompt_config.get("empty_response"): if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"] empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)} yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
"audio_binary": tts(tts_mdl, empty_res)}
return {"answer": prompt_config["empty_response"], "reference": kbinfos} return {"answer": prompt_config["empty_response"], "reference": kbinfos}
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
@ -565,7 +583,8 @@ def chat(dialog, messages, stream=True, **kwargs):
if langfuse_tracer: if langfuse_tracer:
langfuse_generation = langfuse_tracer.start_generation( langfuse_generation = langfuse_tracer.start_generation(
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg} trace_context=trace_context, name="chat", model=llm_model_config["llm_name"],
input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
) )
if stream: if stream:
@ -575,12 +594,12 @@ def chat(dialog, messages, stream=True, **kwargs):
if thought: if thought:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
answer = ans answer = ans
delta_ans = ans[len(last_ans) :] delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16: if num_tokens_from_string(delta_ans) < 16:
continue continue
last_ans = answer last_ans = answer
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans) :] delta_ans = answer[len(last_ans):]
if delta_ans: if delta_ans:
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought + answer) yield decorate_answer(thought + answer)
@ -676,7 +695,9 @@ Please write the SQL, only SQL, without any other explanations or text.
# compose Markdown table # compose Markdown table
columns = ( columns = (
"|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") "|" + "|".join(
[re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
"|Source|" if docid_idx and docid_idx else "|")
) )
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
@ -753,7 +774,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = None doc_ids = None
kbinfos = retriever.retrieval( kbinfos = retriever.retrieval(
question = question, question=question,
embd_mdl=embd_mdl, embd_mdl=embd_mdl,
tenant_ids=tenant_ids, tenant_ids=tenant_ids,
kb_ids=kb_ids, kb_ids=kb_ids,
@ -775,7 +796,8 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
def decorate_answer(answer): def decorate_answer(answer):
nonlocal knowledges, kbinfos, sys_prompt nonlocal knowledges, kbinfos, sys_prompt
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3) answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]],
embd_mdl, tkweight=0.7, vtweight=0.3)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs: if not recall_docs:

View File

@ -243,6 +243,46 @@ class DocumentService(CommonService):
return int(query.scalar()) or 0 return int(query.scalar()) or 0
@classmethod
@DB.connection_context()
def get_all_doc_ids_by_kb_ids(cls, kb_ids):
fields = [cls.model.id]
docs = cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids))
docs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 100
res = []
while True:
doc_batch = docs.offset(offset).limit(limit)
_temp = list(doc_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def get_all_docs_by_creator_id(cls, creator_id):
fields = [
cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id
]
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(
cls.model.created_by == creator_id
)
docs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 100
res = []
while True:
doc_batch = docs.offset(offset).limit(limit)
_temp = list(doc_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def insert(cls, doc): def insert(cls, doc):
@ -517,9 +557,6 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name): def get_doc_id_by_doc_name(cls, doc_name):
"""
highly rely on the strict deduplication guarantee from Document
"""
fields = [cls.model.id] fields = [cls.model.id]
doc_id = cls.model.select(*fields) \ doc_id = cls.model.select(*fields) \
.where(cls.model.name == doc_name) .where(cls.model.name == doc_name)
@ -681,8 +718,16 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_kb_doc_count(cls, kb_id): def get_kb_doc_count(cls, kb_id):
return len(cls.model.select(cls.model.id).where( return cls.model.select().where(cls.model.kb_id == kb_id).count()
cls.model.kb_id == kb_id).dicts())
@classmethod
@DB.connection_context()
def get_all_kb_doc_count(cls):
result = {}
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias('count')).group_by(cls.model.kb_id)
for row in rows:
result[row.kb_id] = row.count
return result
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()

View File

@ -38,6 +38,12 @@ class File2DocumentService(CommonService):
objs = cls.model.select().where(cls.model.document_id == document_id) objs = cls.model.select().where(cls.model.document_id == document_id)
return objs return objs
@classmethod
@DB.connection_context()
def get_by_document_ids(cls, document_ids):
objs = cls.model.select().where(cls.model.document_id.in_(document_ids))
return list(objs.dicts())
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def insert(cls, obj): def insert(cls, obj):
@ -50,6 +56,15 @@ class File2DocumentService(CommonService):
def delete_by_file_id(cls, file_id): def delete_by_file_id(cls, file_id):
return cls.model.delete().where(cls.model.file_id == file_id).execute() return cls.model.delete().where(cls.model.file_id == file_id).execute()
@classmethod
@DB.connection_context()
def delete_by_document_ids_or_file_ids(cls, document_ids, file_ids):
if not document_ids:
return cls.model.delete().where(cls.model.file_id.in_(file_ids)).execute()
elif not file_ids:
return cls.model.delete().where(cls.model.document_id.in_(document_ids)).execute()
return cls.model.delete().where(cls.model.document_id.in_(document_ids) | cls.model.file_id.in_(file_ids)).execute()
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def delete_by_document_id(cls, doc_id): def delete_by_document_id(cls, doc_id):

View File

@ -161,6 +161,23 @@ class FileService(CommonService):
result_ids.append(folder_id) result_ids.append(folder_id)
return result_ids return result_ids
@classmethod
@DB.connection_context()
def get_all_file_ids_by_tenant_id(cls, tenant_id):
fields = [cls.model.id]
files = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
files.order_by(cls.model.create_time.asc())
offset, limit = 0, 100
res = []
while True:
file_batch = files.offset(offset).limit(limit)
_temp = list(file_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def create_folder(cls, file, parent_id, name, count): def create_folder(cls, file, parent_id, name, count):

View File

@ -18,7 +18,7 @@ from datetime import datetime
from peewee import fn, JOIN from peewee import fn, JOIN
from api.db import StatusEnum, TenantPermission from api.db import StatusEnum, TenantPermission
from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant, UserCanvas from api.db.db_models import DB, Document, Knowledgebase, User, UserTenant, UserCanvas
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.utils import current_timestamp, datetime_format from api.utils import current_timestamp, datetime_format
@ -190,6 +190,41 @@ class KnowledgebaseService(CommonService):
return list(kbs.dicts()), count return list(kbs.dicts()), count
@classmethod
@DB.connection_context()
def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id):
# will get all permitted kb, be cautious.
fields = [
cls.model.name,
cls.model.language,
cls.model.permission,
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.status,
cls.model.create_date,
cls.model.update_date
]
# find team kb and owned kb
kbs = cls.model.select(*fields).where(
(cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | (
cls.model.tenant_id == user_id
)
)
# sort by create_time asc
kbs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later.
offset, limit = 0, 50
res = []
while True:
kb_batch = kbs.offset(offset).limit(limit)
_temp = list(kb_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_kb_ids(cls, tenant_id): def get_kb_ids(cls, tenant_id):
@ -226,7 +261,7 @@ class KnowledgebaseService(CommonService):
cls.model.chunk_num, cls.model.chunk_num,
cls.model.parser_id, cls.model.parser_id,
cls.model.pipeline_id, cls.model.pipeline_id,
UserCanvas.title, UserCanvas.title.alias("pipeline_name"),
UserCanvas.avatar.alias("pipeline_avatar"), UserCanvas.avatar.alias("pipeline_avatar"),
cls.model.parser_config, cls.model.parser_config,
cls.model.pagerank, cls.model.pagerank,
@ -240,16 +275,14 @@ class KnowledgebaseService(CommonService):
cls.model.update_time cls.model.update_time
] ]
kbs = cls.model.select(*fields)\ kbs = cls.model.select(*fields)\
.join(Tenant, on=((Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value)))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\ .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.where( .where(
(cls.model.id == kb_id), (cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value) (cls.model.status == StatusEnum.VALID.value)
) ).dicts()
if not kbs: if not kbs:
return return
d = kbs[0].to_dict() return kbs[0]
return d
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
@ -447,3 +480,17 @@ class KnowledgebaseService(CommonService):
else: else:
raise e raise e
@classmethod
@DB.connection_context()
def decrease_document_num_in_delete(cls, kb_id, doc_num_info: dict):
kb_row = cls.model.get_by_id(kb_id)
if not kb_row:
raise RuntimeError(f"kb_id {kb_id} does not exist")
update_dict = {
'doc_num': kb_row.doc_num - doc_num_info['doc_num'],
'chunk_num': kb_row.chunk_num - doc_num_info['chunk_num'],
'token_num': kb_row.token_num - doc_num_info['token_num'],
'update_time': current_timestamp(),
'update_date': datetime_format(datetime.now())
}
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()

View File

@ -51,6 +51,11 @@ class TenantLangfuseService(CommonService):
except peewee.DoesNotExist: except peewee.DoesNotExist:
return None return None
@classmethod
@DB.connection_context()
def delete_ty_tenant_id(cls, tenant_id):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
@classmethod @classmethod
def update_by_tenant(cls, tenant_id, langfuse_keys): def update_by_tenant(cls, tenant_id, langfuse_keys):
langfuse_keys["update_time"] = current_timestamp() langfuse_keys["update_time"] = current_timestamp()

View File

@ -84,3 +84,8 @@ class MCPServerService(CommonService):
return bool(mcp_server), mcp_server return bool(mcp_server), mcp_server
except Exception: except Exception:
return False, None return False, None
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id: str):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()

View File

@ -110,3 +110,8 @@ class SearchService(CommonService):
query = query.paginate(page_number, items_per_page) query = query.paginate(page_number, items_per_page)
return list(query.dicts()), count return list(query.dicts()), count
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()

View File

@ -316,6 +316,12 @@ class TaskService(CommonService):
process_duration = (datetime.now() - task.begin_at).total_seconds() process_duration = (datetime.now() - task.begin_at).total_seconds()
cls.model.update(process_duration=process_duration).where(cls.model.id == id).execute() cls.model.update(process_duration=process_duration).where(cls.model.id == id).execute()
@classmethod
@DB.connection_context()
def delete_by_doc_ids(cls, doc_ids):
"""Delete task associated with a document."""
return cls.model.delete().where(cls.model.doc_id.in_(doc_ids)).execute()
def queue_tasks(doc: dict, bucket: str, name: str, priority: int): def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
"""Create and queue document processing tasks. """Create and queue document processing tasks.

View File

@ -209,6 +209,11 @@ class TenantLLMService(CommonService):
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs) return list(objs)
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
@staticmethod @staticmethod
def llm_id2llm_type(llm_id: str) -> str | None: def llm_id2llm_type(llm_id: str) -> str | None:
from api.db.services.llm_service import LLMService from api.db.services.llm_service import LLMService

View File

@ -24,7 +24,24 @@ class UserCanvasVersionService(CommonService):
return None return None
except Exception: except Exception:
return None return None
@classmethod
@DB.connection_context()
def get_all_canvas_version_by_canvas_ids(cls, canvas_ids):
fields = [cls.model.id]
versions = cls.model.select(*fields).where(cls.model.user_canvas_id.in_(canvas_ids))
versions.order_by(cls.model.create_time.asc())
offset, limit = 0, 100
res = []
while True:
version_batch = versions.offset(offset).limit(limit)
_temp = list(version_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def delete_all_versions(cls, user_canvas_id): def delete_all_versions(cls, user_canvas_id):

View File

@ -45,22 +45,22 @@ class UserService(CommonService):
def query(cls, cols=None, reverse=None, order_by=None, **kwargs): def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
if 'access_token' in kwargs: if 'access_token' in kwargs:
access_token = kwargs['access_token'] access_token = kwargs['access_token']
# Reject empty, None, or whitespace-only access tokens # Reject empty, None, or whitespace-only access tokens
if not access_token or not str(access_token).strip(): if not access_token or not str(access_token).strip():
logging.warning("UserService.query: Rejecting empty access_token query") logging.warning("UserService.query: Rejecting empty access_token query")
return cls.model.select().where(cls.model.id == "INVALID_EMPTY_TOKEN") # Returns empty result return cls.model.select().where(cls.model.id == "INVALID_EMPTY_TOKEN") # Returns empty result
# Reject tokens that are too short (should be UUID, 32+ chars) # Reject tokens that are too short (should be UUID, 32+ chars)
if len(str(access_token).strip()) < 32: if len(str(access_token).strip()) < 32:
logging.warning(f"UserService.query: Rejecting short access_token query: {len(str(access_token))} chars") logging.warning(f"UserService.query: Rejecting short access_token query: {len(str(access_token))} chars")
return cls.model.select().where(cls.model.id == "INVALID_SHORT_TOKEN") # Returns empty result return cls.model.select().where(cls.model.id == "INVALID_SHORT_TOKEN") # Returns empty result
# Reject tokens that start with "INVALID_" (from logout) # Reject tokens that start with "INVALID_" (from logout)
if str(access_token).startswith("INVALID_"): if str(access_token).startswith("INVALID_"):
logging.warning("UserService.query: Rejecting invalidated access_token") logging.warning("UserService.query: Rejecting invalidated access_token")
return cls.model.select().where(cls.model.id == "INVALID_LOGOUT_TOKEN") # Returns empty result return cls.model.select().where(cls.model.id == "INVALID_LOGOUT_TOKEN") # Returns empty result
# Call parent query method for valid requests # Call parent query method for valid requests
return super().query(cols=cols, reverse=reverse, order_by=order_by, **kwargs) return super().query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
@ -100,6 +100,12 @@ class UserService(CommonService):
else: else:
return None return None
@classmethod
@DB.connection_context()
def query_user_by_email(cls, email):
users = cls.model.select().where((cls.model.email == email))
return list(users)
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def save(cls, **kwargs): def save(cls, **kwargs):
@ -133,6 +139,17 @@ class UserService(CommonService):
cls.model.update(user_dict).where( cls.model.update(user_dict).where(
cls.model.id == user_id).execute() cls.model.id == user_id).execute()
@classmethod
@DB.connection_context()
def update_user_password(cls, user_id, new_password):
with DB.atomic():
update_dict = {
"password": generate_password_hash(str(new_password)),
"update_time": current_timestamp(),
"update_date": datetime_format(datetime.now())
}
cls.model.update(update_dict).where(cls.model.id == user_id).execute()
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def is_admin(cls, user_id): def is_admin(cls, user_id):
@ -140,6 +157,12 @@ class UserService(CommonService):
cls.model.id == user_id, cls.model.id == user_id,
cls.model.is_superuser == 1).count() > 0 cls.model.is_superuser == 1).count() > 0
@classmethod
@DB.connection_context()
def get_all_users(cls):
users = cls.model.select()
return list(users)
class TenantService(CommonService): class TenantService(CommonService):
"""Service class for managing tenant-related database operations. """Service class for managing tenant-related database operations.
@ -265,6 +288,17 @@ class UserTenantService(CommonService):
.join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value))) .join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
.where(cls.model.status == StatusEnum.VALID.value).dicts()) .where(cls.model.status == StatusEnum.VALID.value).dicts())
@classmethod
@DB.connection_context()
def get_user_tenant_relation_by_user_id(cls, user_id):
fields = [
cls.model.id,
cls.model.user_id,
cls.model.tenant_id,
cls.model.role
]
return list(cls.model.select(*fields).where(cls.model.user_id == user_id).dicts().dicts())
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_num_members(cls, user_id: str): def get_num_members(cls, user_id: str):

View File

@ -41,7 +41,7 @@ from api import utils
from api.db.db_models import init_database_tables as init_web_db from api.db.db_models import init_database_tables as init_web_db
from api.db.init_data import init_web_data from api.db.init_data import init_web_data
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
from api.utils import show_configs from api.utils.configs import show_configs
from rag.settings import print_rag_settings from rag.settings import print_rag_settings
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
from rag.utils.redis_conn import RedisDistributedLock from rag.utils.redis_conn import RedisDistributedLock

View File

@ -24,7 +24,7 @@ import rag.utils.es_conn
import rag.utils.infinity_conn import rag.utils.infinity_conn
import rag.utils.opensearch_conn import rag.utils.opensearch_conn
from api.constants import RAG_FLOW_SERVICE_NAME from api.constants import RAG_FLOW_SERVICE_NAME
from api.utils import decrypt_database_config, get_base_config from api.utils.configs import decrypt_database_config, get_base_config
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from rag.nlp import search from rag.nlp import search

View File

@ -16,184 +16,15 @@
import base64 import base64
import datetime import datetime
import hashlib import hashlib
import io
import json
import os import os
import pickle
import socket import socket
import time import time
import uuid import uuid
import requests import requests
import logging
import copy
from enum import Enum, IntEnum
import importlib import importlib
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from filelock import FileLock
from api.constants import SERVICE_CONF
from . import file_utils from .common import string_to_bytes
def conf_realpath(conf_name):
conf_path = f"conf/{conf_name}"
return os.path.join(file_utils.get_project_base_directory(), conf_path)
def read_config(conf_name=SERVICE_CONF):
local_config = {}
local_path = conf_realpath(f'local.{conf_name}')
# load local config file
if os.path.exists(local_path):
local_config = file_utils.load_yaml_conf(local_path)
if not isinstance(local_config, dict):
raise ValueError(f'Invalid config file: "{local_path}".')
global_config_path = conf_realpath(conf_name)
global_config = file_utils.load_yaml_conf(global_config_path)
if not isinstance(global_config, dict):
raise ValueError(f'Invalid config file: "{global_config_path}".')
global_config.update(local_config)
return global_config
CONFIGS = read_config()
def show_configs():
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
for k, v in CONFIGS.items():
if isinstance(v, dict):
if "password" in v:
v = copy.deepcopy(v)
v["password"] = "*" * 8
if "access_key" in v:
v = copy.deepcopy(v)
v["access_key"] = "*" * 8
if "secret_key" in v:
v = copy.deepcopy(v)
v["secret_key"] = "*" * 8
if "secret" in v:
v = copy.deepcopy(v)
v["secret"] = "*" * 8
if "sas_token" in v:
v = copy.deepcopy(v)
v["sas_token"] = "*" * 8
if "oauth" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "client_secret" in val:
val["client_secret"] = "*" * 8
if "authentication" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "http_secret_key" in val:
val["http_secret_key"] = "*" * 8
msg += f"\n\t{k}: {v}"
logging.info(msg)
def get_base_config(key, default=None):
if key is None:
return None
if default is None:
default = os.environ.get(key.upper())
return CONFIGS.get(key, default)
use_deserialize_safe_module = get_base_config(
'use_deserialize_safe_module', False)
class BaseType:
def to_dict(self):
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
def to_dict_with_type(self):
def _dict(obj):
module = None
if issubclass(obj.__class__, BaseType):
data = {}
for attr, v in obj.__dict__.items():
k = attr.lstrip("_")
data[k] = _dict(v)
module = obj.__module__
elif isinstance(obj, (list, tuple)):
data = []
for i, vv in enumerate(obj):
data.append(_dict(vv))
elif isinstance(obj, dict):
data = {}
for _k, vv in obj.items():
data[_k] = _dict(vv)
else:
data = obj
return {"type": obj.__class__.__name__,
"data": data, "module": module}
return _dict(self)
class CustomJSONEncoder(json.JSONEncoder):
def __init__(self, **kwargs):
self._with_type = kwargs.pop("with_type", False)
super().__init__(**kwargs)
def default(self, obj):
if isinstance(obj, datetime.datetime):
return obj.strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(obj, datetime.date):
return obj.strftime('%Y-%m-%d')
elif isinstance(obj, datetime.timedelta):
return str(obj)
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
return obj.value
elif isinstance(obj, set):
return list(obj)
elif issubclass(type(obj), BaseType):
if not self._with_type:
return obj.to_dict()
else:
return obj.to_dict_with_type()
elif isinstance(obj, type):
return obj.__name__
else:
return json.JSONEncoder.default(self, obj)
def rag_uuid():
return uuid.uuid1().hex
def string_to_bytes(string):
return string if isinstance(
string, bytes) else string.encode(encoding="utf-8")
def bytes_to_string(byte):
return byte.decode(encoding="utf-8")
def json_dumps(src, byte=False, indent=None, with_type=False):
dest = json.dumps(
src,
indent=indent,
cls=CustomJSONEncoder,
with_type=with_type)
if byte:
dest = string_to_bytes(dest)
return dest
def json_loads(src, object_hook=None, object_pairs_hook=None):
if isinstance(src, bytes):
src = bytes_to_string(src)
return json.loads(src, object_hook=object_hook,
object_pairs_hook=object_pairs_hook)
def current_timestamp(): def current_timestamp():
@ -215,45 +46,6 @@ def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
return time_stamp return time_stamp
def serialize_b64(src, to_str=False):
dest = base64.b64encode(pickle.dumps(src))
if not to_str:
return dest
else:
return bytes_to_string(dest)
def deserialize_b64(src):
src = base64.b64decode(
string_to_bytes(src) if isinstance(
src, str) else src)
if use_deserialize_safe_module:
return restricted_loads(src)
return pickle.loads(src)
safe_module = {
'numpy',
'rag_flow'
}
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
import importlib
if module.split('.')[0] in safe_module:
_module = importlib.import_module(module)
return getattr(_module, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_loads(src):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(src)).load()
def get_lan_ip(): def get_lan_ip():
if os.name != "nt": if os.name != "nt":
import fcntl import fcntl
@ -298,47 +90,6 @@ def from_dict_hook(in_dict: dict):
return in_dict return in_dict
def decrypt_database_password(password):
encrypt_password = get_base_config("encrypt_password", False)
encrypt_module = get_base_config("encrypt_module", False)
private_key = get_base_config("private_key", None)
if not password or not encrypt_password:
return password
if not private_key:
raise ValueError("No private key")
module_fun = encrypt_module.split("#")
pwdecrypt_fun = getattr(
importlib.import_module(
module_fun[0]),
module_fun[1])
return pwdecrypt_fun(private_key, password)
def decrypt_database_config(
database=None, passwd_key="password", name="database"):
if not database:
database = get_base_config(name, {})
database[passwd_key] = decrypt_database_password(database[passwd_key])
return database
def update_config(key, value, conf_name=SERVICE_CONF):
conf_path = conf_realpath(conf_name=conf_name)
if not os.path.isabs(conf_path):
conf_path = os.path.join(
file_utils.get_project_base_directory(), conf_path)
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
config[key] = value
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
def get_uuid(): def get_uuid():
return uuid.uuid1().hex return uuid.uuid1().hex
@ -363,37 +114,6 @@ def elapsed2time(elapsed):
return '%02d:%02d:%02d' % (hour, minuter, second) return '%02d:%02d:%02d' % (hour, minuter, second)
def decrypt(line):
file_path = os.path.join(
file_utils.get_project_base_directory(),
"conf",
"private.pem")
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
return cipher.decrypt(base64.b64decode(
line), "Fail to decrypt password!").decode('utf-8')
def decrypt2(crypt_text):
from base64 import b64decode, b16decode
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
from Crypto.PublicKey import RSA
decode_data = b64decode(crypt_text)
if len(decode_data) == 127:
hex_fixed = '00' + decode_data.hex()
decode_data = b16decode(hex_fixed.upper())
file_path = os.path.join(
file_utils.get_project_base_directory(),
"conf",
"private.pem")
pem = open(file_path).read()
rsa_key = RSA.importKey(pem, "Welcome")
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
decrypt_text = cipher.decrypt(decode_data, None)
return (b64decode(decrypt_text)).decode()
def download_img(url): def download_img(url):
if not url: if not url:
return "" return ""
@ -408,5 +128,5 @@ def delta_seconds(date_string: str):
return (datetime.datetime.now() - dt).total_seconds() return (datetime.datetime.now() - dt).total_seconds()
def hash_str2int(line:str, mod: int=10 ** 8) -> int: def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod

View File

@ -39,6 +39,7 @@ from flask import (
make_response, make_response,
send_file, send_file,
) )
from flask_login import current_user
from flask import ( from flask import (
request as flask_request, request as flask_request,
) )
@ -48,10 +49,13 @@ from werkzeug.http import HTTP_STATUS_CODES
from api import settings from api import settings
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
from api.db import ActiveEnum
from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.db.services import UserService
from api.db.services.llm_service import LLMService from api.db.services.llm_service import LLMService
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.utils import CustomJSONEncoder, get_uuid, json_dumps from api.utils.json import CustomJSONEncoder, json_dumps
from api.utils import get_uuid
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
@ -226,6 +230,18 @@ def not_allowed_parameters(*params):
return decorator return decorator
def active_required(f):
@wraps(f)
def wrapper(*args, **kwargs):
user_id = current_user.id
usr = UserService.filter_by_id(user_id)
# check is_active
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
return get_json_result(code=settings.RetCode.FORBIDDEN, message="User isn't active, please activate first.")
return f(*args, **kwargs)
return wrapper
def is_localhost(ip): def is_localhost(ip):
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"} return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
@ -643,6 +659,16 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
return transformed_data return transformed_data
def group_by(list_of_dict, key):
res = {}
for item in list_of_dict:
if item[key] in res.keys():
res[item[key]].append(item)
else:
res[item[key]] = [item]
return res
def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]: def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
results = {} results = {}
tool_call_sessions = [] tool_call_sessions = []

23
api/utils/common.py Normal file
View File

@ -0,0 +1,23 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
def string_to_bytes(string):
return string if isinstance(
string, bytes) else string.encode(encoding="utf-8")
def bytes_to_string(byte):
return byte.decode(encoding="utf-8")

179
api/utils/configs.py Normal file
View File

@ -0,0 +1,179 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import io
import copy
import logging
import base64
import pickle
import importlib
from api.utils import file_utils
from filelock import FileLock
from api.utils.common import bytes_to_string, string_to_bytes
from api.constants import SERVICE_CONF
def conf_realpath(conf_name):
conf_path = f"conf/{conf_name}"
return os.path.join(file_utils.get_project_base_directory(), conf_path)
def read_config(conf_name=SERVICE_CONF):
local_config = {}
local_path = conf_realpath(f'local.{conf_name}')
# load local config file
if os.path.exists(local_path):
local_config = file_utils.load_yaml_conf(local_path)
if not isinstance(local_config, dict):
raise ValueError(f'Invalid config file: "{local_path}".')
global_config_path = conf_realpath(conf_name)
global_config = file_utils.load_yaml_conf(global_config_path)
if not isinstance(global_config, dict):
raise ValueError(f'Invalid config file: "{global_config_path}".')
global_config.update(local_config)
return global_config
CONFIGS = read_config()
def show_configs():
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
for k, v in CONFIGS.items():
if isinstance(v, dict):
if "password" in v:
v = copy.deepcopy(v)
v["password"] = "*" * 8
if "access_key" in v:
v = copy.deepcopy(v)
v["access_key"] = "*" * 8
if "secret_key" in v:
v = copy.deepcopy(v)
v["secret_key"] = "*" * 8
if "secret" in v:
v = copy.deepcopy(v)
v["secret"] = "*" * 8
if "sas_token" in v:
v = copy.deepcopy(v)
v["sas_token"] = "*" * 8
if "oauth" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "client_secret" in val:
val["client_secret"] = "*" * 8
if "authentication" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "http_secret_key" in val:
val["http_secret_key"] = "*" * 8
msg += f"\n\t{k}: {v}"
logging.info(msg)
def get_base_config(key, default=None):
if key is None:
return None
if default is None:
default = os.environ.get(key.upper())
return CONFIGS.get(key, default)
def decrypt_database_password(password):
encrypt_password = get_base_config("encrypt_password", False)
encrypt_module = get_base_config("encrypt_module", False)
private_key = get_base_config("private_key", None)
if not password or not encrypt_password:
return password
if not private_key:
raise ValueError("No private key")
module_fun = encrypt_module.split("#")
pwdecrypt_fun = getattr(
importlib.import_module(
module_fun[0]),
module_fun[1])
return pwdecrypt_fun(private_key, password)
def decrypt_database_config(
database=None, passwd_key="password", name="database"):
if not database:
database = get_base_config(name, {})
database[passwd_key] = decrypt_database_password(database[passwd_key])
return database
def update_config(key, value, conf_name=SERVICE_CONF):
conf_path = conf_realpath(conf_name=conf_name)
if not os.path.isabs(conf_path):
conf_path = os.path.join(
file_utils.get_project_base_directory(), conf_path)
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
config[key] = value
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
safe_module = {
'numpy',
'rag_flow'
}
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
import importlib
if module.split('.')[0] in safe_module:
_module = importlib.import_module(module)
return getattr(_module, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_loads(src):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(src)).load()
def serialize_b64(src, to_str=False):
dest = base64.b64encode(pickle.dumps(src))
if not to_str:
return dest
else:
return bytes_to_string(dest)
def deserialize_b64(src):
src = base64.b64decode(
string_to_bytes(src) if isinstance(
src, str) else src)
use_deserialize_safe_module = get_base_config(
'use_deserialize_safe_module', False)
if use_deserialize_safe_module:
return restricted_loads(src)
return pickle.loads(src)

64
api/utils/crypt.py Normal file
View File

@ -0,0 +1,64 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import os
import sys
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from api.utils import file_utils
def crypt(line):
"""
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
"""
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
encrypted_password = cipher.encrypt(password_base64.encode())
return base64.b64encode(encrypted_password).decode('utf-8')
def decrypt(line):
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
def decrypt2(crypt_text):
from base64 import b64decode, b16decode
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
from Crypto.PublicKey import RSA
decode_data = b64decode(crypt_text)
if len(decode_data) == 127:
hex_fixed = '00' + decode_data.hex()
decode_data = b16decode(hex_fixed.upper())
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
pem = open(file_path).read()
rsa_key = RSA.importKey(pem, "Welcome")
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
decrypt_text = cipher.decrypt(decode_data, None)
return (b64decode(decrypt_text)).decode()
if __name__ == "__main__":
passwd = crypt(sys.argv[1])
print(passwd)
print(decrypt(passwd))

107
api/utils/health_utils.py Normal file
View File

@ -0,0 +1,107 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from timeit import default_timer as timer
from api import settings
from api.db.db_models import DB
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
def _ok_nok(ok: bool) -> str:
return "ok" if ok else "nok"
def check_db() -> tuple[bool, dict]:
st = timer()
try:
# lightweight probe; works for MySQL/Postgres
DB.execute_sql("SELECT 1")
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
except Exception as e:
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
def check_redis() -> tuple[bool, dict]:
st = timer()
try:
ok = bool(REDIS_CONN.health())
return ok, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
except Exception as e:
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
def check_doc_engine() -> tuple[bool, dict]:
st = timer()
try:
meta = settings.docStoreConn.health()
# treat any successful call as ok
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
except Exception as e:
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
def check_storage() -> tuple[bool, dict]:
st = timer()
try:
STORAGE_IMPL.health()
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
except Exception as e:
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
def run_health_checks() -> tuple[dict, bool]:
result: dict[str, str | dict] = {}
db_ok, db_meta = check_db()
result["db"] = _ok_nok(db_ok)
if not db_ok:
result.setdefault("_meta", {})["db"] = db_meta
try:
redis_ok, redis_meta = check_redis()
result["redis"] = _ok_nok(redis_ok)
if not redis_ok:
result.setdefault("_meta", {})["redis"] = redis_meta
except Exception:
result["redis"] = "nok"
try:
doc_ok, doc_meta = check_doc_engine()
result["doc_engine"] = _ok_nok(doc_ok)
if not doc_ok:
result.setdefault("_meta", {})["doc_engine"] = doc_meta
except Exception:
result["doc_engine"] = "nok"
try:
sto_ok, sto_meta = check_storage()
result["storage"] = _ok_nok(sto_ok)
if not sto_ok:
result.setdefault("_meta", {})["storage"] = sto_meta
except Exception:
result["storage"] = "nok"
all_ok = (result.get("db") == "ok") and (result.get("redis") == "ok") and (result.get("doc_engine") == "ok") and (result.get("storage") == "ok")
result["status"] = "ok" if all_ok else "nok"
return result, all_ok

78
api/utils/json.py Normal file
View File

@ -0,0 +1,78 @@
import datetime
import json
from enum import Enum, IntEnum
from api.utils.common import string_to_bytes, bytes_to_string
class BaseType:
def to_dict(self):
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
def to_dict_with_type(self):
def _dict(obj):
module = None
if issubclass(obj.__class__, BaseType):
data = {}
for attr, v in obj.__dict__.items():
k = attr.lstrip("_")
data[k] = _dict(v)
module = obj.__module__
elif isinstance(obj, (list, tuple)):
data = []
for i, vv in enumerate(obj):
data.append(_dict(vv))
elif isinstance(obj, dict):
data = {}
for _k, vv in obj.items():
data[_k] = _dict(vv)
else:
data = obj
return {"type": obj.__class__.__name__,
"data": data, "module": module}
return _dict(self)
class CustomJSONEncoder(json.JSONEncoder):
def __init__(self, **kwargs):
self._with_type = kwargs.pop("with_type", False)
super().__init__(**kwargs)
def default(self, obj):
if isinstance(obj, datetime.datetime):
return obj.strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(obj, datetime.date):
return obj.strftime('%Y-%m-%d')
elif isinstance(obj, datetime.timedelta):
return str(obj)
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
return obj.value
elif isinstance(obj, set):
return list(obj)
elif issubclass(type(obj), BaseType):
if not self._with_type:
return obj.to_dict()
else:
return obj.to_dict_with_type()
elif isinstance(obj, type):
return obj.__name__
else:
return json.JSONEncoder.default(self, obj)
def json_dumps(src, byte=False, indent=None, with_type=False):
dest = json.dumps(
src,
indent=indent,
cls=CustomJSONEncoder,
with_type=with_type)
if byte:
dest = string_to_bytes(dest)
return dest
def json_loads(src, object_hook=None, object_pairs_hook=None):
if isinstance(src, bytes):
src = bytes_to_string(src)
return json.loads(src, object_hook=object_hook,
object_pairs_hook=object_pairs_hook)

View File

@ -1,40 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import os
import sys
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from api.utils import decrypt, file_utils
def crypt(line):
file_path = os.path.join(
file_utils.get_project_base_directory(),
"conf",
"public.pem")
rsa_key = RSA.importKey(open(file_path).read(),"Welcome")
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
encrypted_password = cipher.encrypt(password_base64.encode())
return base64.b64encode(encrypted_password).decode('utf-8')
if __name__ == "__main__":
passwd = crypt(sys.argv[1])
print(passwd)
print(decrypt(passwd))

19
chat_demo/index.html Normal file
View File

@ -0,0 +1,19 @@
<iframe src="http://localhost:9222/next-chats/widget?shared_id=9dcfc68696c611f0bb789b9b8b765d12&from=chat&auth=U4MDU3NzkwOTZjNzExZjBiYjc4OWI5Yj&mode=master&streaming=false"
style="position:fixed;bottom:0;right:0;width:100px;height:100px;border:none;background:transparent;z-index:9999"
frameborder="0" allow="microphone;camera"></iframe>
<script>
window.addEventListener('message',e=>{
if(e.origin!=='http://localhost:9222')return;
if(e.data.type==='CREATE_CHAT_WINDOW'){
if(document.getElementById('chat-win'))return;
const i=document.createElement('iframe');
i.id='chat-win';i.src=e.data.src;
i.style.cssText='position:fixed;bottom:104px;right:24px;width:380px;height:500px;border:none;background:transparent;z-index:9998;display:none';
i.frameBorder='0';i.allow='microphone;camera';
document.body.appendChild(i);
}else if(e.data.type==='TOGGLE_CHAT'){
const w=document.getElementById('chat-win');
if(w)w.style.display=e.data.isOpen?'block':'none';
}else if(e.data.type==='SCROLL_PASSTHROUGH')window.scrollBy(0,e.data.deltaY);
});
</script>

154
chat_demo/widget_demo.html Normal file
View File

@ -0,0 +1,154 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Floating Chat Widget Demo</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 0;
padding: 40px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
color: white;
}
.demo-content {
max-width: 800px;
margin: 0 auto;
}
.demo-content h1 {
text-align: center;
font-size: 2.5rem;
margin-bottom: 2rem;
}
.demo-content p {
font-size: 1.2rem;
line-height: 1.6;
margin-bottom: 1.5rem;
}
.feature-list {
background: rgba(255, 255, 255, 0.1);
border-radius: 10px;
padding: 2rem;
margin: 2rem 0;
}
.feature-list h3 {
margin-top: 0;
font-size: 1.5rem;
}
.feature-list ul {
list-style-type: none;
padding: 0;
}
.feature-list li {
padding: 0.5rem 0;
padding-left: 1.5rem;
position: relative;
}
.feature-list li:before {
content: "✓";
position: absolute;
left: 0;
color: #4ade80;
font-weight: bold;
}
</style>
</head>
<body>
<div class="demo-content">
<h1>🚀 Floating Chat Widget Demo</h1>
<p>
Welcome to our demo page! This page simulates a real website with content.
Look for the floating chat button in the bottom-right corner - just like Intercom!
</p>
<div class="feature-list">
<h3>🎯 Widget Features</h3>
<ul>
<li>Floating button that stays visible while scrolling</li>
<li>Click to open/close the chat window</li>
<li>Minimize button to collapse the chat</li>
<li>Professional Intercom-style design</li>
<li>Unread message indicator (red badge)</li>
<li>Transparent background integration</li>
<li>Responsive design for all screen sizes</li>
</ul>
</div>
<p>
The chat widget is completely separate from your website's content and won't
interfere with your existing layout or functionality. It's designed to be
lightweight and performant.
</p>
<p>
Try scrolling this page - notice how the chat button stays in position.
Click it to start a conversation with our AI assistant!
</p>
<div class="feature-list">
<h3>🔧 Implementation</h3>
<ul>
<li>Simple iframe embed - just copy and paste</li>
<li>No JavaScript dependencies required</li>
<li>Works on any website or platform</li>
<li>Customizable appearance and behavior</li>
<li>Secure and privacy-focused</li>
</ul>
</div>
<p>
This is just placeholder content to demonstrate how the widget integrates
seamlessly with your existing website content. The widget floats above
everything else without disrupting your user experience.
</p>
<p style="margin-top: 4rem; text-align: center; font-style: italic;">
🎉 Ready to add this to your website? Get your embed code from the admin panel!
</p>
</div>
<iframe id="main-widget" src="http://localhost:9222/next-chats/widget?shared_id=9dcfc68696c611f0bb789b9b8b765d12&from=chat&auth=U4MDU3NzkwOTZjNzExZjBiYjc4OWI5Yj&visible_avatar=1&locale=zh&mode=master&streaming=false"
style="position:fixed;bottom:0;right:0;width:100px;height:100px;border:none;background:transparent;z-index:9999;opacity:0;transition:opacity 0.2s ease"
frameborder="0" allow="microphone;camera"></iframe>
<script>
window.addEventListener('message',e=>{
if(e.origin!=='http://localhost:9222')return;
if(e.data.type==='WIDGET_READY'){
// Show the main widget when React is ready
const mainWidget = document.getElementById('main-widget');
if(mainWidget) mainWidget.style.opacity = '1';
}else if(e.data.type==='CREATE_CHAT_WINDOW'){
if(document.getElementById('chat-win'))return;
const i=document.createElement('iframe');
i.id='chat-win';i.src=e.data.src;
i.style.cssText='position:fixed;bottom:104px;right:24px;width:380px;height:500px;border:none;background:transparent;z-index:9998;display:none;opacity:0;transition:opacity 0.2s ease';
i.frameBorder='0';i.allow='microphone;camera';
document.body.appendChild(i);
}else if(e.data.type==='TOGGLE_CHAT'){
const w=document.getElementById('chat-win');
if(w){
if(e.data.isOpen){
w.style.display='block';
// Wait for the iframe content to be ready before showing
setTimeout(() => w.style.opacity='1', 100);
}else{
w.style.opacity='0';
setTimeout(() => w.style.display='none', 200);
}
}
}else if(e.data.type==='SCROLL_PASSTHROUGH')window.scrollBy(0,e.data.deltaY);
});
</script>
</body>
</html>

View File

@ -402,7 +402,7 @@
"is_tools": true "is_tools": true
}, },
{ {
"llm_name": "qwen3-max-preview", "llm_name": "qwen3-max",
"tags": "LLM,CHAT,256k", "tags": "LLM,CHAT,256k",
"max_tokens": 256000, "max_tokens": 256000,
"model_type": "chat", "model_type": "chat",
@ -436,6 +436,27 @@
"model_type": "chat", "model_type": "chat",
"is_tools": true "is_tools": true
}, },
{
"llm_name": "qwen3-vl-plus",
"tags": "LLM,CHAT,IMAGE2TEXT,256k",
"max_tokens": 256000,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "qwen3-vl-235b-a22b-instruct",
"tags": "LLM,CHAT,IMAGE2TEXT,128k",
"max_tokens": 128000,
"model_type": "image2text",
"is_tools": true
},
{
"llm_name": "qwen3-vl-235b-a22b-thinking",
"tags": "LLM,CHAT,IMAGE2TEXT,128k",
"max_tokens": 128000,
"model_type": "image2text",
"is_tools": true
},
{ {
"llm_name": "qwen3-235b-a22b-instruct-2507", "llm_name": "qwen3-235b-a22b-instruct-2507",
"tags": "LLM,CHAT,128k", "tags": "LLM,CHAT,128k",
@ -457,6 +478,20 @@
"model_type": "chat", "model_type": "chat",
"is_tools": true "is_tools": true
}, },
{
"llm_name": "qwen3-next-80b-a3b-instruct",
"tags": "LLM,CHAT,128k",
"max_tokens": 128000,
"model_type": "chat",
"is_tools": true
},
{
"llm_name": "qwen3-next-80b-a3b-thinking",
"tags": "LLM,CHAT,128k",
"max_tokens": 128000,
"model_type": "chat",
"is_tools": true
},
{ {
"llm_name": "qwen3-0.6b", "llm_name": "qwen3-0.6b",
"tags": "LLM,CHAT,32k", "tags": "LLM,CHAT,32k",
@ -622,6 +657,13 @@
"tags": "SPEECH2TEXT,8k", "tags": "SPEECH2TEXT,8k",
"max_tokens": 8000, "max_tokens": 8000,
"model_type": "speech2text" "model_type": "speech2text"
},
{
"llm_name": "qianwen-deepresearch-30b-a3b-131k",
"tags": "LLM,CHAT,1M,AGENT,DEEPRESEARCH",
"max_tokens": 1000000,
"model_type": "chat",
"is_tools": true
} }
] ]
}, },

View File

@ -1,6 +1,9 @@
ragflow: ragflow:
host: 0.0.0.0 host: 0.0.0.0
http_port: 9380 http_port: 9380
admin:
host: 0.0.0.0
http_port: 9381
mysql: mysql:
name: 'rag_flow' name: 'rag_flow'
user: 'root' user: 'root'

View File

@ -19,7 +19,7 @@ from PIL import Image
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
from rag.prompts import vision_llm_figure_describe_prompt from rag.prompts.generator import vision_llm_figure_describe_prompt
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions): def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):

View File

@ -37,7 +37,7 @@ from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from rag.prompts import vision_llm_describe_prompt from rag.prompts.generator import vision_llm_describe_prompt
from rag.settings import PARALLEL_DEVICES from rag.settings import PARALLEL_DEVICES
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber" LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"

View File

@ -350,7 +350,7 @@ class TextRecognizer:
def close(self): def close(self):
# close session and release manually # close session and release manually
logging.info('Close TextRecognizer.') logging.info('Close text recognizer.')
if hasattr(self, "predictor"): if hasattr(self, "predictor"):
del self.predictor del self.predictor
gc.collect() gc.collect()
@ -490,7 +490,7 @@ class TextDetector:
return dt_boxes return dt_boxes
def close(self): def close(self):
logging.info("Close TextDetector.") logging.info("Close text detector.")
if hasattr(self, "predictor"): if hasattr(self, "predictor"):
del self.predictor del self.predictor
gc.collect() gc.collect()

View File

@ -1,6 +1,9 @@
ragflow: ragflow:
host: ${RAGFLOW_HOST:-0.0.0.0} host: ${RAGFLOW_HOST:-0.0.0.0}
http_port: 9380 http_port: 9380
admin:
host: ${RAGFLOW_HOST:-0.0.0.0}
http_port: 9381
mysql: mysql:
name: '${MYSQL_DBNAME:-rag_flow}' name: '${MYSQL_DBNAME:-rag_flow}'
user: '${MYSQL_USER:-root}' user: '${MYSQL_USER:-root}'

View File

@ -3,6 +3,6 @@
"position": 40, "position": 40,
"link": { "link": {
"type": "generated-index", "type": "generated-index",
"description": "Guides and references on accessing RAGFlow's knowledge bases via MCP." "description": "Guides and references on accessing RAGFlow's datasets via MCP."
} }
} }

View File

@ -14,9 +14,9 @@ A RAGFlow Model Context Protocol (MCP) server is designed as an independent comp
An MCP server can start up in either self-host mode (default) or host mode: An MCP server can start up in either self-host mode (default) or host mode:
- **Self-host mode**: - **Self-host mode**:
When launching an MCP server in self-host mode, you must provide an API key to authenticate the MCP server with the RAGFlow server. In this mode, the MCP server can access *only* the datasets (knowledge bases) of a specified tenant on the RAGFlow server. When launching an MCP server in self-host mode, you must provide an API key to authenticate the MCP server with the RAGFlow server. In this mode, the MCP server can access *only* the datasets of a specified tenant on the RAGFlow server.
- **Host mode**: - **Host mode**:
In host mode, each MCP client can access their own knowledge bases on the RAGFlow server. However, each client request must include a valid API key to authenticate the client with the RAGFlow server. In host mode, each MCP client can access their own datasets on the RAGFlow server. However, each client request must include a valid API key to authenticate the client with the RAGFlow server.
Once a connection is established, an MCP server communicates with its client in MCP HTTP+SSE (Server-Sent Events) mode, unidirectionally pushing responses from the RAGFlow server to its client in real time. Once a connection is established, an MCP server communicates with its client in MCP HTTP+SSE (Server-Sent Events) mode, unidirectionally pushing responses from the RAGFlow server to its client in real time.

View File

@ -498,7 +498,7 @@ To switch your document engine from Elasticsearch to [Infinity](https://github.c
### Where are my uploaded files stored in RAGFlow's image? ### Where are my uploaded files stored in RAGFlow's image?
All uploaded files are stored in Minio, RAGFlow's object storage solution. For instance, if you upload your file directly to a knowledge base, it is located at `<knowledgebase_id>/filename`. All uploaded files are stored in Minio, RAGFlow's object storage solution. For instance, if you upload your file directly to a dataset, it is located at `<knowledgebase_id>/filename`.
--- ---
@ -507,3 +507,16 @@ All uploaded files are stored in Minio, RAGFlow's object storage solution. For i
You can control the batch size for document parsing and embedding by setting the environment variables `DOC_BULK_SIZE` and `EMBEDDING_BATCH_SIZE`. Increasing these values may improve throughput for large-scale data processing, but will also increase memory usage. Adjust them according to your hardware resources. You can control the batch size for document parsing and embedding by setting the environment variables `DOC_BULK_SIZE` and `EMBEDDING_BATCH_SIZE`. Increasing these values may improve throughput for large-scale data processing, but will also increase memory usage. Adjust them according to your hardware resources.
--- ---
### How to accelerate the question-answering speed of my chat assistant?
See [here](./guides/chat/best_practices/accelerate_question_answering.mdx).
---
### How to accelerate the question-answering speed of my Agent?
See [here](./guides/agent/best_practices/accelerate_agent_question_answering.md).
---

View File

@ -229,18 +229,4 @@ The global variable name for the output of the **Agent** component, which can be
### Why does it take so long for my Agent to respond? ### Why does it take so long for my Agent to respond?
An Agents response time generally depends on two key factors: the LLMs capabilities and the prompt, the latter reflecting task complexity. When using an Agent, you should always balance task demands with the LLMs ability. See [How to balance task complexity with an Agent's performance and speed?](#how-to-balance-task-complexity-with-an-agents-performance-and-speed) for details. See [here](../best_practices/accelerate_agent_question_answering.md) for details.
## Best practices
### How to balance task complexity with an Agents performance and speed?
- For simple tasks, such as retrieval, rewriting, formatting, or structured data extraction, use concise prompts, remove planning or reasoning instructions, enforce output length limits, and select smaller or Turbo-class models. This significantly reduces latency and cost with minimal impact on quality.
- For complex tasks, like multi-step reasoning, cross-document synthesis, or tool-based workflows, maintain or enhance prompts that include planning, reflection, and verification steps.
- In multi-Agent orchestration systems, delegate simple subtasks to sub-Agents using smaller, faster models, and reserve more powerful models for the lead Agent to handle complexity and uncertainty.
:::tip KEY INSIGHT
Focus on minimizing output tokens — through summarization, bullet points, or explicit length limits — as this has far greater impact on reducing latency than optimizing input size.
:::

View File

@ -67,14 +67,14 @@ You can tune document parsing and embedding efficiency by setting the environmen
## Frequently asked questions ## Frequently asked questions
### Is the uploaded file in a knowledge base? ### Is the uploaded file in a dataset?
No. Files uploaded to an agent as input are not stored in a knowledge base and hence will not be processed using RAGFlow's built-in OCR, DLR or TSR models, or chunked using RAGFlow's built-in chunking methods. No. Files uploaded to an agent as input are not stored in a dataset and hence will not be processed using RAGFlow's built-in OCR, DLR or TSR models, or chunked using RAGFlow's built-in chunking methods.
### File size limit for an uploaded file ### File size limit for an uploaded file
There is no _specific_ file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete. There is no _specific_ file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete.
:::tip NOTE :::tip NOTE
The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a knowledge base or **File Management**. These settings DO NOT apply in this scenario. The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a dataset or **File Management**. These settings DO NOT apply in this scenario.
::: :::

View File

@ -49,6 +49,10 @@ You can specify multiple input sources for the **Code** component. Click **+ Add
This field allows you to enter and edit your source code. This field allows you to enter and edit your source code.
:::danger IMPORTANT
If your code implementation includes defined variables, whether input or output variables, ensure they are also specified in the corresponding **Input** or **Output** sections.
:::
#### A Python code example #### A Python code example
```Python ```Python
@ -77,6 +81,15 @@ This field allows you to enter and edit your source code.
You define the output variable(s) of the **Code** component here. You define the output variable(s) of the **Code** component here.
:::danger IMPORTANT
If you define output variables here, ensure they are also defined in your code implementation; otherwise, their values will be `null`. The following are two examples:
![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/set_object_output.jpg)
![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/set_nested_object_output.png)
:::
### Output ### Output
The defined output variable(s) will be auto-populated here. The defined output variable(s) will be auto-populated here.

View File

@ -9,7 +9,7 @@ A component that retrieves information from specified datasets.
## Scenarios ## Scenarios
A **Retrieval** component is essential in most RAG scenarios, where information is extracted from designated knowledge bases before being sent to the LLM for content generation. A **Retrieval** component can operate either as a standalone workflow module or as a tool for an **Agent** component. In the latter role, the **Agent** component has autonomous control over when to invoke it for query and retrieval. A **Retrieval** component is essential in most RAG scenarios, where information is extracted from designated datasets before being sent to the LLM for content generation. A **Retrieval** component can operate either as a standalone workflow module or as a tool for an **Agent** component. In the latter role, the **Agent** component has autonomous control over when to invoke it for query and retrieval.
The following screenshot shows a reference design using the **Retrieval** component, where the component serves as a tool for an **Agent** component. You can find it from the **Report Agent Using Knowledge Base** Agent template. The following screenshot shows a reference design using the **Retrieval** component, where the component serves as a tool for an **Agent** component. You can find it from the **Report Agent Using Knowledge Base** Agent template.
@ -17,7 +17,7 @@ The following screenshot shows a reference design using the **Retrieval** compon
## Prerequisites ## Prerequisites
Ensure you [have properly configured your target knowledge base(s)](../../dataset/configure_knowledge_base.md). Ensure you [have properly configured your target dataset(s)](../../dataset/configure_knowledge_base.md).
## Quickstart ## Quickstart
@ -36,9 +36,9 @@ The **Retrieval** component depends on query variables to specify its queries.
By default, you can use `sys.query`, which is the user query and the default output of the **Begin** component. All global variables defined before the **Retrieval** component can also be used as query statements. Use the `(x)` button or type `/` to show all the available query variables. By default, you can use `sys.query`, which is the user query and the default output of the **Begin** component. All global variables defined before the **Retrieval** component can also be used as query statements. Use the `(x)` button or type `/` to show all the available query variables.
### 3. Select knowledge base(s) to query ### 3. Select dataset(s) to query
You can specify one or multiple knowledge bases to retrieve data from. If selecting mutiple, ensure they use the same embedding model. You can specify one or multiple datasets to retrieve data from. If selecting mutiple, ensure they use the same embedding model.
### 4. Expand **Advanced Settings** to configure the retrieval method ### 4. Expand **Advanced Settings** to configure the retrieval method
@ -52,7 +52,7 @@ Using a rerank model will *significantly* increase the system's response time. I
### 5. Enable cross-language search ### 5. Enable cross-language search
If your user query is different from the languages of the knowledge bases, you can select the target languages in the **Cross-language search** dropdown menu. The model will then translates queries to ensure accurate matching of semantic meaning across languages. If your user query is different from the languages of the datasets, you can select the target languages in the **Cross-language search** dropdown menu. The model will then translates queries to ensure accurate matching of semantic meaning across languages.
### 6. Test retrieval results ### 6. Test retrieval results
@ -76,10 +76,10 @@ The **Retrieval** component relies on query variables to specify its queries. Al
### Knowledge bases ### Knowledge bases
Select the knowledge base(s) to retrieve data from. Select the dataset(s) to retrieve data from.
- If no knowledge base is selected, meaning conversations with the agent will not be based on any knowledge base, ensure that the **Empty response** field is left blank to avoid an error. - If no dataset is selected, meaning conversations with the agent will not be based on any dataset, ensure that the **Empty response** field is left blank to avoid an error.
- If you select multiple knowledge bases, you must ensure that the knowledge bases (datasets) you select use the same embedding model; otherwise, an error message would occur. - If you select multiple datasets, you must ensure that the datasets you select use the same embedding model; otherwise, an error message would occur.
### Similarity threshold ### Similarity threshold
@ -110,11 +110,11 @@ Using a rerank model will *significantly* increase the system's response time.
### Empty response ### Empty response
- Set this as a response if no results are retrieved from the knowledge base(s) for your query, or - Set this as a response if no results are retrieved from the dataset(s) for your query, or
- Leave this field blank to allow the chat model to improvise when nothing is found. - Leave this field blank to allow the chat model to improvise when nothing is found.
:::caution WARNING :::caution WARNING
If you do not specify a knowledge base, you must leave this field blank; otherwise, an error would occur. If you do not specify a dataset, you must leave this field blank; otherwise, an error would occur.
::: :::
### Cross-language search ### Cross-language search
@ -124,10 +124,10 @@ Select one or more languages for crosslanguage search. If no language is sele
### Use knowledge graph ### Use knowledge graph
:::caution IMPORTANT :::caution IMPORTANT
Before enabling this feature, ensure you have properly [constructed a knowledge graph from each target knowledge base](../../dataset/construct_knowledge_graph.md). Before enabling this feature, ensure you have properly [constructed a knowledge graph from each target dataset](../../dataset/construct_knowledge_graph.md).
::: :::
Whether to use knowledge graph(s) in the specified knowledge base(s) during retrieval for multi-hop question answering. When enabled, this would involve iterative searches across entity, relationship, and community report chunks, greatly increasing retrieval time. Whether to use knowledge graph(s) in the specified dataset(s) during retrieval for multi-hop question answering. When enabled, this would involve iterative searches across entity, relationship, and community report chunks, greatly increasing retrieval time.
### Output ### Output

View File

@ -27,7 +27,7 @@ Agents and RAG are complementary techniques, each enhancing the others capabi
Before proceeding, ensure that: Before proceeding, ensure that:
1. You have properly set the LLM to use. See the guides on [Configure your API key](../models/llm_api_key_setup.md) or [Deploy a local LLM](../models/deploy_local_llm.mdx) for more information. 1. You have properly set the LLM to use. See the guides on [Configure your API key](../models/llm_api_key_setup.md) or [Deploy a local LLM](../models/deploy_local_llm.mdx) for more information.
2. You have a knowledge base configured and the corresponding files properly parsed. See the guide on [Configure a knowledge base](../dataset/configure_knowledge_base.md) for more information. 2. You have a dataset configured and the corresponding files properly parsed. See the guide on [Configure a dataset](../dataset/configure_knowledge_base.md) for more information.
::: :::

View File

@ -0,0 +1,8 @@
{
"label": "Best practices",
"position": 30,
"link": {
"type": "generated-index",
"description": "Best practices on Agent configuration."
}
}

View File

@ -0,0 +1,58 @@
---
sidebar_position: 1
slug: /accelerate_agent_question_answering
---
# Accelerate answering
A checklist to speed up question answering.
---
Please note that some of your settings may consume a significant amount of time. If you often find that your question answering is time-consuming, here is a checklist to consider:
## Balance task complexity with an Agents performance and speed?
An Agents response time generally depends on many factors, e.g., the LLMs capabilities and the prompt, the latter reflecting task complexity. When using an Agent, you should always balance task demands with the LLMs ability.
- For simple tasks, such as retrieval, rewriting, formatting, or structured data extraction, use concise prompts, remove planning or reasoning instructions, enforce output length limits, and select smaller or Turbo-class models. This significantly reduces latency and cost with minimal impact on quality.
- For complex tasks, like multi-step reasoning, cross-document synthesis, or tool-based workflows, maintain or enhance prompts that include planning, reflection, and verification steps.
- In multi-Agent orchestration systems, delegate simple subtasks to sub-Agents using smaller, faster models, and reserve more powerful models for the lead Agent to handle complexity and uncertainty.
:::tip KEY INSIGHT
Focus on minimizing output tokens — through summarization, bullet points, or explicit length limits — as this has far greater impact on reducing latency than optimizing input size.
:::
## Disable Reasoning
Disabling the **Reasoning** toggle will reduce the LLM's thinking time. For a model like Qwen3, you also need to add `/no_think` to the system prompt to disable reasoning.
## Disable Rerank model
- Leaving the **Rerank model** field empty (in the corresponding **Retrieval** component) will significantly decrease retrieval time.
- When using a rerank model, ensure you have a GPU for acceleration; otherwise, the reranking process will be *prohibitively* slow.
:::tip NOTE
Please note that rerank models are essential in certain scenarios. There is always a trade-off between speed and performance; you must weigh the pros against cons for your specific case.
:::
## Check the time taken for each task
Click the light bulb icon above the *current* dialogue and scroll down the popup window to view the time taken for each task:
| Item name | Description |
| ----------------- | --------------------------------------------------------------------------------------------- |
| Total | Total time spent on this conversation round, including chunk retrieval and answer generation. |
| Check LLM | Time to validate the specified LLM. |
| Create retriever | Time to create a chunk retriever. |
| Bind embedding | Time to initialize an embedding model instance. |
| Bind LLM | Time to initialize an LLM instance. |
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. |
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
| Generate keywords | Time to extract keywords from the user query. |
| Retrieval | Time to retrieve the chunks. |
| Generate answer | Time to generate the answer. |

View File

@ -22,7 +22,7 @@ When debugging your chat assistant, you can use AI search as a reference to veri
## Prerequisites ## Prerequisites
- Ensure that you have configured the system's default models on the **Model providers** page. - Ensure that you have configured the system's default models on the **Model providers** page.
- Ensure that the intended knowledge bases are properly configured and the intended documents have finished file parsing. - Ensure that the intended datasets are properly configured and the intended documents have finished file parsing.
## Frequently asked questions ## Frequently asked questions

Some files were not shown because too many files have changed in this diff Show More