mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix: Merge main branch (#10377)
### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: jinhai <haijin.chn@gmail.com> Signed-off-by: Jin Hai <haijin.chn@gmail.com> Co-authored-by: Lynn <lynn_inf@hotmail.com> Co-authored-by: chanx <1243304602@qq.com> Co-authored-by: balibabu <cike8899@users.noreply.github.com> Co-authored-by: 纷繁下的无奈 <zhileihuang@126.com> Co-authored-by: huangzl <huangzl@shinemo.com> Co-authored-by: writinwaters <93570324+writinwaters@users.noreply.github.com> Co-authored-by: Wilmer <33392318@qq.com> Co-authored-by: Adrian Weidig <adrianweidig@gmx.net> Co-authored-by: Zhichang Yu <yuzhichang@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Yongteng Lei <yongtengrey@outlook.com> Co-authored-by: Liu An <asiro@qq.com> Co-authored-by: buua436 <66937541+buua436@users.noreply.github.com> Co-authored-by: BadwomanCraZY <511528396@qq.com> Co-authored-by: cucusenok <31804608+cucusenok@users.noreply.github.com> Co-authored-by: Russell Valentine <russ@coldstonelabs.org> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Billy Bao <newyorkupperbay@gmail.com> Co-authored-by: Zhedong Cen <cenzhedong2@126.com> Co-authored-by: TensorNull <129579691+TensorNull@users.noreply.github.com> Co-authored-by: TensorNull <tensor.null@gmail.com> Co-authored-by: Ajay <160579663+aybanda@users.noreply.github.com> Co-authored-by: AB <aj@Ajays-MacBook-Air.local> Co-authored-by: 天海蒼灆 <huangaoqin@tecpie.com> Co-authored-by: He Wang <wanghechn@qq.com> Co-authored-by: Atsushi Hatakeyama <atu729@icloud.com> Co-authored-by: Jin Hai <haijin.chn@gmail.com> Co-authored-by: Mohamed Mathari <155896313+melmathari@users.noreply.github.com> Co-authored-by: Mohamed Mathari <nocodeventure@Mac-mini-van-Mohamed.fritz.box> Co-authored-by: Stephen Hu <stephenhu@seismic.com> Co-authored-by: Shaun Zhang <zhangwfjh@users.noreply.github.com> Co-authored-by: zhimeng123 <60221886+zhimeng123@users.noreply.github.com> Co-authored-by: mxc <mxc@example.com> Co-authored-by: Dominik Novotný <50611433+SgtMarmite@users.noreply.github.com> Co-authored-by: EVGENY M <168018528+rjohny55@users.noreply.github.com> Co-authored-by: mcoder6425 <mcoder64@gmail.com> Co-authored-by: TeslaZY <TeslaZY@outlook.com> Co-authored-by: lemsn <lemsn@msn.com> Co-authored-by: lemsn <lemsn@126.com> Co-authored-by: Adrian Gora <47756404+adagora@users.noreply.github.com> Co-authored-by: Womsxd <45663319+Womsxd@users.noreply.github.com> Co-authored-by: FatMii <39074672+FatMii@users.noreply.github.com>
This commit is contained in:
101
admin/README.md
Normal file
101
admin/README.md
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# RAGFlow Admin Service & CLI
|
||||||
|
|
||||||
|
### Introduction
|
||||||
|
|
||||||
|
Admin Service is a dedicated management component designed to monitor, maintain, and administrate the RAGFlow system. It provides comprehensive tools for ensuring system stability, performing operational tasks, and managing users and permissions efficiently.
|
||||||
|
|
||||||
|
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
|
||||||
|
|
||||||
|
For user and system management, it supports listing, creating, modifying, and deleting users and their associated resources like knowledge bases and Agents.
|
||||||
|
|
||||||
|
Built with scalability and reliability in mind, the Admin Service ensures smooth system operation and simplifies maintenance workflows.
|
||||||
|
|
||||||
|
It consists of a server-side Service and a command-line client (CLI), both implemented in Python. User commands are parsed using the Lark parsing toolkit.
|
||||||
|
|
||||||
|
- **Admin Service**: A backend service that interfaces with the RAGFlow system to execute administrative operations and monitor its status.
|
||||||
|
- **Admin CLI**: A command-line interface that allows users to connect to the Admin Service and issue commands for system management.
|
||||||
|
|
||||||
|
### Starting the Admin Service
|
||||||
|
|
||||||
|
1. Before start Admin Service, please make sure RAGFlow system is already started.
|
||||||
|
|
||||||
|
2. Run the service script:
|
||||||
|
```bash
|
||||||
|
python admin/admin_server.py
|
||||||
|
```
|
||||||
|
The service will start and listen for incoming connections from the CLI on the configured port.
|
||||||
|
|
||||||
|
### Using the Admin CLI
|
||||||
|
|
||||||
|
1. Ensure the Admin Service is running.
|
||||||
|
2. Launch the CLI client:
|
||||||
|
```bash
|
||||||
|
python admin/admin_client.py -h 0.0.0.0 -p 9381
|
||||||
|
|
||||||
|
## Supported Commands
|
||||||
|
|
||||||
|
Commands are case-insensitive and must be terminated with a semicolon (`;`).
|
||||||
|
|
||||||
|
### Service Management Commands
|
||||||
|
|
||||||
|
- `LIST SERVICES;`
|
||||||
|
- Lists all available services within the RAGFlow system.
|
||||||
|
- `SHOW SERVICE <id>;`
|
||||||
|
- Shows detailed status information for the service identified by `<id>`.
|
||||||
|
- `STARTUP SERVICE <id>;`
|
||||||
|
- Attempts to start the service identified by `<id>`.
|
||||||
|
- `SHUTDOWN SERVICE <id>;`
|
||||||
|
- Attempts to gracefully shut down the service identified by `<id>`.
|
||||||
|
- `RESTART SERVICE <id>;`
|
||||||
|
- Attempts to restart the service identified by `<id>`.
|
||||||
|
|
||||||
|
### User Management Commands
|
||||||
|
|
||||||
|
- `LIST USERS;`
|
||||||
|
- Lists all users known to the system.
|
||||||
|
- `SHOW USER '<username>';`
|
||||||
|
- Shows details and permissions for the specified user. The username must be enclosed in single or double quotes.
|
||||||
|
- `DROP USER '<username>';`
|
||||||
|
- Removes the specified user from the system. Use with caution.
|
||||||
|
- `ALTER USER PASSWORD '<username>' '<new_password>';`
|
||||||
|
- Changes the password for the specified user.
|
||||||
|
|
||||||
|
### Data and Agent Commands
|
||||||
|
|
||||||
|
- `LIST DATASETS OF '<username>';`
|
||||||
|
- Lists the datasets associated with the specified user.
|
||||||
|
- `LIST AGENTS OF '<username>';`
|
||||||
|
- Lists the agents associated with the specified user.
|
||||||
|
|
||||||
|
### Meta-Commands
|
||||||
|
|
||||||
|
Meta-commands are prefixed with a backslash (`\`).
|
||||||
|
|
||||||
|
- `\?` or `\help`
|
||||||
|
- Shows help information for the available commands.
|
||||||
|
- `\q` or `\quit`
|
||||||
|
- Exits the CLI application.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
```commandline
|
||||||
|
admin> list users;
|
||||||
|
+-------------------------------+------------------------+-----------+-------------+
|
||||||
|
| create_date | email | is_active | nickname |
|
||||||
|
+-------------------------------+------------------------+-----------+-------------+
|
||||||
|
| Fri, 22 Nov 2024 16:03:41 GMT | jeffery@infiniflow.org | 1 | Jeffery |
|
||||||
|
| Fri, 22 Nov 2024 16:10:55 GMT | aya@infiniflow.org | 1 | Waterdancer |
|
||||||
|
+-------------------------------+------------------------+-----------+-------------+
|
||||||
|
|
||||||
|
admin> list services;
|
||||||
|
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
|
||||||
|
| extra | host | id | name | port | service_type |
|
||||||
|
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
|
||||||
|
| {} | 0.0.0.0 | 0 | ragflow_0 | 9380 | ragflow_server |
|
||||||
|
| {'meta_type': 'mysql', 'password': 'infini_rag_flow', 'username': 'root'} | localhost | 1 | mysql | 5455 | meta_data |
|
||||||
|
| {'password': 'infini_rag_flow', 'store_type': 'minio', 'user': 'rag_flow'} | localhost | 2 | minio | 9000 | file_store |
|
||||||
|
| {'password': 'infini_rag_flow', 'retrieval_type': 'elasticsearch', 'username': 'elastic'} | localhost | 3 | elasticsearch | 1200 | retrieval |
|
||||||
|
| {'db_name': 'default_db', 'retrieval_type': 'infinity'} | localhost | 4 | infinity | 23817 | retrieval |
|
||||||
|
| {'database': 1, 'mq_type': 'redis', 'password': 'infini_rag_flow'} | localhost | 5 | redis | 6379 | message_queue |
|
||||||
|
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
|
||||||
|
```
|
||||||
574
admin/admin_client.py
Normal file
574
admin/admin_client.py
Normal file
@ -0,0 +1,574 @@
|
|||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from Cryptodome.PublicKey import RSA
|
||||||
|
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||||
|
from typing import Dict, List, Any
|
||||||
|
from lark import Lark, Transformer, Tree
|
||||||
|
import requests
|
||||||
|
from requests.auth import HTTPBasicAuth
|
||||||
|
from api.common.base64 import encode_to_base64
|
||||||
|
|
||||||
|
GRAMMAR = r"""
|
||||||
|
start: command
|
||||||
|
|
||||||
|
command: sql_command | meta_command
|
||||||
|
|
||||||
|
sql_command: list_services
|
||||||
|
| show_service
|
||||||
|
| startup_service
|
||||||
|
| shutdown_service
|
||||||
|
| restart_service
|
||||||
|
| list_users
|
||||||
|
| show_user
|
||||||
|
| drop_user
|
||||||
|
| alter_user
|
||||||
|
| create_user
|
||||||
|
| activate_user
|
||||||
|
| list_datasets
|
||||||
|
| list_agents
|
||||||
|
|
||||||
|
// meta command definition
|
||||||
|
meta_command: "\\" meta_command_name [meta_args]
|
||||||
|
|
||||||
|
meta_command_name: /[a-zA-Z?]+/
|
||||||
|
meta_args: (meta_arg)+
|
||||||
|
|
||||||
|
meta_arg: /[^\\s"']+/ | quoted_string
|
||||||
|
|
||||||
|
// command definition
|
||||||
|
|
||||||
|
LIST: "LIST"i
|
||||||
|
SERVICES: "SERVICES"i
|
||||||
|
SHOW: "SHOW"i
|
||||||
|
CREATE: "CREATE"i
|
||||||
|
SERVICE: "SERVICE"i
|
||||||
|
SHUTDOWN: "SHUTDOWN"i
|
||||||
|
STARTUP: "STARTUP"i
|
||||||
|
RESTART: "RESTART"i
|
||||||
|
USERS: "USERS"i
|
||||||
|
DROP: "DROP"i
|
||||||
|
USER: "USER"i
|
||||||
|
ALTER: "ALTER"i
|
||||||
|
ACTIVE: "ACTIVE"i
|
||||||
|
PASSWORD: "PASSWORD"i
|
||||||
|
DATASETS: "DATASETS"i
|
||||||
|
OF: "OF"i
|
||||||
|
AGENTS: "AGENTS"i
|
||||||
|
|
||||||
|
list_services: LIST SERVICES ";"
|
||||||
|
show_service: SHOW SERVICE NUMBER ";"
|
||||||
|
startup_service: STARTUP SERVICE NUMBER ";"
|
||||||
|
shutdown_service: SHUTDOWN SERVICE NUMBER ";"
|
||||||
|
restart_service: RESTART SERVICE NUMBER ";"
|
||||||
|
|
||||||
|
list_users: LIST USERS ";"
|
||||||
|
drop_user: DROP USER quoted_string ";"
|
||||||
|
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
|
||||||
|
show_user: SHOW USER quoted_string ";"
|
||||||
|
create_user: CREATE USER quoted_string quoted_string ";"
|
||||||
|
activate_user: ALTER USER ACTIVE quoted_string status ";"
|
||||||
|
|
||||||
|
list_datasets: LIST DATASETS OF quoted_string ";"
|
||||||
|
list_agents: LIST AGENTS OF quoted_string ";"
|
||||||
|
|
||||||
|
identifier: WORD
|
||||||
|
quoted_string: QUOTED_STRING
|
||||||
|
status: WORD
|
||||||
|
|
||||||
|
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
|
||||||
|
WORD: /[a-zA-Z0-9_\-\.]+/
|
||||||
|
NUMBER: /[0-9]+/
|
||||||
|
|
||||||
|
%import common.WS
|
||||||
|
%ignore WS
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class AdminTransformer(Transformer):
|
||||||
|
|
||||||
|
def start(self, items):
|
||||||
|
return items[0]
|
||||||
|
|
||||||
|
def command(self, items):
|
||||||
|
return items[0]
|
||||||
|
|
||||||
|
def list_services(self, items):
|
||||||
|
result = {'type': 'list_services'}
|
||||||
|
return result
|
||||||
|
|
||||||
|
def show_service(self, items):
|
||||||
|
service_id = int(items[2])
|
||||||
|
return {"type": "show_service", "number": service_id}
|
||||||
|
|
||||||
|
def startup_service(self, items):
|
||||||
|
service_id = int(items[2])
|
||||||
|
return {"type": "startup_service", "number": service_id}
|
||||||
|
|
||||||
|
def shutdown_service(self, items):
|
||||||
|
service_id = int(items[2])
|
||||||
|
return {"type": "shutdown_service", "number": service_id}
|
||||||
|
|
||||||
|
def restart_service(self, items):
|
||||||
|
service_id = int(items[2])
|
||||||
|
return {"type": "restart_service", "number": service_id}
|
||||||
|
|
||||||
|
def list_users(self, items):
|
||||||
|
return {"type": "list_users"}
|
||||||
|
|
||||||
|
def show_user(self, items):
|
||||||
|
user_name = items[2]
|
||||||
|
return {"type": "show_user", "username": user_name}
|
||||||
|
|
||||||
|
def drop_user(self, items):
|
||||||
|
user_name = items[2]
|
||||||
|
return {"type": "drop_user", "username": user_name}
|
||||||
|
|
||||||
|
def alter_user(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
new_password = items[4]
|
||||||
|
return {"type": "alter_user", "username": user_name, "password": new_password}
|
||||||
|
|
||||||
|
def create_user(self, items):
|
||||||
|
user_name = items[2]
|
||||||
|
password = items[3]
|
||||||
|
return {"type": "create_user", "username": user_name, "password": password, "role": "user"}
|
||||||
|
|
||||||
|
def activate_user(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
activate_status = items[4]
|
||||||
|
return {"type": "activate_user", "activate_status": activate_status, "username": user_name}
|
||||||
|
|
||||||
|
def list_datasets(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
return {"type": "list_datasets", "username": user_name}
|
||||||
|
|
||||||
|
def list_agents(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
return {"type": "list_agents", "username": user_name}
|
||||||
|
|
||||||
|
def meta_command(self, items):
|
||||||
|
command_name = str(items[0]).lower()
|
||||||
|
args = items[1:] if len(items) > 1 else []
|
||||||
|
|
||||||
|
# handle quoted parameter
|
||||||
|
parsed_args = []
|
||||||
|
for arg in args:
|
||||||
|
if hasattr(arg, 'value'):
|
||||||
|
parsed_args.append(arg.value)
|
||||||
|
else:
|
||||||
|
parsed_args.append(str(arg))
|
||||||
|
|
||||||
|
return {'type': 'meta', 'command': command_name, 'args': parsed_args}
|
||||||
|
|
||||||
|
def meta_command_name(self, items):
|
||||||
|
return items[0]
|
||||||
|
|
||||||
|
def meta_args(self, items):
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt(input_string):
|
||||||
|
pub = '-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----'
|
||||||
|
pub_key = RSA.importKey(pub)
|
||||||
|
cipher = Cipher_pkcs1_v1_5.new(pub_key)
|
||||||
|
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode('utf-8')))
|
||||||
|
return base64.b64encode(cipher_text).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class AdminCommandParser:
|
||||||
|
def __init__(self):
|
||||||
|
self.parser = Lark(GRAMMAR, start='start', parser='lalr', transformer=AdminTransformer())
|
||||||
|
self.command_history = []
|
||||||
|
|
||||||
|
def parse_command(self, command_str: str) -> Dict[str, Any]:
|
||||||
|
if not command_str.strip():
|
||||||
|
return {'type': 'empty'}
|
||||||
|
|
||||||
|
self.command_history.append(command_str)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.parser.parse(command_str)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
return {'type': 'error', 'message': f'Parse error: {str(e)}'}
|
||||||
|
|
||||||
|
|
||||||
|
class AdminCLI:
|
||||||
|
def __init__(self):
|
||||||
|
self.parser = AdminCommandParser()
|
||||||
|
self.is_interactive = False
|
||||||
|
self.admin_account = "admin@ragflow.io"
|
||||||
|
self.admin_password: str = "admin"
|
||||||
|
self.host: str = ""
|
||||||
|
self.port: int = 0
|
||||||
|
|
||||||
|
def verify_admin(self, args):
|
||||||
|
|
||||||
|
conn_info = self._parse_connection_args(args)
|
||||||
|
if 'error' in conn_info:
|
||||||
|
print(f"Error: {conn_info['error']}")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.host = conn_info['host']
|
||||||
|
self.port = conn_info['port']
|
||||||
|
print(f"Attempt to access ip: {self.host}, port: {self.port}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/auth'
|
||||||
|
|
||||||
|
try_count = 0
|
||||||
|
while True:
|
||||||
|
try_count += 1
|
||||||
|
if try_count > 3:
|
||||||
|
return False
|
||||||
|
|
||||||
|
admin_passwd = input(f"password for {self.admin_account}: ").strip()
|
||||||
|
try:
|
||||||
|
self.admin_password = encode_to_base64(admin_passwd)
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
if response.status_code == 200:
|
||||||
|
res_json = response.json()
|
||||||
|
error_code = res_json.get('code', -1)
|
||||||
|
if error_code == 0:
|
||||||
|
print("Authentication successful.")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
error_message = res_json.get('message', 'Unknown error')
|
||||||
|
print(f"Authentication failed: {error_message}, try again")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
print(f"Bad response,status: {response.status_code}, try again")
|
||||||
|
except Exception:
|
||||||
|
print(f"Can't access {self.host}, port: {self.port}")
|
||||||
|
|
||||||
|
def _print_table_simple(self, data):
|
||||||
|
if not data:
|
||||||
|
print("No data to print")
|
||||||
|
return
|
||||||
|
if isinstance(data, dict):
|
||||||
|
# handle single row data
|
||||||
|
data = [data]
|
||||||
|
|
||||||
|
columns = list(data[0].keys())
|
||||||
|
col_widths = {}
|
||||||
|
|
||||||
|
for col in columns:
|
||||||
|
max_width = len(str(col))
|
||||||
|
for item in data:
|
||||||
|
value_len = len(str(item.get(col, '')))
|
||||||
|
if value_len > max_width:
|
||||||
|
max_width = value_len
|
||||||
|
col_widths[col] = max(2, max_width)
|
||||||
|
|
||||||
|
# Generate delimiter
|
||||||
|
separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
|
||||||
|
|
||||||
|
# Print header
|
||||||
|
print(separator)
|
||||||
|
header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
|
||||||
|
print(header)
|
||||||
|
print(separator)
|
||||||
|
|
||||||
|
# Print data
|
||||||
|
for item in data:
|
||||||
|
row = "|"
|
||||||
|
for col in columns:
|
||||||
|
value = str(item.get(col, ''))
|
||||||
|
if len(value) > col_widths[col]:
|
||||||
|
value = value[:col_widths[col] - 3] + "..."
|
||||||
|
row += f" {value:<{col_widths[col]}} |"
|
||||||
|
print(row)
|
||||||
|
|
||||||
|
print(separator)
|
||||||
|
|
||||||
|
def run_interactive(self):
|
||||||
|
|
||||||
|
self.is_interactive = True
|
||||||
|
print("RAGFlow Admin command line interface - Type '\\?' for help, '\\q' to quit")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
command = input("admin> ").strip()
|
||||||
|
if not command:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"command: {command}")
|
||||||
|
result = self.parser.parse_command(command)
|
||||||
|
self.execute_command(result)
|
||||||
|
|
||||||
|
if isinstance(result, Tree):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if result.get('type') == 'meta' and result.get('command') in ['q', 'quit', 'exit']:
|
||||||
|
break
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nUse '\\q' to quit")
|
||||||
|
except EOFError:
|
||||||
|
print("\nGoodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
def run_single_command(self, args):
|
||||||
|
conn_info = self._parse_connection_args(args)
|
||||||
|
if 'error' in conn_info:
|
||||||
|
print(f"Error: {conn_info['error']}")
|
||||||
|
return
|
||||||
|
|
||||||
|
def _parse_connection_args(self, args: List[str]) -> Dict[str, Any]:
|
||||||
|
parser = argparse.ArgumentParser(description='Admin CLI Client', add_help=False)
|
||||||
|
parser.add_argument('-h', '--host', default='localhost', help='Admin service host')
|
||||||
|
parser.add_argument('-p', '--port', type=int, default=8080, help='Admin service port')
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_args, remaining_args = parser.parse_known_args(args)
|
||||||
|
return {
|
||||||
|
'host': parsed_args.host,
|
||||||
|
'port': parsed_args.port,
|
||||||
|
}
|
||||||
|
except SystemExit:
|
||||||
|
return {'error': 'Invalid connection arguments'}
|
||||||
|
|
||||||
|
def execute_command(self, parsed_command: Dict[str, Any]):
|
||||||
|
|
||||||
|
command_dict: dict
|
||||||
|
if isinstance(parsed_command, Tree):
|
||||||
|
command_dict = parsed_command.children[0]
|
||||||
|
else:
|
||||||
|
if parsed_command['type'] == 'error':
|
||||||
|
print(f"Error: {parsed_command['message']}")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
command_dict = parsed_command
|
||||||
|
|
||||||
|
# print(f"Parsed command: {command_dict}")
|
||||||
|
|
||||||
|
command_type = command_dict['type']
|
||||||
|
|
||||||
|
match command_type:
|
||||||
|
case 'list_services':
|
||||||
|
self._handle_list_services(command_dict)
|
||||||
|
case 'show_service':
|
||||||
|
self._handle_show_service(command_dict)
|
||||||
|
case 'restart_service':
|
||||||
|
self._handle_restart_service(command_dict)
|
||||||
|
case 'shutdown_service':
|
||||||
|
self._handle_shutdown_service(command_dict)
|
||||||
|
case 'startup_service':
|
||||||
|
self._handle_startup_service(command_dict)
|
||||||
|
case 'list_users':
|
||||||
|
self._handle_list_users(command_dict)
|
||||||
|
case 'show_user':
|
||||||
|
self._handle_show_user(command_dict)
|
||||||
|
case 'drop_user':
|
||||||
|
self._handle_drop_user(command_dict)
|
||||||
|
case 'alter_user':
|
||||||
|
self._handle_alter_user(command_dict)
|
||||||
|
case 'create_user':
|
||||||
|
self._handle_create_user(command_dict)
|
||||||
|
case 'activate_user':
|
||||||
|
self._handle_activate_user(command_dict)
|
||||||
|
case 'list_datasets':
|
||||||
|
self._handle_list_datasets(command_dict)
|
||||||
|
case 'list_agents':
|
||||||
|
self._handle_list_agents(command_dict)
|
||||||
|
case 'meta':
|
||||||
|
self._handle_meta_command(command_dict)
|
||||||
|
case _:
|
||||||
|
print(f"Command '{command_type}' would be executed with API")
|
||||||
|
|
||||||
|
def _handle_list_services(self, command):
|
||||||
|
print("Listing all services")
|
||||||
|
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/services'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_show_service(self, command):
|
||||||
|
service_id: int = command['number']
|
||||||
|
print(f"Showing service: {service_id}")
|
||||||
|
|
||||||
|
def _handle_restart_service(self, command):
|
||||||
|
service_id: int = command['number']
|
||||||
|
print(f"Restart service {service_id}")
|
||||||
|
|
||||||
|
def _handle_shutdown_service(self, command):
|
||||||
|
service_id: int = command['number']
|
||||||
|
print(f"Shutdown service {service_id}")
|
||||||
|
|
||||||
|
def _handle_startup_service(self, command):
|
||||||
|
service_id: int = command['number']
|
||||||
|
print(f"Startup service {service_id}")
|
||||||
|
|
||||||
|
def _handle_list_users(self, command):
|
||||||
|
print("Listing all users")
|
||||||
|
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_show_user(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
print(f"Showing user: {username}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get user {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_drop_user(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
print(f"Drop user: {username}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
|
||||||
|
response = requests.delete(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(res_json["message"])
|
||||||
|
else:
|
||||||
|
print(f"Fail to drop user, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_alter_user(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
password_tree: Tree = command['password']
|
||||||
|
password: str = password_tree.children[0].strip("'\"")
|
||||||
|
print(f"Alter user: {username}, password: {password}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/password'
|
||||||
|
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||||
|
json={'new_password': encrypt(password)})
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(res_json["message"])
|
||||||
|
else:
|
||||||
|
print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_create_user(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
password_tree: Tree = command['password']
|
||||||
|
password: str = password_tree.children[0].strip("'\"")
|
||||||
|
role: str = command['role']
|
||||||
|
print(f"Create user: {username}, password: {password}, role: {role}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||||
|
json={'username': username, 'password': encrypt(password), 'role': role}
|
||||||
|
)
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to create user {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_activate_user(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
activate_tree: Tree = command['activate_status']
|
||||||
|
activate_status: str = activate_tree.children[0].strip("'\"")
|
||||||
|
if activate_status.lower() in ['on', 'off']:
|
||||||
|
print(f"Alter user {username} activate status, turn {activate_status.lower()}.")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/activate'
|
||||||
|
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||||
|
json={'activate_status': activate_status})
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(res_json["message"])
|
||||||
|
else:
|
||||||
|
print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
else:
|
||||||
|
print(f"Unknown activate status: {activate_status}.")
|
||||||
|
|
||||||
|
def _handle_list_datasets(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
print(f"Listing all datasets of user: {username}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/datasets'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get all datasets of {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_list_agents(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
print(f"Listing all agents of user: {username}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/agents'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get all agents of {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_meta_command(self, command):
|
||||||
|
meta_command = command['command']
|
||||||
|
args = command.get('args', [])
|
||||||
|
|
||||||
|
if meta_command in ['?', 'h', 'help']:
|
||||||
|
self.show_help()
|
||||||
|
elif meta_command in ['q', 'quit', 'exit']:
|
||||||
|
print("Goodbye!")
|
||||||
|
else:
|
||||||
|
print(f"Meta command '{meta_command}' with args {args}")
|
||||||
|
|
||||||
|
def show_help(self):
|
||||||
|
"""Help info"""
|
||||||
|
help_text = """
|
||||||
|
Commands:
|
||||||
|
LIST SERVICES
|
||||||
|
SHOW SERVICE <service>
|
||||||
|
STARTUP SERVICE <service>
|
||||||
|
SHUTDOWN SERVICE <service>
|
||||||
|
RESTART SERVICE <service>
|
||||||
|
LIST USERS
|
||||||
|
SHOW USER <user>
|
||||||
|
DROP USER <user>
|
||||||
|
CREATE USER <user> <password>
|
||||||
|
ALTER USER PASSWORD <user> <new_password>
|
||||||
|
ALTER USER ACTIVE <user> <on/off>
|
||||||
|
LIST DATASETS OF <user>
|
||||||
|
LIST AGENTS OF <user>
|
||||||
|
|
||||||
|
Meta Commands:
|
||||||
|
\\?, \\h, \\help Show this help
|
||||||
|
\\q, \\quit, \\exit Quit the CLI
|
||||||
|
"""
|
||||||
|
print(help_text)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import sys
|
||||||
|
|
||||||
|
cli = AdminCLI()
|
||||||
|
|
||||||
|
if len(sys.argv) == 1 or (len(sys.argv) > 1 and sys.argv[1] == '-'):
|
||||||
|
print(r"""
|
||||||
|
____ ___ ______________ ___ __ _
|
||||||
|
/ __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___
|
||||||
|
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \
|
||||||
|
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /
|
||||||
|
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
|
||||||
|
""")
|
||||||
|
if cli.verify_admin(sys.argv):
|
||||||
|
cli.run_interactive()
|
||||||
|
else:
|
||||||
|
if cli.verify_admin(sys.argv):
|
||||||
|
cli.run_interactive()
|
||||||
|
# cli.run_single_command(sys.argv[1:])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
47
admin/admin_server.py
Normal file
47
admin/admin_server.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
from werkzeug.serving import run_simple
|
||||||
|
from flask import Flask
|
||||||
|
from routes import admin_bp
|
||||||
|
from api.utils.log_utils import init_root_logger
|
||||||
|
from api.constants import SERVICE_CONF
|
||||||
|
from api import settings
|
||||||
|
from config import load_configurations, SERVICE_CONFIGS
|
||||||
|
|
||||||
|
stop_event = threading.Event()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
init_root_logger("admin_service")
|
||||||
|
logging.info(r"""
|
||||||
|
____ ___ ______________ ___ __ _
|
||||||
|
/ __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___
|
||||||
|
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \
|
||||||
|
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /
|
||||||
|
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
|
||||||
|
""")
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.register_blueprint(admin_bp)
|
||||||
|
settings.init_settings()
|
||||||
|
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.info("RAGFlow Admin service start...")
|
||||||
|
run_simple(
|
||||||
|
hostname="0.0.0.0",
|
||||||
|
port=9381,
|
||||||
|
application=app,
|
||||||
|
threaded=True,
|
||||||
|
use_reloader=True,
|
||||||
|
use_debugger=True,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
stop_event.set()
|
||||||
|
time.sleep(1)
|
||||||
|
os.kill(os.getpid(), signal.SIGKILL)
|
||||||
57
admin/auth.py
Normal file
57
admin/auth.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from functools import wraps
|
||||||
|
from flask import request, jsonify
|
||||||
|
|
||||||
|
from exceptions import AdminException
|
||||||
|
from api.db.init_data import encode_to_base64
|
||||||
|
from api.db.services import UserService
|
||||||
|
|
||||||
|
|
||||||
|
def check_admin(username: str, password: str):
|
||||||
|
users = UserService.query(email=username)
|
||||||
|
if not users:
|
||||||
|
logging.info(f"Username: {username} is not registered!")
|
||||||
|
user_info = {
|
||||||
|
"id": uuid.uuid1().hex,
|
||||||
|
"password": encode_to_base64("admin"),
|
||||||
|
"nickname": "admin",
|
||||||
|
"is_superuser": True,
|
||||||
|
"email": "admin@ragflow.io",
|
||||||
|
"creator": "system",
|
||||||
|
"status": "1",
|
||||||
|
}
|
||||||
|
if not UserService.save(**user_info):
|
||||||
|
raise AdminException("Can't init admin.", 500)
|
||||||
|
|
||||||
|
user = UserService.query_user(username, password)
|
||||||
|
if user:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def login_verify(f):
|
||||||
|
@wraps(f)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
auth = request.authorization
|
||||||
|
if not auth or 'username' not in auth.parameters or 'password' not in auth.parameters:
|
||||||
|
return jsonify({
|
||||||
|
"code": 401,
|
||||||
|
"message": "Authentication required",
|
||||||
|
"data": None
|
||||||
|
}), 200
|
||||||
|
|
||||||
|
username = auth.parameters['username']
|
||||||
|
password = auth.parameters['password']
|
||||||
|
# TODO: to check the username and password from DB
|
||||||
|
if check_admin(username, password) is False:
|
||||||
|
return jsonify({
|
||||||
|
"code": 403,
|
||||||
|
"message": "Access denied",
|
||||||
|
"data": None
|
||||||
|
}), 200
|
||||||
|
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
280
admin/config.py
Normal file
280
admin/config.py
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Any
|
||||||
|
from api.utils.configs import read_config
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceConfigs:
|
||||||
|
def __init__(self):
|
||||||
|
self.configs = []
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
SERVICE_CONFIGS = ServiceConfigs
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceType(Enum):
|
||||||
|
METADATA = "metadata"
|
||||||
|
RETRIEVAL = "retrieval"
|
||||||
|
MESSAGE_QUEUE = "message_queue"
|
||||||
|
RAGFLOW_SERVER = "ragflow_server"
|
||||||
|
TASK_EXECUTOR = "task_executor"
|
||||||
|
FILE_STORE = "file_store"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConfig(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
service_type: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port, 'service_type': self.service_type}
|
||||||
|
|
||||||
|
|
||||||
|
class MetaConfig(BaseConfig):
|
||||||
|
meta_type: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['meta_type'] = self.meta_type
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class MySQLConfig(MetaConfig):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['username'] = self.username
|
||||||
|
extra_dict['password'] = self.password
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresConfig(MetaConfig):
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalConfig(BaseConfig):
|
||||||
|
retrieval_type: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['retrieval_type'] = self.retrieval_type
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class InfinityConfig(RetrievalConfig):
|
||||||
|
db_name: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['db_name'] = self.db_name
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticsearchConfig(RetrievalConfig):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['username'] = self.username
|
||||||
|
extra_dict['password'] = self.password
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class MessageQueueConfig(BaseConfig):
|
||||||
|
mq_type: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['mq_type'] = self.mq_type
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class RedisConfig(MessageQueueConfig):
|
||||||
|
database: int
|
||||||
|
password: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['database'] = self.database
|
||||||
|
extra_dict['password'] = self.password
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class RabbitMQConfig(MessageQueueConfig):
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class RAGFlowServerConfig(BaseConfig):
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TaskExecutorConfig(BaseConfig):
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class FileStoreConfig(BaseConfig):
|
||||||
|
store_type: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['store_type'] = self.store_type
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class MinioConfig(FileStoreConfig):
|
||||||
|
user: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['user'] = self.user
|
||||||
|
extra_dict['password'] = self.password
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def load_configurations(config_path: str) -> list[BaseConfig]:
|
||||||
|
raw_configs = read_config(config_path)
|
||||||
|
configurations = []
|
||||||
|
ragflow_count = 0
|
||||||
|
id_count = 0
|
||||||
|
for k, v in raw_configs.items():
|
||||||
|
match (k):
|
||||||
|
case "ragflow":
|
||||||
|
name: str = f'ragflow_{ragflow_count}'
|
||||||
|
host: str = v['host']
|
||||||
|
http_port: int = v['http_port']
|
||||||
|
config = RAGFlowServerConfig(id=id_count, name=name, host=host, port=http_port, service_type="ragflow_server")
|
||||||
|
configurations.append(config)
|
||||||
|
id_count += 1
|
||||||
|
case "es":
|
||||||
|
name: str = 'elasticsearch'
|
||||||
|
url = v['hosts']
|
||||||
|
parsed = urlparse(url)
|
||||||
|
host: str = parsed.hostname
|
||||||
|
port: int = parsed.port
|
||||||
|
username: str = v.get('username')
|
||||||
|
password: str = v.get('password')
|
||||||
|
config = ElasticsearchConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval",
|
||||||
|
retrieval_type="elasticsearch",
|
||||||
|
username=username, password=password)
|
||||||
|
configurations.append(config)
|
||||||
|
id_count += 1
|
||||||
|
|
||||||
|
case "infinity":
|
||||||
|
name: str = 'infinity'
|
||||||
|
url = v['uri']
|
||||||
|
parts = url.split(':', 1)
|
||||||
|
host = parts[0]
|
||||||
|
port = int(parts[1])
|
||||||
|
database: str = v.get('db_name', 'default_db')
|
||||||
|
config = InfinityConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval", retrieval_type="infinity",
|
||||||
|
db_name=database)
|
||||||
|
configurations.append(config)
|
||||||
|
id_count += 1
|
||||||
|
case "minio":
|
||||||
|
name: str = 'minio'
|
||||||
|
url = v['host']
|
||||||
|
parts = url.split(':', 1)
|
||||||
|
host = parts[0]
|
||||||
|
port = int(parts[1])
|
||||||
|
user = v.get('user')
|
||||||
|
password = v.get('password')
|
||||||
|
config = MinioConfig(id=id_count, name=name, host=host, port=port, user=user, password=password, service_type="file_store",
|
||||||
|
store_type="minio")
|
||||||
|
configurations.append(config)
|
||||||
|
id_count += 1
|
||||||
|
case "redis":
|
||||||
|
name: str = 'redis'
|
||||||
|
url = v['host']
|
||||||
|
parts = url.split(':', 1)
|
||||||
|
host = parts[0]
|
||||||
|
port = int(parts[1])
|
||||||
|
password = v.get('password')
|
||||||
|
db: int = v.get('db')
|
||||||
|
config = RedisConfig(id=id_count, name=name, host=host, port=port, password=password, database=db,
|
||||||
|
service_type="message_queue", mq_type="redis")
|
||||||
|
configurations.append(config)
|
||||||
|
id_count += 1
|
||||||
|
case "mysql":
|
||||||
|
name: str = 'mysql'
|
||||||
|
host: str = v.get('host')
|
||||||
|
port: int = v.get('port')
|
||||||
|
username = v.get('user')
|
||||||
|
password = v.get('password')
|
||||||
|
config = MySQLConfig(id=id_count, name=name, host=host, port=port, username=username, password=password,
|
||||||
|
service_type="meta_data", meta_type="mysql")
|
||||||
|
configurations.append(config)
|
||||||
|
id_count += 1
|
||||||
|
case "admin":
|
||||||
|
pass
|
||||||
|
case _:
|
||||||
|
logging.warning(f"Unknown configuration key: {k}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return configurations
|
||||||
17
admin/exceptions.py
Normal file
17
admin/exceptions.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
class AdminException(Exception):
|
||||||
|
def __init__(self, message, code=400):
|
||||||
|
super().__init__(message)
|
||||||
|
self.code = code
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
class UserNotFoundError(AdminException):
|
||||||
|
def __init__(self, username):
|
||||||
|
super().__init__(f"User '{username}' not found", 404)
|
||||||
|
|
||||||
|
class UserAlreadyExistsError(AdminException):
|
||||||
|
def __init__(self, username):
|
||||||
|
super().__init__(f"User '{username}' already exists", 409)
|
||||||
|
|
||||||
|
class CannotDeleteAdminError(AdminException):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("Cannot delete admin account", 403)
|
||||||
0
admin/models.py
Normal file
0
admin/models.py
Normal file
15
admin/responses.py
Normal file
15
admin/responses.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from flask import jsonify
|
||||||
|
|
||||||
|
def success_response(data=None, message="Success", code = 0):
|
||||||
|
return jsonify({
|
||||||
|
"code": code,
|
||||||
|
"message": message,
|
||||||
|
"data": data
|
||||||
|
}), 200
|
||||||
|
|
||||||
|
def error_response(message="Error", code=-1, data=None):
|
||||||
|
return jsonify({
|
||||||
|
"code": code,
|
||||||
|
"message": message,
|
||||||
|
"data": data
|
||||||
|
}), 400
|
||||||
190
admin/routes.py
Normal file
190
admin/routes.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
from flask import Blueprint, request
|
||||||
|
|
||||||
|
from auth import login_verify
|
||||||
|
from responses import success_response, error_response
|
||||||
|
from services import UserMgr, ServiceMgr, UserServiceMgr
|
||||||
|
from exceptions import AdminException
|
||||||
|
|
||||||
|
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/auth', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def auth_admin():
|
||||||
|
try:
|
||||||
|
return success_response(None, "Admin is authorized", 0)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def list_users():
|
||||||
|
try:
|
||||||
|
users = UserMgr.get_all_users()
|
||||||
|
return success_response(users, "Get all users", 0)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users', methods=['POST'])
|
||||||
|
@login_verify
|
||||||
|
def create_user():
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
if not data or 'username' not in data or 'password' not in data:
|
||||||
|
return error_response("Username and password are required", 400)
|
||||||
|
|
||||||
|
username = data['username']
|
||||||
|
password = data['password']
|
||||||
|
role = data.get('role', 'user')
|
||||||
|
|
||||||
|
res = UserMgr.create_user(username, password, role)
|
||||||
|
if res["success"]:
|
||||||
|
user_info = res["user_info"]
|
||||||
|
user_info.pop("password") # do not return password
|
||||||
|
return success_response(user_info, "User created successfully")
|
||||||
|
else:
|
||||||
|
return error_response("create user failed")
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>', methods=['DELETE'])
|
||||||
|
@login_verify
|
||||||
|
def delete_user(username):
|
||||||
|
try:
|
||||||
|
res = UserMgr.delete_user(username)
|
||||||
|
if res["success"]:
|
||||||
|
return success_response(None, res["message"])
|
||||||
|
else:
|
||||||
|
return error_response(res["message"])
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>/password', methods=['PUT'])
|
||||||
|
@login_verify
|
||||||
|
def change_password(username):
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
if not data or 'new_password' not in data:
|
||||||
|
return error_response("New password is required", 400)
|
||||||
|
|
||||||
|
new_password = data['new_password']
|
||||||
|
msg = UserMgr.update_user_password(username, new_password)
|
||||||
|
return success_response(None, msg)
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>/activate', methods=['PUT'])
|
||||||
|
@login_verify
|
||||||
|
def alter_user_activate_status(username):
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
if not data or 'activate_status' not in data:
|
||||||
|
return error_response("Activation status is required", 400)
|
||||||
|
activate_status = data['activate_status']
|
||||||
|
msg = UserMgr.update_user_activate_status(username, activate_status)
|
||||||
|
return success_response(None, msg)
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_user_details(username):
|
||||||
|
try:
|
||||||
|
user_details = UserMgr.get_user_details(username)
|
||||||
|
return success_response(user_details)
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>/datasets', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_user_datasets(username):
|
||||||
|
try:
|
||||||
|
datasets_list = UserServiceMgr.get_user_datasets(username)
|
||||||
|
return success_response(datasets_list)
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>/agents', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_user_agents(username):
|
||||||
|
try:
|
||||||
|
agents_list = UserServiceMgr.get_user_agents(username)
|
||||||
|
return success_response(agents_list)
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/services', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_services():
|
||||||
|
try:
|
||||||
|
services = ServiceMgr.get_all_services()
|
||||||
|
return success_response(services, "Get all services", 0)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/service_types/<service_type>', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_services_by_type(service_type_str):
|
||||||
|
try:
|
||||||
|
services = ServiceMgr.get_services_by_type(service_type_str)
|
||||||
|
return success_response(services)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/services/<service_id>', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_service(service_id):
|
||||||
|
try:
|
||||||
|
services = ServiceMgr.get_service_details(service_id)
|
||||||
|
return success_response(services)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/services/<service_id>', methods=['DELETE'])
|
||||||
|
@login_verify
|
||||||
|
def shutdown_service(service_id):
|
||||||
|
try:
|
||||||
|
services = ServiceMgr.shutdown_service(service_id)
|
||||||
|
return success_response(services)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/services/<service_id>', methods=['PUT'])
|
||||||
|
@login_verify
|
||||||
|
def restart_service(service_id):
|
||||||
|
try:
|
||||||
|
services = ServiceMgr.restart_service(service_id)
|
||||||
|
return success_response(services)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
175
admin/services.py
Normal file
175
admin/services.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
import re
|
||||||
|
from werkzeug.security import check_password_hash
|
||||||
|
from api.db import ActiveEnum
|
||||||
|
from api.db.services import UserService
|
||||||
|
from api.db.joint_services.user_account_service import create_new_user, delete_user_data
|
||||||
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
|
from api.db.services.user_service import TenantService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.utils.crypt import decrypt
|
||||||
|
from exceptions import AdminException, UserAlreadyExistsError, UserNotFoundError
|
||||||
|
from config import SERVICE_CONFIGS
|
||||||
|
|
||||||
|
class UserMgr:
|
||||||
|
@staticmethod
|
||||||
|
def get_all_users():
|
||||||
|
users = UserService.get_all_users()
|
||||||
|
result = []
|
||||||
|
for user in users:
|
||||||
|
result.append({'email': user.email, 'nickname': user.nickname, 'create_date': user.create_date, 'is_active': user.is_active})
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_details(username):
|
||||||
|
# use email to query
|
||||||
|
users = UserService.query_user_by_email(username)
|
||||||
|
result = []
|
||||||
|
for user in users:
|
||||||
|
result.append({
|
||||||
|
'email': user.email,
|
||||||
|
'language': user.language,
|
||||||
|
'last_login_time': user.last_login_time,
|
||||||
|
'is_authenticated': user.is_authenticated,
|
||||||
|
'is_active': user.is_active,
|
||||||
|
'is_anonymous': user.is_anonymous,
|
||||||
|
'login_channel': user.login_channel,
|
||||||
|
'status': user.status,
|
||||||
|
'is_superuser': user.is_superuser,
|
||||||
|
'create_date': user.create_date,
|
||||||
|
'update_date': user.update_date
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_user(username, password, role="user") -> dict:
|
||||||
|
# Validate the email address
|
||||||
|
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", username):
|
||||||
|
raise AdminException(f"Invalid email address: {username}!")
|
||||||
|
# Check if the email address is already used
|
||||||
|
if UserService.query(email=username):
|
||||||
|
raise UserAlreadyExistsError(username)
|
||||||
|
# Construct user info data
|
||||||
|
user_info_dict = {
|
||||||
|
"email": username,
|
||||||
|
"nickname": "", # ask user to edit it manually in settings.
|
||||||
|
"password": decrypt(password),
|
||||||
|
"login_channel": "password",
|
||||||
|
"is_superuser": role == "admin",
|
||||||
|
}
|
||||||
|
return create_new_user(user_info_dict)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_user(username):
|
||||||
|
# use email to delete
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
if len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
usr = user_list[0]
|
||||||
|
return delete_user_data(usr.id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_user_password(username, new_password) -> str:
|
||||||
|
# use email to find user. check exist and unique.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# check new_password different from old.
|
||||||
|
usr = user_list[0]
|
||||||
|
psw = decrypt(new_password)
|
||||||
|
if check_password_hash(usr.password, psw):
|
||||||
|
return "Same password, no need to update!"
|
||||||
|
# update password
|
||||||
|
UserService.update_user_password(usr.id, psw)
|
||||||
|
return "Password updated successfully!"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_user_activate_status(username, activate_status: str):
|
||||||
|
# use email to find user. check exist and unique.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# check activate status different from new
|
||||||
|
usr = user_list[0]
|
||||||
|
# format activate_status before handle
|
||||||
|
_activate_status = activate_status.lower()
|
||||||
|
target_status = {
|
||||||
|
'on': ActiveEnum.ACTIVE.value,
|
||||||
|
'off': ActiveEnum.INACTIVE.value,
|
||||||
|
}.get(_activate_status)
|
||||||
|
if not target_status:
|
||||||
|
raise AdminException(f"Invalid activate_status: {activate_status}")
|
||||||
|
if target_status == usr.is_active:
|
||||||
|
return f"User activate status is already {_activate_status}!"
|
||||||
|
# update is_active
|
||||||
|
UserService.update_user(usr.id, {"is_active": target_status})
|
||||||
|
return f"Turn {_activate_status} user activate status successfully!"
|
||||||
|
|
||||||
|
class UserServiceMgr:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_datasets(username):
|
||||||
|
# use email to find user.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# find tenants
|
||||||
|
usr = user_list[0]
|
||||||
|
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||||
|
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||||
|
# filter permitted kb and owned kb
|
||||||
|
return KnowledgebaseService.get_all_kb_by_tenant_ids(tenant_ids, usr.id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_agents(username):
|
||||||
|
# use email to find user.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# find tenants
|
||||||
|
usr = user_list[0]
|
||||||
|
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||||
|
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||||
|
# filter permitted agents and owned agents
|
||||||
|
res = UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
|
||||||
|
return [{
|
||||||
|
'title': r['title'],
|
||||||
|
'permission': r['permission'],
|
||||||
|
'canvas_type': r['canvas_type'],
|
||||||
|
'canvas_category': r['canvas_category']
|
||||||
|
} for r in res]
|
||||||
|
|
||||||
|
class ServiceMgr:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_services():
|
||||||
|
result = []
|
||||||
|
configs = SERVICE_CONFIGS.configs
|
||||||
|
for config in configs:
|
||||||
|
result.append(config.to_dict())
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_services_by_type(service_type_str: str):
|
||||||
|
raise AdminException("get_services_by_type: not implemented")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_service_details(service_id: int):
|
||||||
|
raise AdminException("get_service_details: not implemented")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def shutdown_service(service_id: int):
|
||||||
|
raise AdminException("shutdown_service: not implemented")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def restart_service(service_id: int):
|
||||||
|
raise AdminException("restart_service: not implemented")
|
||||||
@ -27,7 +27,7 @@ from agent.component import component_class
|
|||||||
from agent.component.base import ComponentBase
|
from agent.component.base import ComponentBase
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.utils import get_uuid, hash_str2int
|
from api.utils import get_uuid, hash_str2int
|
||||||
from rag.prompts.prompts import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
|
||||||
class Graph:
|
class Graph:
|
||||||
@ -490,7 +490,8 @@ class Canvas(Graph):
|
|||||||
|
|
||||||
r = self.retrieval[-1]
|
r = self.retrieval[-1]
|
||||||
for ck in chunks_format({"chunks": chunks}):
|
for ck in chunks_format({"chunks": chunks}):
|
||||||
cid = hash_str2int(ck["id"], 100)
|
cid = hash_str2int(ck["id"], 500)
|
||||||
|
# cid = uuid.uuid5(uuid.NAMESPACE_DNS, ck["id"])
|
||||||
if cid not in r:
|
if cid not in r:
|
||||||
r["chunks"][cid] = ck
|
r["chunks"][cid] = ck
|
||||||
|
|
||||||
|
|||||||
@ -28,9 +28,8 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.mcp_server_service import MCPServerService
|
from api.db.services.mcp_server_service import MCPServerService
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.prompts import message_fit_in
|
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
|
||||||
from rag.prompts.prompts import next_step, COMPLETE_TASK, analyze_task, \
|
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
|
||||||
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question
|
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||||
from agent.component.llm import LLMParam, LLM
|
from agent.component.llm import LLMParam, LLM
|
||||||
|
|
||||||
@ -138,7 +137,7 @@ class Agent(LLM, ToolBase):
|
|||||||
res.update(cpn.get_input_form())
|
res.update(cpn.get_input_form())
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if kwargs.get("user_prompt"):
|
if kwargs.get("user_prompt"):
|
||||||
usr_pmt = ""
|
usr_pmt = ""
|
||||||
|
|||||||
@ -244,7 +244,7 @@ class ComponentParamBase(ABC):
|
|||||||
|
|
||||||
if not value_legal:
|
if not value_legal:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
|
"Please check runtime conf, {} = {} does not match user-parameter restriction".format(
|
||||||
variable, value
|
variable, value
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -431,7 +431,7 @@ class ComponentBase(ABC):
|
|||||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
return self.output()
|
return self.output()
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from rag.llm.chat_model import ERROR_PREFIX
|
|||||||
class CategorizeParam(LLMParam):
|
class CategorizeParam(LLMParam):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Define the Categorize component parameters.
|
Define the categorize component parameters.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -80,7 +80,7 @@ Here's description of each category:
|
|||||||
- Prioritize the most specific applicable category
|
- Prioritize the most specific applicable category
|
||||||
- Return only the category name without explanations
|
- Return only the category name without explanations
|
||||||
- Use "Other" only when no other category fits
|
- Use "Other" only when no other category fits
|
||||||
|
|
||||||
""".format(
|
""".format(
|
||||||
"\n - ".join(list(self.category_description.keys())),
|
"\n - ".join(list(self.category_description.keys())),
|
||||||
"\n".join(descriptions)
|
"\n".join(descriptions)
|
||||||
@ -96,7 +96,7 @@ Here's description of each category:
|
|||||||
class Categorize(LLM, ABC):
|
class Categorize(LLM, ABC):
|
||||||
component_name = "Categorize"
|
component_name = "Categorize"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||||
if not msg:
|
if not msg:
|
||||||
@ -112,7 +112,7 @@ class Categorize(LLM, ABC):
|
|||||||
|
|
||||||
user_prompt = """
|
user_prompt = """
|
||||||
---- Real Data ----
|
---- Real Data ----
|
||||||
{} →
|
{} →
|
||||||
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
|
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
|
||||||
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
||||||
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
||||||
@ -134,4 +134,4 @@ class Categorize(LLM, ABC):
|
|||||||
self.set_output("_next", cpn_ids)
|
self.set_output("_next", cpn_ids)
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class InvokeParam(ComponentParamBase):
|
|||||||
class Invoke(ComponentBase, ABC):
|
class Invoke(ComponentBase, ABC):
|
||||||
component_name = "Invoke"
|
component_name = "Invoke"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
args = {}
|
args = {}
|
||||||
for para in self._param.variables:
|
for para in self._param.variables:
|
||||||
|
|||||||
@ -26,8 +26,7 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.prompts import message_fit_in, citation_prompt
|
from rag.prompts.generator import tool_call_summary, message_fit_in, citation_prompt
|
||||||
from rag.prompts.prompts import tool_call_summary
|
|
||||||
|
|
||||||
|
|
||||||
class LLMParam(ComponentParamBase):
|
class LLMParam(ComponentParamBase):
|
||||||
@ -82,9 +81,9 @@ class LLMParam(ComponentParamBase):
|
|||||||
|
|
||||||
class LLM(ComponentBase):
|
class LLM(ComponentBase):
|
||||||
component_name = "LLM"
|
component_name = "LLM"
|
||||||
|
|
||||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
def __init__(self, canvas, component_id, param: ComponentParamBase):
|
||||||
super().__init__(canvas, id, param)
|
super().__init__(canvas, component_id, param)
|
||||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id),
|
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id),
|
||||||
self._param.llm_id, max_retries=self._param.max_retries,
|
self._param.llm_id, max_retries=self._param.max_retries,
|
||||||
retry_interval=self._param.delay_after_error
|
retry_interval=self._param.delay_after_error
|
||||||
@ -206,7 +205,7 @@ class LLM(ComponentBase):
|
|||||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
||||||
yield delta(txt)
|
yield delta(txt)
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
def clean_formated_answer(ans: str) -> str:
|
def clean_formated_answer(ans: str) -> str:
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
@ -214,7 +213,7 @@ class LLM(ComponentBase):
|
|||||||
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||||
|
|
||||||
prompt, msg, _ = self._prepare_prompt_variables()
|
prompt, msg, _ = self._prepare_prompt_variables()
|
||||||
error = ""
|
error: str = ""
|
||||||
|
|
||||||
if self._param.output_structure:
|
if self._param.output_structure:
|
||||||
prompt += "\nThe output MUST follow this JSON format:\n"+json.dumps(self._param.output_structure, ensure_ascii=False, indent=2)
|
prompt += "\nThe output MUST follow this JSON format:\n"+json.dumps(self._param.output_structure, ensure_ascii=False, indent=2)
|
||||||
|
|||||||
@ -49,7 +49,7 @@ class MessageParam(ComponentParamBase):
|
|||||||
class Message(ComponentBase):
|
class Message(ComponentBase):
|
||||||
component_name = "Message"
|
component_name = "Message"
|
||||||
|
|
||||||
def get_kwargs(self, script:str, kwargs:dict = {}, delimeter:str=None) -> tuple[str, dict[str, str | list | Any]]:
|
def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]:
|
||||||
for k,v in self.get_input_elements_from_text(script).items():
|
for k,v in self.get_input_elements_from_text(script).items():
|
||||||
if k in kwargs:
|
if k in kwargs:
|
||||||
continue
|
continue
|
||||||
@ -60,8 +60,8 @@ class Message(ComponentBase):
|
|||||||
if isinstance(v, partial):
|
if isinstance(v, partial):
|
||||||
for t in v():
|
for t in v():
|
||||||
ans += t
|
ans += t
|
||||||
elif isinstance(v, list) and delimeter:
|
elif isinstance(v, list) and delimiter:
|
||||||
ans = delimeter.join([str(vv) for vv in v])
|
ans = delimiter.join([str(vv) for vv in v])
|
||||||
elif not isinstance(v, str):
|
elif not isinstance(v, str):
|
||||||
try:
|
try:
|
||||||
ans = json.dumps(v, ensure_ascii=False)
|
ans = json.dumps(v, ensure_ascii=False)
|
||||||
@ -127,7 +127,7 @@ class Message(ComponentBase):
|
|||||||
]
|
]
|
||||||
return any([re.search(p, content) for p in patt])
|
return any([re.search(p, content) for p in patt])
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
rand_cnt = random.choice(self._param.content)
|
rand_cnt = random.choice(self._param.content)
|
||||||
if self._param.stream and not self._is_jinjia2(rand_cnt):
|
if self._param.stream and not self._is_jinjia2(rand_cnt):
|
||||||
|
|||||||
@ -56,7 +56,7 @@ class StringTransform(Message, ABC):
|
|||||||
"type": "line"
|
"type": "line"
|
||||||
} for k, o in self.get_input_elements_from_text(self._param.script).items()}
|
} for k, o in self.get_input_elements_from_text(self._param.script).items()}
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if self._param.method == "split":
|
if self._param.method == "split":
|
||||||
self._split(kwargs.get("line"))
|
self._split(kwargs.get("line"))
|
||||||
@ -90,7 +90,7 @@ class StringTransform(Message, ABC):
|
|||||||
for k,v in kwargs.items():
|
for k,v in kwargs.items():
|
||||||
if not v:
|
if not v:
|
||||||
v = ""
|
v = ""
|
||||||
script = re.sub(k, v, script)
|
script = re.sub(k, lambda match: v, script)
|
||||||
|
|
||||||
self.set_output("result", script)
|
self.set_output("result", script)
|
||||||
|
|
||||||
|
|||||||
@ -61,7 +61,7 @@ class SwitchParam(ComponentParamBase):
|
|||||||
class Switch(ComponentBase, ABC):
|
class Switch(ComponentBase, ABC):
|
||||||
component_name = "Switch"
|
component_name = "Switch"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
for cond in self._param.conditions:
|
for cond in self._param.conditions:
|
||||||
res = []
|
res = []
|
||||||
|
|||||||
@ -61,7 +61,7 @@ class ArXivParam(ToolParamBase):
|
|||||||
class ArXiv(ToolBase, ABC):
|
class ArXiv(ToolBase, ABC):
|
||||||
component_name = "ArXiv"
|
component_name = "ArXiv"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -97,6 +97,6 @@ class ArXiv(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from typing import TypedDict, List, Any
|
|||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
from agent.component.base import ComponentParamBase, ComponentBase
|
||||||
from api.utils import hash_str2int
|
from api.utils import hash_str2int
|
||||||
from rag.llm.chat_model import ToolCallSession
|
from rag.llm.chat_model import ToolCallSession
|
||||||
from rag.prompts.prompts import kb_prompt
|
from rag.prompts.generator import kb_prompt
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
|
|||||||
@ -129,7 +129,7 @@ module.exports = { main };
|
|||||||
class CodeExec(ToolBase, ABC):
|
class CodeExec(ToolBase, ABC):
|
||||||
component_name = "CodeExec"
|
component_name = "CodeExec"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
lang = kwargs.get("lang", self._param.lang)
|
lang = kwargs.get("lang", self._param.lang)
|
||||||
script = kwargs.get("script", self._param.script)
|
script = kwargs.get("script", self._param.script)
|
||||||
@ -157,7 +157,7 @@ class CodeExec(ToolBase, ABC):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:")
|
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run", code_req, resp.status_code)
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
body = resp.json()
|
body = resp.json()
|
||||||
|
|||||||
@ -73,7 +73,7 @@ class DuckDuckGoParam(ToolParamBase):
|
|||||||
class DuckDuckGo(ToolBase, ABC):
|
class DuckDuckGo(ToolBase, ABC):
|
||||||
component_name = "DuckDuckGo"
|
component_name = "DuckDuckGo"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -115,6 +115,6 @@ class DuckDuckGo(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -98,8 +98,8 @@ class EmailParam(ToolParamBase):
|
|||||||
|
|
||||||
class Email(ToolBase, ABC):
|
class Email(ToolBase, ABC):
|
||||||
component_name = "Email"
|
component_name = "Email"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("to_email"):
|
if not kwargs.get("to_email"):
|
||||||
self.set_output("success", False)
|
self.set_output("success", False)
|
||||||
@ -212,4 +212,4 @@ class Email(ToolBase, ABC):
|
|||||||
To: {}
|
To: {}
|
||||||
Subject: {}
|
Subject: {}
|
||||||
Your email is on its way—sit tight!
|
Your email is on its way—sit tight!
|
||||||
""".format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!"))
|
""".format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!"))
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class ExeSQLParam(ToolParamBase):
|
|||||||
self.max_records = 1024
|
self.max_records = 1024
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql'])
|
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2'])
|
||||||
self.check_empty(self.database, "Database name")
|
self.check_empty(self.database, "Database name")
|
||||||
self.check_empty(self.username, "database username")
|
self.check_empty(self.username, "database username")
|
||||||
self.check_empty(self.host, "IP Address")
|
self.check_empty(self.host, "IP Address")
|
||||||
@ -78,7 +78,7 @@ class ExeSQLParam(ToolParamBase):
|
|||||||
class ExeSQL(ToolBase, ABC):
|
class ExeSQL(ToolBase, ABC):
|
||||||
component_name = "ExeSQL"
|
component_name = "ExeSQL"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
|
|
||||||
def convert_decimals(obj):
|
def convert_decimals(obj):
|
||||||
@ -123,6 +123,55 @@ class ExeSQL(ToolBase, ABC):
|
|||||||
r'PWD=' + self._param.password
|
r'PWD=' + self._param.password
|
||||||
)
|
)
|
||||||
db = pyodbc.connect(conn_str)
|
db = pyodbc.connect(conn_str)
|
||||||
|
elif self._param.db_type == 'IBM DB2':
|
||||||
|
import ibm_db
|
||||||
|
conn_str = (
|
||||||
|
f"DATABASE={self._param.database};"
|
||||||
|
f"HOSTNAME={self._param.host};"
|
||||||
|
f"PORT={self._param.port};"
|
||||||
|
f"PROTOCOL=TCPIP;"
|
||||||
|
f"UID={self._param.username};"
|
||||||
|
f"PWD={self._param.password};"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
conn = ibm_db.connect(conn_str, "", "")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("Database Connection Failed! \n" + str(e))
|
||||||
|
|
||||||
|
sql_res = []
|
||||||
|
formalized_content = []
|
||||||
|
for single_sql in sqls:
|
||||||
|
single_sql = single_sql.replace("```", "").strip()
|
||||||
|
if not single_sql:
|
||||||
|
continue
|
||||||
|
single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql)
|
||||||
|
|
||||||
|
stmt = ibm_db.exec_immediate(conn, single_sql)
|
||||||
|
rows = []
|
||||||
|
row = ibm_db.fetch_assoc(stmt)
|
||||||
|
while row and len(rows) < self._param.max_records:
|
||||||
|
rows.append(row)
|
||||||
|
row = ibm_db.fetch_assoc(stmt)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
sql_res.append({"content": "No record in the database!"})
|
||||||
|
continue
|
||||||
|
|
||||||
|
df = pd.DataFrame(rows)
|
||||||
|
for col in df.columns:
|
||||||
|
if pd.api.types.is_datetime64_any_dtype(df[col]):
|
||||||
|
df[col] = df[col].dt.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
df = df.where(pd.notnull(df), None)
|
||||||
|
|
||||||
|
sql_res.append(convert_decimals(df.to_dict(orient="records")))
|
||||||
|
formalized_content.append(df.to_markdown(index=False, floatfmt=".6f"))
|
||||||
|
|
||||||
|
ibm_db.close(conn)
|
||||||
|
|
||||||
|
self.set_output("json", sql_res)
|
||||||
|
self.set_output("formalized_content", "\n\n".join(formalized_content))
|
||||||
|
return self.output("formalized_content")
|
||||||
try:
|
try:
|
||||||
cursor = db.cursor()
|
cursor = db.cursor()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -150,6 +199,8 @@ class ExeSQL(ToolBase, ABC):
|
|||||||
if pd.api.types.is_datetime64_any_dtype(single_res[col]):
|
if pd.api.types.is_datetime64_any_dtype(single_res[col]):
|
||||||
single_res[col] = single_res[col].dt.strftime('%Y-%m-%d')
|
single_res[col] = single_res[col].dt.strftime('%Y-%m-%d')
|
||||||
|
|
||||||
|
single_res = single_res.where(pd.notnull(single_res), None)
|
||||||
|
|
||||||
sql_res.append(convert_decimals(single_res.to_dict(orient='records')))
|
sql_res.append(convert_decimals(single_res.to_dict(orient='records')))
|
||||||
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
|
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
|
||||||
|
|
||||||
|
|||||||
@ -57,7 +57,7 @@ class GitHubParam(ToolParamBase):
|
|||||||
class GitHub(ToolBase, ABC):
|
class GitHub(ToolBase, ABC):
|
||||||
component_name = "GitHub"
|
component_name = "GitHub"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -88,4 +88,4 @@ class GitHub(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!"))
|
return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -116,7 +116,7 @@ class GoogleParam(ToolParamBase):
|
|||||||
class Google(ToolBase, ABC):
|
class Google(ToolBase, ABC):
|
||||||
component_name = "Google"
|
component_name = "Google"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("q"):
|
if not kwargs.get("q"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -154,6 +154,6 @@ class Google(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -63,7 +63,7 @@ class GoogleScholarParam(ToolParamBase):
|
|||||||
class GoogleScholar(ToolBase, ABC):
|
class GoogleScholar(ToolBase, ABC):
|
||||||
component_name = "GoogleScholar"
|
component_name = "GoogleScholar"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -93,4 +93,4 @@ class GoogleScholar(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class PubMedParam(ToolParamBase):
|
|||||||
self.meta:ToolMeta = {
|
self.meta:ToolMeta = {
|
||||||
"name": "pubmed_search",
|
"name": "pubmed_search",
|
||||||
"description": """
|
"description": """
|
||||||
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
|
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
|
||||||
In addition to MEDLINE, PubMed provides access to:
|
In addition to MEDLINE, PubMed provides access to:
|
||||||
- older references from the print version of Index Medicus, back to 1951 and earlier
|
- older references from the print version of Index Medicus, back to 1951 and earlier
|
||||||
- references to some journals before they were indexed in Index Medicus and MEDLINE, for instance Science, BMJ, and Annals of Surgery
|
- references to some journals before they were indexed in Index Medicus and MEDLINE, for instance Science, BMJ, and Annals of Surgery
|
||||||
@ -69,7 +69,7 @@ In addition to MEDLINE, PubMed provides access to:
|
|||||||
class PubMed(ToolBase, ABC):
|
class PubMed(ToolBase, ABC):
|
||||||
component_name = "PubMed"
|
component_name = "PubMed"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -105,4 +105,4 @@ class PubMed(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -23,8 +23,7 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts import kb_prompt
|
from rag.prompts.generator import cross_languages, kb_prompt
|
||||||
from rag.prompts.prompts import cross_languages
|
|
||||||
|
|
||||||
|
|
||||||
class RetrievalParam(ToolParamBase):
|
class RetrievalParam(ToolParamBase):
|
||||||
@ -75,7 +74,7 @@ class RetrievalParam(ToolParamBase):
|
|||||||
class Retrieval(ToolBase, ABC):
|
class Retrieval(ToolBase, ABC):
|
||||||
component_name = "Retrieval"
|
component_name = "Retrieval"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", self._param.empty_response)
|
self.set_output("formalized_content", self._param.empty_response)
|
||||||
@ -163,13 +162,20 @@ class Retrieval(ToolBase, ABC):
|
|||||||
self.set_output("formalized_content", self._param.empty_response)
|
self.set_output("formalized_content", self._param.empty_response)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Format the chunks for JSON output (similar to how other tools do it)
|
||||||
|
json_output = kbinfos["chunks"].copy()
|
||||||
|
|
||||||
self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"])
|
self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"])
|
||||||
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
|
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
|
||||||
|
|
||||||
|
# Set both formalized content and JSON output
|
||||||
self.set_output("formalized_content", form_cnt)
|
self.set_output("formalized_content", form_cnt)
|
||||||
|
self.set_output("json", json_output)
|
||||||
|
|
||||||
return form_cnt
|
return form_cnt
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -77,7 +77,7 @@ class SearXNGParam(ToolParamBase):
|
|||||||
class SearXNG(ToolBase, ABC):
|
class SearXNG(ToolBase, ABC):
|
||||||
component_name = "SearXNG"
|
component_name = "SearXNG"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
# Gracefully handle try-run without inputs
|
# Gracefully handle try-run without inputs
|
||||||
query = kwargs.get("query")
|
query = kwargs.get("query")
|
||||||
@ -94,7 +94,6 @@ class SearXNG(ToolBase, ABC):
|
|||||||
last_e = ""
|
last_e = ""
|
||||||
for _ in range(self._param.max_retries+1):
|
for _ in range(self._param.max_retries+1):
|
||||||
try:
|
try:
|
||||||
# 构建搜索参数
|
|
||||||
search_params = {
|
search_params = {
|
||||||
'q': query,
|
'q': query,
|
||||||
'format': 'json',
|
'format': 'json',
|
||||||
@ -104,33 +103,29 @@ class SearXNG(ToolBase, ABC):
|
|||||||
'pageno': 1
|
'pageno': 1
|
||||||
}
|
}
|
||||||
|
|
||||||
# 发送搜索请求
|
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{searxng_url}/search",
|
f"{searxng_url}/search",
|
||||||
params=search_params,
|
params=search_params,
|
||||||
timeout=10
|
timeout=10
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
# 验证响应数据
|
|
||||||
if not data or not isinstance(data, dict):
|
if not data or not isinstance(data, dict):
|
||||||
raise ValueError("Invalid response from SearXNG")
|
raise ValueError("Invalid response from SearXNG")
|
||||||
|
|
||||||
results = data.get("results", [])
|
results = data.get("results", [])
|
||||||
if not isinstance(results, list):
|
if not isinstance(results, list):
|
||||||
raise ValueError("Invalid results format from SearXNG")
|
raise ValueError("Invalid results format from SearXNG")
|
||||||
|
|
||||||
# 限制结果数量
|
|
||||||
results = results[:self._param.top_n]
|
results = results[:self._param.top_n]
|
||||||
|
|
||||||
# 处理搜索结果
|
|
||||||
self._retrieve_chunks(results,
|
self._retrieve_chunks(results,
|
||||||
get_title=lambda r: r.get("title", ""),
|
get_title=lambda r: r.get("title", ""),
|
||||||
get_url=lambda r: r.get("url", ""),
|
get_url=lambda r: r.get("url", ""),
|
||||||
get_content=lambda r: r.get("content", ""))
|
get_content=lambda r: r.get("content", ""))
|
||||||
|
|
||||||
self.set_output("json", results)
|
self.set_output("json", results)
|
||||||
return self.output("formalized_content")
|
return self.output("formalized_content")
|
||||||
|
|
||||||
@ -151,6 +146,6 @@ class SearXNG(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Searching with SearXNG for relevant results...
|
Searching with SearXNG for relevant results...
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class TavilySearchParam(ToolParamBase):
|
|||||||
self.meta:ToolMeta = {
|
self.meta:ToolMeta = {
|
||||||
"name": "tavily_search",
|
"name": "tavily_search",
|
||||||
"description": """
|
"description": """
|
||||||
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
|
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
|
||||||
When searching:
|
When searching:
|
||||||
- Start with specific query which should focus on just a single aspect.
|
- Start with specific query which should focus on just a single aspect.
|
||||||
- Number of keywords in query should be less than 5.
|
- Number of keywords in query should be less than 5.
|
||||||
@ -101,7 +101,7 @@ When searching:
|
|||||||
class TavilySearch(ToolBase, ABC):
|
class TavilySearch(ToolBase, ABC):
|
||||||
component_name = "TavilySearch"
|
component_name = "TavilySearch"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -136,7 +136,7 @@ class TavilySearch(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|
||||||
@ -199,7 +199,7 @@ class TavilyExtractParam(ToolParamBase):
|
|||||||
class TavilyExtract(ToolBase, ABC):
|
class TavilyExtract(ToolBase, ABC):
|
||||||
component_name = "TavilyExtract"
|
component_name = "TavilyExtract"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
||||||
last_e = None
|
last_e = None
|
||||||
@ -224,4 +224,4 @@ class TavilyExtract(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))
|
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))
|
||||||
|
|||||||
@ -68,7 +68,7 @@ fund selection platform: through AI technology, is committed to providing excell
|
|||||||
class WenCai(ToolBase, ABC):
|
class WenCai(ToolBase, ABC):
|
||||||
component_name = "WenCai"
|
component_name = "WenCai"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("report", "")
|
self.set_output("report", "")
|
||||||
@ -111,4 +111,4 @@ class WenCai(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!"))
|
return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -64,7 +64,7 @@ class WikipediaParam(ToolParamBase):
|
|||||||
class Wikipedia(ToolBase, ABC):
|
class Wikipedia(ToolBase, ABC):
|
||||||
component_name = "Wikipedia"
|
component_name = "Wikipedia"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -99,6 +99,6 @@ class Wikipedia(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -72,7 +72,7 @@ class YahooFinanceParam(ToolParamBase):
|
|||||||
class YahooFinance(ToolBase, ABC):
|
class YahooFinance(ToolBase, ABC):
|
||||||
component_name = "YahooFinance"
|
component_name = "YahooFinance"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("stock_code"):
|
if not kwargs.get("stock_code"):
|
||||||
self.set_output("report", "")
|
self.set_output("report", "")
|
||||||
@ -111,4 +111,4 @@ class YahooFinance(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!"))
|
return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!"))
|
||||||
|
|||||||
@ -27,7 +27,8 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
|||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
from api.db.db_models import close_connection
|
from api.db.db_models import close_connection
|
||||||
from api.db.services import UserService
|
from api.db.services import UserService
|
||||||
from api.utils import CustomJSONEncoder, commands
|
from api.utils.json import CustomJSONEncoder
|
||||||
|
from api.utils import commands
|
||||||
|
|
||||||
from flask_mail import Mail
|
from flask_mail import Mail
|
||||||
from flask_session import Session
|
from flask_session import Session
|
||||||
|
|||||||
@ -39,7 +39,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
|
|||||||
|
|
||||||
from api.utils.file_utils import filename_type, thumbnail
|
from api.utils.file_utils import filename_type, thumbnail
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts import keyword_extraction
|
from rag.prompts.generator import keyword_extraction
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
from api.db.services.canvas_service import UserCanvasService
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
|
|||||||
@ -100,7 +100,7 @@ def save():
|
|||||||
def get(canvas_id):
|
def get(canvas_id):
|
||||||
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||||
return get_data_error_result(message="canvas not found.")
|
return get_data_error_result(message="canvas not found.")
|
||||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||||
return get_json_result(data=c)
|
return get_json_result(data=c)
|
||||||
|
|
||||||
|
|
||||||
@ -243,7 +243,7 @@ def reset():
|
|||||||
|
|
||||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||||
def upload(canvas_id):
|
def upload(canvas_id):
|
||||||
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
|
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="canvas not found.")
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
|
||||||
@ -393,6 +393,22 @@ def test_db_connect():
|
|||||||
cursor = db.cursor()
|
cursor = db.cursor()
|
||||||
cursor.execute("SELECT 1")
|
cursor.execute("SELECT 1")
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
elif req["db_type"] == 'IBM DB2':
|
||||||
|
import ibm_db
|
||||||
|
conn_str = (
|
||||||
|
f"DATABASE={req['database']};"
|
||||||
|
f"HOSTNAME={req['host']};"
|
||||||
|
f"PORT={req['port']};"
|
||||||
|
f"PROTOCOL=TCPIP;"
|
||||||
|
f"UID={req['username']};"
|
||||||
|
f"PWD={req['password']};"
|
||||||
|
)
|
||||||
|
logging.info(conn_str)
|
||||||
|
conn = ibm_db.connect(conn_str, "", "")
|
||||||
|
stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1")
|
||||||
|
ibm_db.fetch_assoc(stmt)
|
||||||
|
ibm_db.close(conn)
|
||||||
|
return get_json_result(data="Database Connection Successful!")
|
||||||
else:
|
else:
|
||||||
return server_error_response("Unsupported database type.")
|
return server_error_response("Unsupported database type.")
|
||||||
if req["db_type"] != 'mssql':
|
if req["db_type"] != 'mssql':
|
||||||
@ -529,7 +545,7 @@ def sessions(canvas_id):
|
|||||||
@manager.route('/prompts', methods=['GET']) # noqa: F821
|
@manager.route('/prompts', methods=['GET']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def prompts():
|
def prompts():
|
||||||
from rag.prompts.prompts import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
||||||
return get_json_result(data={
|
return get_json_result(data={
|
||||||
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
||||||
"plan_generation": NEXT_STEP,
|
"plan_generation": NEXT_STEP,
|
||||||
|
|||||||
@ -33,8 +33,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, server_e
|
|||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
from rag.prompts import cross_languages, keyword_extraction
|
from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
|
||||||
from rag.prompts.prompts import gen_meta_filter
|
|
||||||
from rag.settings import PAGERANK_FLD
|
from rag.settings import PAGERANK_FLD
|
||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import traceback
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from flask import Response, request
|
from flask import Response, request
|
||||||
from flask_login import current_user, login_required
|
from flask_login import current_user, login_required
|
||||||
@ -29,8 +29,8 @@ from api.db.services.search_service import SearchService
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||||
from rag.prompts.prompt_template import load_prompt
|
from rag.prompts.template import load_prompt
|
||||||
from rag.prompts.prompts import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||||
@ -226,7 +226,7 @@ def completion():
|
|||||||
if not is_embedded:
|
if not is_embedded:
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
logging.exception(e)
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
|
|||||||
@ -577,7 +577,7 @@ def change_parser():
|
|||||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if req.get("pipeline_id"):
|
if "pipeline_id" in req:
|
||||||
if doc.pipeline_id == req["pipeline_id"]:
|
if doc.pipeline_id == req["pipeline_id"]:
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"]})
|
DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"]})
|
||||||
|
|||||||
@ -246,6 +246,8 @@ def rm():
|
|||||||
return get_data_error_result(message="File or Folder not found!")
|
return get_data_error_result(message="File or Folder not found!")
|
||||||
if not file.tenant_id:
|
if not file.tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
if file.tenant_id != current_user.id:
|
||||||
|
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -292,6 +294,8 @@ def rename():
|
|||||||
e, file = FileService.get_by_id(req["file_id"])
|
e, file = FileService.get_by_id(req["file_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="File not found!")
|
return get_data_error_result(message="File not found!")
|
||||||
|
if file.tenant_id != current_user.id:
|
||||||
|
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
if file.type != FileType.FOLDER.value \
|
if file.type != FileType.FOLDER.value \
|
||||||
and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||||
file.name.lower()).suffix:
|
file.name.lower()).suffix:
|
||||||
@ -328,6 +332,8 @@ def get(file_id):
|
|||||||
e, file = FileService.get_by_id(file_id)
|
e, file = FileService.get_by_id(file_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
|
if file.tenant_id != current_user.id:
|
||||||
|
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
blob = STORAGE_IMPL.get(file.parent_id, file.location)
|
blob = STORAGE_IMPL.get(file.parent_id, file.location)
|
||||||
if not blob:
|
if not blob:
|
||||||
@ -367,6 +373,8 @@ def move():
|
|||||||
return get_data_error_result(message="File or Folder not found!")
|
return get_data_error_result(message="File or Folder not found!")
|
||||||
if not file.tenant_id:
|
if not file.tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
if file.tenant_id != current_user.id:
|
||||||
|
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
fe, _ = FileService.get_by_id(parent_id)
|
fe, _ = FileService.get_by_id(parent_id)
|
||||||
if not fe:
|
if not fe:
|
||||||
return get_data_error_result(message="Parent Folder not found!")
|
return get_data_error_result(message="Parent Folder not found!")
|
||||||
|
|||||||
@ -40,7 +40,7 @@ from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_
|
|||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
from rag.prompts import cross_languages, keyword_extraction
|
from rag.prompts.generator import cross_languages, keyword_extraction
|
||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
|
|||||||
@ -3,9 +3,11 @@ import re
|
|||||||
|
|
||||||
import flask
|
import flask
|
||||||
from flask import request
|
from flask import request
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.utils.api_utils import server_error_response, token_required
|
from api.utils.api_utils import server_error_response, token_required
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.db import FileType
|
from api.db import FileType
|
||||||
@ -81,16 +83,16 @@ def upload(tenant_id):
|
|||||||
return get_json_result(data=False, message="Can't find this folder!", code=404)
|
return get_json_result(data=False, message="Can't find this folder!", code=404)
|
||||||
|
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
# 文件路径处理
|
# Handle file path
|
||||||
full_path = '/' + file_obj.filename
|
full_path = '/' + file_obj.filename
|
||||||
file_obj_names = full_path.split('/')
|
file_obj_names = full_path.split('/')
|
||||||
file_len = len(file_obj_names)
|
file_len = len(file_obj_names)
|
||||||
|
|
||||||
# 获取文件夹路径ID
|
# Get folder path ID
|
||||||
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
||||||
len_id_list = len(file_id_list)
|
len_id_list = len(file_id_list)
|
||||||
|
|
||||||
# 创建文件夹结构
|
# Crete file folder
|
||||||
if file_len != len_id_list:
|
if file_len != len_id_list:
|
||||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||||
if not e:
|
if not e:
|
||||||
@ -666,3 +668,71 @@ def move(tenant_id):
|
|||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
||||||
|
@token_required
|
||||||
|
def convert(tenant_id):
|
||||||
|
req = request.json
|
||||||
|
kb_ids = req["kb_ids"]
|
||||||
|
file_ids = req["file_ids"]
|
||||||
|
file2documents = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
files = FileService.get_by_ids(file_ids)
|
||||||
|
files_set = dict({file.id: file for file in files})
|
||||||
|
for file_id in file_ids:
|
||||||
|
file = files_set[file_id]
|
||||||
|
if not file:
|
||||||
|
return get_json_result(message="File not found!", code=404)
|
||||||
|
file_ids_list = [file_id]
|
||||||
|
if file.type == FileType.FOLDER.value:
|
||||||
|
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||||
|
for id in file_ids_list:
|
||||||
|
informs = File2DocumentService.get_by_file_id(id)
|
||||||
|
# delete
|
||||||
|
for inform in informs:
|
||||||
|
doc_id = inform.document_id
|
||||||
|
e, doc = DocumentService.get_by_id(doc_id)
|
||||||
|
if not e:
|
||||||
|
return get_json_result(message="Document not found!", code=404)
|
||||||
|
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||||
|
if not tenant_id:
|
||||||
|
return get_json_result(message="Tenant not found!", code=404)
|
||||||
|
if not DocumentService.remove_document(doc, tenant_id):
|
||||||
|
return get_json_result(
|
||||||
|
message="Database error (Document removal)!", code=404)
|
||||||
|
File2DocumentService.delete_by_file_id(id)
|
||||||
|
|
||||||
|
# insert
|
||||||
|
for kb_id in kb_ids:
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not e:
|
||||||
|
return get_json_result(
|
||||||
|
message="Can't find this knowledgebase!", code=404)
|
||||||
|
e, file = FileService.get_by_id(id)
|
||||||
|
if not e:
|
||||||
|
return get_json_result(
|
||||||
|
message="Can't find this file!", code=404)
|
||||||
|
|
||||||
|
doc = DocumentService.insert({
|
||||||
|
"id": get_uuid(),
|
||||||
|
"kb_id": kb.id,
|
||||||
|
"parser_id": FileService.get_parser(file.type, file.name, kb.parser_id),
|
||||||
|
"parser_config": kb.parser_config,
|
||||||
|
"created_by": tenant_id,
|
||||||
|
"type": file.type,
|
||||||
|
"name": file.name,
|
||||||
|
"suffix": Path(file.name).suffix.lstrip("."),
|
||||||
|
"location": file.location,
|
||||||
|
"size": file.size
|
||||||
|
})
|
||||||
|
file2document = File2DocumentService.insert({
|
||||||
|
"id": get_uuid(),
|
||||||
|
"file_id": id,
|
||||||
|
"document_id": doc.id,
|
||||||
|
})
|
||||||
|
|
||||||
|
file2documents.append(file2document.to_json())
|
||||||
|
return get_json_result(data=file2documents)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
@ -38,9 +38,8 @@ from api.db.services.user_service import UserTenantService
|
|||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request
|
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts import chunks_format
|
from rag.prompts.template import load_prompt
|
||||||
from rag.prompts.prompt_template import load_prompt
|
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
||||||
from rag.prompts.prompts import cross_languages, gen_meta_filter, keyword_extraction
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||||
|
|||||||
@ -37,7 +37,8 @@ from timeit import default_timer as timer
|
|||||||
|
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
from api.utils.health import run_health_checks
|
from api.utils.health_utils import run_health_checks
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/version", methods=["GET"]) # noqa: F821
|
@manager.route("/version", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
|
|||||||
@ -34,7 +34,6 @@ from api.db.services.user_service import TenantService, UserService, UserTenantS
|
|||||||
from api.utils import (
|
from api.utils import (
|
||||||
current_timestamp,
|
current_timestamp,
|
||||||
datetime_format,
|
datetime_format,
|
||||||
decrypt,
|
|
||||||
download_img,
|
download_img,
|
||||||
get_format_time,
|
get_format_time,
|
||||||
get_uuid,
|
get_uuid,
|
||||||
@ -46,6 +45,7 @@ from api.utils.api_utils import (
|
|||||||
server_error_response,
|
server_error_response,
|
||||||
validate_request,
|
validate_request,
|
||||||
)
|
)
|
||||||
|
from api.utils.crypt import decrypt
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||||
@ -98,7 +98,14 @@ def login():
|
|||||||
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password")
|
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password")
|
||||||
|
|
||||||
user = UserService.query_user(email, password)
|
user = UserService.query_user(email, password)
|
||||||
if user:
|
|
||||||
|
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||||
|
return get_json_result(
|
||||||
|
data=False,
|
||||||
|
code=settings.RetCode.FORBIDDEN,
|
||||||
|
message="This account has been disabled, please contact the administrator!",
|
||||||
|
)
|
||||||
|
elif user:
|
||||||
response_data = user.to_json()
|
response_data = user.to_json()
|
||||||
user.access_token = get_uuid()
|
user.access_token = get_uuid()
|
||||||
login_user(user)
|
login_user(user)
|
||||||
@ -227,6 +234,9 @@ def oauth_callback(channel):
|
|||||||
# User exists, try to log in
|
# User exists, try to log in
|
||||||
user = users[0]
|
user = users[0]
|
||||||
user.access_token = get_uuid()
|
user.access_token = get_uuid()
|
||||||
|
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||||
|
return redirect("/?error=user_inactive")
|
||||||
|
|
||||||
login_user(user)
|
login_user(user)
|
||||||
user.save()
|
user.save()
|
||||||
return redirect(f"/?auth={user.get_id()}")
|
return redirect(f"/?auth={user.get_id()}")
|
||||||
@ -317,6 +327,8 @@ def github_callback():
|
|||||||
# User has already registered, try to log in
|
# User has already registered, try to log in
|
||||||
user = users[0]
|
user = users[0]
|
||||||
user.access_token = get_uuid()
|
user.access_token = get_uuid()
|
||||||
|
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||||
|
return redirect("/?error=user_inactive")
|
||||||
login_user(user)
|
login_user(user)
|
||||||
user.save()
|
user.save()
|
||||||
return redirect("/?auth=%s" % user.get_id())
|
return redirect("/?auth=%s" % user.get_id())
|
||||||
@ -418,6 +430,8 @@ def feishu_callback():
|
|||||||
|
|
||||||
# User has already registered, try to log in
|
# User has already registered, try to log in
|
||||||
user = users[0]
|
user = users[0]
|
||||||
|
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||||
|
return redirect("/?error=user_inactive")
|
||||||
user.access_token = get_uuid()
|
user.access_token = get_uuid()
|
||||||
login_user(user)
|
login_user(user)
|
||||||
user.save()
|
user.save()
|
||||||
|
|||||||
2
api/common/README.md
Normal file
2
api/common/README.md
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
The python files in this directory are shared between service. They contain common utilities, models, and functions that can be used across various
|
||||||
|
services to ensure consistency and reduce code duplication.
|
||||||
21
api/common/base64.py
Normal file
21
api/common/base64.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import base64
|
||||||
|
|
||||||
|
def encode_to_base64(input_string):
|
||||||
|
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
||||||
|
return base64_encoded.decode('utf-8')
|
||||||
@ -23,6 +23,11 @@ class StatusEnum(Enum):
|
|||||||
INVALID = "0"
|
INVALID = "0"
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveEnum(Enum):
|
||||||
|
ACTIVE = "1"
|
||||||
|
INACTIVE = "0"
|
||||||
|
|
||||||
|
|
||||||
class UserTenantRole(StrEnum):
|
class UserTenantRole(StrEnum):
|
||||||
OWNER = 'owner'
|
OWNER = 'owner'
|
||||||
ADMIN = 'admin'
|
ADMIN = 'admin'
|
||||||
@ -111,7 +116,7 @@ class CanvasCategory(StrEnum):
|
|||||||
Agent = "agent_canvas"
|
Agent = "agent_canvas"
|
||||||
DataFlow = "dataflow_canvas"
|
DataFlow = "dataflow_canvas"
|
||||||
|
|
||||||
VALID_CAVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow}
|
VALID_CANVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow}
|
||||||
|
|
||||||
|
|
||||||
class MCPServerType(StrEnum):
|
class MCPServerType(StrEnum):
|
||||||
|
|||||||
@ -26,12 +26,14 @@ from functools import wraps
|
|||||||
|
|
||||||
from flask_login import UserMixin
|
from flask_login import UserMixin
|
||||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||||
from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||||
|
|
||||||
from api import settings, utils
|
from api import settings, utils
|
||||||
from api.db import ParserType, SerializedType
|
from api.db import ParserType, SerializedType
|
||||||
|
from api.utils.json import json_dumps, json_loads
|
||||||
|
from api.utils.configs import deserialize_b64, serialize_b64
|
||||||
|
|
||||||
|
|
||||||
def singleton(cls, *args, **kw):
|
def singleton(cls, *args, **kw):
|
||||||
@ -70,12 +72,12 @@ class JSONField(LongTextField):
|
|||||||
def db_value(self, value):
|
def db_value(self, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
value = self.default_value
|
value = self.default_value
|
||||||
return utils.json_dumps(value)
|
return json_dumps(value)
|
||||||
|
|
||||||
def python_value(self, value):
|
def python_value(self, value):
|
||||||
if not value:
|
if not value:
|
||||||
return self.default_value
|
return self.default_value
|
||||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||||
|
|
||||||
|
|
||||||
class ListField(JSONField):
|
class ListField(JSONField):
|
||||||
@ -91,21 +93,21 @@ class SerializedField(LongTextField):
|
|||||||
|
|
||||||
def db_value(self, value):
|
def db_value(self, value):
|
||||||
if self._serialized_type == SerializedType.PICKLE:
|
if self._serialized_type == SerializedType.PICKLE:
|
||||||
return utils.serialize_b64(value, to_str=True)
|
return serialize_b64(value, to_str=True)
|
||||||
elif self._serialized_type == SerializedType.JSON:
|
elif self._serialized_type == SerializedType.JSON:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
return utils.json_dumps(value, with_type=True)
|
return json_dumps(value, with_type=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||||
|
|
||||||
def python_value(self, value):
|
def python_value(self, value):
|
||||||
if self._serialized_type == SerializedType.PICKLE:
|
if self._serialized_type == SerializedType.PICKLE:
|
||||||
return utils.deserialize_b64(value)
|
return deserialize_b64(value)
|
||||||
elif self._serialized_type == SerializedType.JSON:
|
elif self._serialized_type == SerializedType.JSON:
|
||||||
if value is None:
|
if value is None:
|
||||||
return {}
|
return {}
|
||||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||||
|
|
||||||
@ -250,36 +252,63 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def execute_sql(self, sql, params=None, commit=True):
|
def execute_sql(self, sql, params=None, commit=True):
|
||||||
from peewee import OperationalError
|
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
return super().execute_sql(sql, params, commit)
|
return super().execute_sql(sql, params, commit)
|
||||||
except OperationalError as e:
|
except (OperationalError, InterfaceError) as e:
|
||||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
error_codes = [2013, 2006]
|
||||||
logging.warning(f"Lost connection (attempt {attempt + 1}/{self.max_retries}): {e}")
|
error_messages = ['', 'Lost connection']
|
||||||
|
should_retry = (
|
||||||
|
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||||
|
(str(e) in error_messages) or
|
||||||
|
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_retry and attempt < self.max_retries:
|
||||||
|
logging.warning(
|
||||||
|
f"Database connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
|
||||||
|
)
|
||||||
self._handle_connection_loss()
|
self._handle_connection_loss()
|
||||||
time.sleep(self.retry_delay * (2**attempt))
|
time.sleep(self.retry_delay * (2 ** attempt))
|
||||||
else:
|
else:
|
||||||
logging.error(f"DB execution failure: {e}")
|
logging.error(f"DB execution failure: {e}")
|
||||||
raise
|
raise
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _handle_connection_loss(self):
|
def _handle_connection_loss(self):
|
||||||
self.close_all()
|
# self.close_all()
|
||||||
self.connect()
|
# self.connect()
|
||||||
|
try:
|
||||||
|
self.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
self.connect()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to reconnect: {e}")
|
||||||
|
time.sleep(0.1)
|
||||||
|
self.connect()
|
||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
from peewee import OperationalError
|
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
return super().begin()
|
return super().begin()
|
||||||
except OperationalError as e:
|
except (OperationalError, InterfaceError) as e:
|
||||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
error_codes = [2013, 2006]
|
||||||
logging.warning(f"Lost connection during transaction (attempt {attempt + 1}/{self.max_retries})")
|
error_messages = ['', 'Lost connection']
|
||||||
|
|
||||||
|
should_retry = (
|
||||||
|
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||||
|
(str(e) in error_messages) or
|
||||||
|
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_retry and attempt < self.max_retries:
|
||||||
|
logging.warning(
|
||||||
|
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
|
||||||
|
)
|
||||||
self._handle_connection_loss()
|
self._handle_connection_loss()
|
||||||
time.sleep(self.retry_delay * (2**attempt))
|
time.sleep(self.retry_delay * (2 ** attempt))
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -299,7 +328,16 @@ class BaseDataBase:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
database_config = settings.DATABASE.copy()
|
database_config = settings.DATABASE.copy()
|
||||||
db_name = database_config.pop("name")
|
db_name = database_config.pop("name")
|
||||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
|
||||||
|
pool_config = {
|
||||||
|
'max_retries': 5,
|
||||||
|
'retry_delay': 1,
|
||||||
|
}
|
||||||
|
database_config.update(pool_config)
|
||||||
|
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(
|
||||||
|
db_name, **database_config
|
||||||
|
)
|
||||||
|
# self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||||
logging.info("init database on cluster mode successfully")
|
logging.info("init database on cluster mode successfully")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@ -32,11 +31,7 @@ from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_l
|
|||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
from api.common.base64 import encode_to_base64
|
||||||
|
|
||||||
def encode_to_base64(input_string):
|
|
||||||
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
|
||||||
return base64_encoded.decode('utf-8')
|
|
||||||
|
|
||||||
|
|
||||||
def init_superuser():
|
def init_superuser():
|
||||||
@ -144,8 +139,9 @@ def init_llm_factory():
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
break
|
break
|
||||||
|
doc_count = DocumentService.get_all_kb_doc_count()
|
||||||
for kb_id in KnowledgebaseService.get_all_ids():
|
for kb_id in KnowledgebaseService.get_all_ids():
|
||||||
KnowledgebaseService.update_document_number_in_init(kb_id=kb_id, doc_num=DocumentService.get_kb_doc_count(kb_id))
|
KnowledgebaseService.update_document_number_in_init(kb_id=kb_id, doc_num=doc_count.get(kb_id, 0))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
0
api/db/joint_services/__init__.py
Normal file
0
api/db/joint_services/__init__.py
Normal file
327
api/db/joint_services/user_account_service.py
Normal file
327
api/db/joint_services/user_account_service.py
Normal file
@ -0,0 +1,327 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from api import settings
|
||||||
|
from api.utils.api_utils import group_by
|
||||||
|
from api.db import FileType, UserTenantRole, ActiveEnum
|
||||||
|
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||||
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
|
from api.db.services.conversation_service import ConversationService
|
||||||
|
from api.db.services.dialog_service import DialogService
|
||||||
|
from api.db.services.document_service import DocumentService
|
||||||
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.db.services.langfuse_service import TenantLangfuseService
|
||||||
|
from api.db.services.llm_service import get_init_tenant_llm
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.mcp_server_service import MCPServerService
|
||||||
|
from api.db.services.search_service import SearchService
|
||||||
|
from api.db.services.task_service import TaskService
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
|
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||||
|
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||||
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
from rag.nlp import search
|
||||||
|
|
||||||
|
|
||||||
|
def create_new_user(user_info: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Add a new user, and create tenant, tenant llm, file folder for new user.
|
||||||
|
:param user_info: {
|
||||||
|
"email": <example@example.com>,
|
||||||
|
"nickname": <str, "name">,
|
||||||
|
"password": <decrypted password>,
|
||||||
|
"login_channel": <enum, "password">,
|
||||||
|
"is_superuser": <bool, role == "admin">,
|
||||||
|
}
|
||||||
|
:return: {
|
||||||
|
"success": <bool>,
|
||||||
|
"user_info": <dict>, # if true, return user_info
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# generate user_id and access_token for user
|
||||||
|
user_id = uuid.uuid1().hex
|
||||||
|
user_info['id'] = user_id
|
||||||
|
user_info['access_token'] = uuid.uuid1().hex
|
||||||
|
# construct tenant info
|
||||||
|
tenant = {
|
||||||
|
"id": user_id,
|
||||||
|
"name": user_info["nickname"] + "‘s Kingdom",
|
||||||
|
"llm_id": settings.CHAT_MDL,
|
||||||
|
"embd_id": settings.EMBEDDING_MDL,
|
||||||
|
"asr_id": settings.ASR_MDL,
|
||||||
|
"parser_ids": settings.PARSERS,
|
||||||
|
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||||
|
"rerank_id": settings.RERANK_MDL,
|
||||||
|
}
|
||||||
|
usr_tenant = {
|
||||||
|
"tenant_id": user_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"invited_by": user_id,
|
||||||
|
"role": UserTenantRole.OWNER,
|
||||||
|
}
|
||||||
|
# construct file folder info
|
||||||
|
file_id = uuid.uuid1().hex
|
||||||
|
file = {
|
||||||
|
"id": file_id,
|
||||||
|
"parent_id": file_id,
|
||||||
|
"tenant_id": user_id,
|
||||||
|
"created_by": user_id,
|
||||||
|
"name": "/",
|
||||||
|
"type": FileType.FOLDER.value,
|
||||||
|
"size": 0,
|
||||||
|
"location": "",
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
tenant_llm = get_init_tenant_llm(user_id)
|
||||||
|
|
||||||
|
if not UserService.save(**user_info):
|
||||||
|
return {"success": False}
|
||||||
|
|
||||||
|
TenantService.insert(**tenant)
|
||||||
|
UserTenantService.insert(**usr_tenant)
|
||||||
|
TenantLLMService.insert_many(tenant_llm)
|
||||||
|
FileService.insert(file)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"user_info": user_info,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as create_error:
|
||||||
|
logging.exception(create_error)
|
||||||
|
# rollback
|
||||||
|
try:
|
||||||
|
TenantService.delete_by_id(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
try:
|
||||||
|
u = UserTenantService.query(tenant_id=user_id)
|
||||||
|
if u:
|
||||||
|
UserTenantService.delete_by_id(u[0].id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
try:
|
||||||
|
TenantLLMService.delete_by_tenant_id(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
try:
|
||||||
|
FileService.delete_by_id(file["id"])
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
# delete user row finally
|
||||||
|
try:
|
||||||
|
UserService.delete_by_id(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
# reraise
|
||||||
|
raise create_error
|
||||||
|
|
||||||
|
|
||||||
|
def delete_user_data(user_id: str) -> dict:
|
||||||
|
# use user_id to delete
|
||||||
|
usr = UserService.filter_by_id(user_id)
|
||||||
|
if not usr:
|
||||||
|
return {"success": False, "message": f"{user_id} can't be found."}
|
||||||
|
# check is inactive and not admin
|
||||||
|
if usr.is_active == ActiveEnum.ACTIVE.value:
|
||||||
|
return {"success": False, "message": f"{user_id} is active and can't be deleted."}
|
||||||
|
if usr.is_superuser:
|
||||||
|
return {"success": False, "message": "Can't delete the super user."}
|
||||||
|
# tenant info
|
||||||
|
tenants = UserTenantService.get_user_tenant_relation_by_user_id(usr.id)
|
||||||
|
owned_tenant = [t for t in tenants if t["role"] == UserTenantRole.OWNER.value]
|
||||||
|
|
||||||
|
done_msg = ''
|
||||||
|
try:
|
||||||
|
# step1. delete owned tenant info
|
||||||
|
if owned_tenant:
|
||||||
|
done_msg += "Start to delete owned tenant.\n"
|
||||||
|
tenant_id = owned_tenant[0]["tenant_id"]
|
||||||
|
kb_ids = KnowledgebaseService.get_kb_ids(usr.id)
|
||||||
|
# step1.1 delete knowledgebase related file and info
|
||||||
|
if kb_ids:
|
||||||
|
# step1.1.1 delete files in storage, remove bucket
|
||||||
|
for kb_id in kb_ids:
|
||||||
|
if STORAGE_IMPL.bucket_exists(kb_id):
|
||||||
|
STORAGE_IMPL.remove_bucket(kb_id)
|
||||||
|
done_msg += f"- Removed {len(kb_ids)} dataset's buckets.\n"
|
||||||
|
# step1.1.2 delete file and document info in db
|
||||||
|
doc_ids = DocumentService.get_all_doc_ids_by_kb_ids(kb_ids)
|
||||||
|
if doc_ids:
|
||||||
|
doc_delete_res = DocumentService.delete_by_ids([i["id"] for i in doc_ids])
|
||||||
|
done_msg += f"- Deleted {doc_delete_res} document records.\n"
|
||||||
|
task_delete_res = TaskService.delete_by_doc_ids([i["id"] for i in doc_ids])
|
||||||
|
done_msg += f"- Deleted {task_delete_res} task records.\n"
|
||||||
|
file_ids = FileService.get_all_file_ids_by_tenant_id(usr.id)
|
||||||
|
if file_ids:
|
||||||
|
file_delete_res = FileService.delete_by_ids([f["id"] for f in file_ids])
|
||||||
|
done_msg += f"- Deleted {file_delete_res} file records.\n"
|
||||||
|
if doc_ids or file_ids:
|
||||||
|
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
|
||||||
|
[i["id"] for i in doc_ids],
|
||||||
|
[f["id"] for f in file_ids]
|
||||||
|
)
|
||||||
|
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||||
|
# step1.1.3 delete chunk in es
|
||||||
|
r = settings.docStoreConn.delete({"kb_id": kb_ids},
|
||||||
|
search.index_name(tenant_id), kb_ids)
|
||||||
|
done_msg += f"- Deleted {r} chunk records.\n"
|
||||||
|
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
|
||||||
|
done_msg += f"- Deleted {kb_delete_res} knowledgebase records.\n"
|
||||||
|
# step1.1.4 delete agents
|
||||||
|
agent_delete_res = delete_user_agents(usr.id)
|
||||||
|
done_msg += f"- Deleted {agent_delete_res['agents_deleted_count']} agent, {agent_delete_res['version_deleted_count']} versions records.\n"
|
||||||
|
# step1.1.5 delete dialogs
|
||||||
|
dialog_delete_res = delete_user_dialogs(usr.id)
|
||||||
|
done_msg += f"- Deleted {dialog_delete_res['dialogs_deleted_count']} dialogs, {dialog_delete_res['conversations_deleted_count']} conversations, {dialog_delete_res['api_token_deleted_count']} api tokens, {dialog_delete_res['api4conversation_deleted_count']} api4conversations.\n"
|
||||||
|
# step1.1.6 delete mcp server
|
||||||
|
mcp_delete_res = MCPServerService.delete_by_tenant_id(usr.id)
|
||||||
|
done_msg += f"- Deleted {mcp_delete_res} MCP server.\n"
|
||||||
|
# step1.1.7 delete search
|
||||||
|
search_delete_res = SearchService.delete_by_tenant_id(usr.id)
|
||||||
|
done_msg += f"- Deleted {search_delete_res} search records.\n"
|
||||||
|
# step1.2 delete tenant_llm and tenant_langfuse
|
||||||
|
llm_delete_res = TenantLLMService.delete_by_tenant_id(tenant_id)
|
||||||
|
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
|
||||||
|
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
|
||||||
|
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
|
||||||
|
# step1.3 delete own tenant
|
||||||
|
tenant_delete_res = TenantService.delete_by_id(tenant_id)
|
||||||
|
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
|
||||||
|
# step2 delete user-tenant relation
|
||||||
|
if tenants:
|
||||||
|
# step2.1 delete docs and files in joined team
|
||||||
|
joined_tenants = [t for t in tenants if t["role"] == UserTenantRole.NORMAL.value]
|
||||||
|
if joined_tenants:
|
||||||
|
done_msg += "Start to delete data in joined tenants.\n"
|
||||||
|
created_documents = DocumentService.get_all_docs_by_creator_id(usr.id)
|
||||||
|
if created_documents:
|
||||||
|
# step2.1.1 delete files
|
||||||
|
doc_file_info = File2DocumentService.get_by_document_ids([d['id'] for d in created_documents])
|
||||||
|
created_files = FileService.get_by_ids([f['file_id'] for f in doc_file_info])
|
||||||
|
if created_files:
|
||||||
|
# step2.1.1.1 delete file in storage
|
||||||
|
for f in created_files:
|
||||||
|
STORAGE_IMPL.rm(f.parent_id, f.location)
|
||||||
|
done_msg += f"- Deleted {len(created_files)} uploaded file.\n"
|
||||||
|
# step2.1.1.2 delete file record
|
||||||
|
file_delete_res = FileService.delete_by_ids([f.id for f in created_files])
|
||||||
|
done_msg += f"- Deleted {file_delete_res} file records.\n"
|
||||||
|
# step2.1.2 delete document-file relation record
|
||||||
|
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
|
||||||
|
[d['id'] for d in created_documents],
|
||||||
|
[f.id for f in created_files]
|
||||||
|
)
|
||||||
|
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||||
|
# step2.1.3 delete chunks
|
||||||
|
doc_groups = group_by(created_documents, "tenant_id")
|
||||||
|
kb_grouped_doc = {k: group_by(v, "kb_id") for k, v in doc_groups.items()}
|
||||||
|
# chunks in {'tenant_id': {'kb_id': [{'id': doc_id}]}} structure
|
||||||
|
chunk_delete_res = 0
|
||||||
|
kb_doc_info = {}
|
||||||
|
for _tenant_id, kb_doc in kb_grouped_doc.items():
|
||||||
|
for _kb_id, docs in kb_doc.items():
|
||||||
|
chunk_delete_res += settings.docStoreConn.delete(
|
||||||
|
{"doc_id": [d["id"] for d in docs]},
|
||||||
|
search.index_name(_tenant_id), _kb_id
|
||||||
|
)
|
||||||
|
# record doc info
|
||||||
|
if _kb_id in kb_doc_info.keys():
|
||||||
|
kb_doc_info[_kb_id]['doc_num'] += 1
|
||||||
|
kb_doc_info[_kb_id]['token_num'] += sum([d["token_num"] for d in docs])
|
||||||
|
kb_doc_info[_kb_id]['chunk_num'] += sum([d["chunk_num"] for d in docs])
|
||||||
|
else:
|
||||||
|
kb_doc_info[_kb_id] = {
|
||||||
|
'doc_num': 1,
|
||||||
|
'token_num': sum([d["token_num"] for d in docs]),
|
||||||
|
'chunk_num': sum([d["chunk_num"] for d in docs])
|
||||||
|
}
|
||||||
|
done_msg += f"- Deleted {chunk_delete_res} chunks.\n"
|
||||||
|
# step2.1.4 delete tasks
|
||||||
|
task_delete_res = TaskService.delete_by_doc_ids([d['id'] for d in created_documents])
|
||||||
|
done_msg += f"- Deleted {task_delete_res} tasks.\n"
|
||||||
|
# step2.1.5 delete document record
|
||||||
|
doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents])
|
||||||
|
done_msg += f"- Deleted {doc_delete_res} documents.\n"
|
||||||
|
# step2.1.6 update knowledge base doc&chunk&token cnt
|
||||||
|
for kb_id, doc_num in kb_doc_info.items():
|
||||||
|
KnowledgebaseService.decrease_document_num_in_delete(kb_id, doc_num)
|
||||||
|
|
||||||
|
# step2.2 delete relation
|
||||||
|
user_tenant_delete_res = UserTenantService.delete_by_ids([t["id"] for t in tenants])
|
||||||
|
done_msg += f"- Deleted {user_tenant_delete_res} user-tenant records.\n"
|
||||||
|
# step3 finally delete user
|
||||||
|
user_delete_res = UserService.delete_by_id(usr.id)
|
||||||
|
done_msg += f"- Deleted {user_delete_res} user.\nDelete done!"
|
||||||
|
|
||||||
|
return {"success": True, "message": f"Successfully deleted user. Details:\n{done_msg}"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
return {"success": False, "message": f"Error: {str(e)}. Already done:\n{done_msg}"}
|
||||||
|
|
||||||
|
|
||||||
|
def delete_user_agents(user_id: str) -> dict:
|
||||||
|
"""
|
||||||
|
use user_id to delete
|
||||||
|
:return: {
|
||||||
|
"agents_deleted_count": 1,
|
||||||
|
"version_deleted_count": 2
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
agents_deleted_count, agents_version_deleted_count = 0, 0
|
||||||
|
user_agents = UserCanvasService.get_all_agents_by_tenant_ids([user_id], user_id)
|
||||||
|
if user_agents:
|
||||||
|
agents_version = UserCanvasVersionService.get_all_canvas_version_by_canvas_ids([a['id'] for a in user_agents])
|
||||||
|
agents_version_deleted_count = UserCanvasVersionService.delete_by_ids([v['id'] for v in agents_version])
|
||||||
|
agents_deleted_count = UserCanvasService.delete_by_ids([a['id'] for a in user_agents])
|
||||||
|
return {
|
||||||
|
"agents_deleted_count": agents_deleted_count,
|
||||||
|
"version_deleted_count": agents_version_deleted_count
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def delete_user_dialogs(user_id: str) -> dict:
|
||||||
|
"""
|
||||||
|
use user_id to delete
|
||||||
|
:return: {
|
||||||
|
"dialogs_deleted_count": 1,
|
||||||
|
"conversations_deleted_count": 1,
|
||||||
|
"api_token_deleted_count": 2,
|
||||||
|
"api4conversation_deleted_count": 2
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
dialog_deleted_count, conversations_deleted_count, api_token_deleted_count, api4conversation_deleted_count = 0, 0, 0, 0
|
||||||
|
user_dialogs = DialogService.get_all_dialogs_by_tenant_id(user_id)
|
||||||
|
if user_dialogs:
|
||||||
|
# delete conversation
|
||||||
|
conversations = ConversationService.get_all_conversation_by_dialog_ids([ud['id'] for ud in user_dialogs])
|
||||||
|
conversations_deleted_count = ConversationService.delete_by_ids([c['id'] for c in conversations])
|
||||||
|
# delete api token
|
||||||
|
api_token_deleted_count = APITokenService.delete_by_tenant_id(user_id)
|
||||||
|
# delete api for conversation
|
||||||
|
api4conversation_deleted_count = API4ConversationService.delete_by_dialog_ids([ud['id'] for ud in user_dialogs])
|
||||||
|
# delete dialog at last
|
||||||
|
dialog_deleted_count = DialogService.delete_by_ids([ud['id'] for ud in user_dialogs])
|
||||||
|
return {
|
||||||
|
"dialogs_deleted_count": dialog_deleted_count,
|
||||||
|
"conversations_deleted_count": conversations_deleted_count,
|
||||||
|
"api_token_deleted_count": api_token_deleted_count,
|
||||||
|
"api4conversation_deleted_count": api4conversation_deleted_count
|
||||||
|
}
|
||||||
@ -19,7 +19,7 @@ from pathlib import PurePath
|
|||||||
from .user_service import UserService as UserService
|
from .user_service import UserService as UserService
|
||||||
|
|
||||||
|
|
||||||
def split_name_counter(filename: str) -> tuple[str, int | None]:
|
def _split_name_counter(filename: str) -> tuple[str, int | None]:
|
||||||
"""
|
"""
|
||||||
Splits a filename into main part and counter (if present in parentheses).
|
Splits a filename into main part and counter (if present in parentheses).
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ def duplicate_name(query_func, **kwargs) -> str:
|
|||||||
stem = path.stem
|
stem = path.stem
|
||||||
suffix = path.suffix
|
suffix = path.suffix
|
||||||
|
|
||||||
main_part, counter = split_name_counter(stem)
|
main_part, counter = _split_name_counter(stem)
|
||||||
counter = counter + 1 if counter else 1
|
counter = counter + 1 if counter else 1
|
||||||
|
|
||||||
new_name = f"{main_part}({counter}){suffix}"
|
new_name = f"{main_part}({counter}){suffix}"
|
||||||
|
|||||||
@ -35,6 +35,11 @@ class APITokenService(CommonService):
|
|||||||
cls.model.token == token
|
cls.model.token == token
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_by_tenant_id(cls, tenant_id):
|
||||||
|
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||||
|
|
||||||
|
|
||||||
class API4ConversationService(CommonService):
|
class API4ConversationService(CommonService):
|
||||||
model = API4Conversation
|
model = API4Conversation
|
||||||
@ -100,3 +105,8 @@ class API4ConversationService(CommonService):
|
|||||||
cls.model.create_date <= to_date,
|
cls.model.create_date <= to_date,
|
||||||
cls.model.source == source
|
cls.model.source == source
|
||||||
).group_by(cls.model.create_date.truncate("day")).dicts()
|
).group_by(cls.model.create_date.truncate("day")).dicts()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_by_dialog_ids(cls, dialog_ids):
|
||||||
|
return cls.model.delete().where(cls.model.dialog_id.in_(dialog_ids)).execute()
|
||||||
|
|||||||
@ -18,7 +18,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
from api.db import CanvasCategory
|
from api.db import CanvasCategory, TenantPermission
|
||||||
from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation
|
from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation
|
||||||
from api.db.services.api_service import API4ConversationService
|
from api.db.services.api_service import API4ConversationService
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
@ -63,7 +63,38 @@ class UserCanvasService(CommonService):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_tenant_id(cls, pid):
|
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
|
||||||
|
# will get all permitted agents, be cautious
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.title,
|
||||||
|
cls.model.permission,
|
||||||
|
cls.model.canvas_type,
|
||||||
|
cls.model.canvas_category
|
||||||
|
]
|
||||||
|
# find team agents and owned agents
|
||||||
|
agents = cls.model.select(*fields).where(
|
||||||
|
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
|
||||||
|
cls.model.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# sort by create_time, asc
|
||||||
|
agents.order_by(cls.model.create_time.asc())
|
||||||
|
# maybe cause slow query by deep paginate, optimize later
|
||||||
|
offset, limit = 0, 50
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
ag_batch = agents.offset(offset).limit(limit)
|
||||||
|
_temp = list(ag_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_by_canvas_id(cls, pid):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
@ -138,7 +169,7 @@ class UserCanvasService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def accessible(cls, canvas_id, tenant_id):
|
def accessible(cls, canvas_id, tenant_id):
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||||
if not e:
|
if not e:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@ -14,12 +14,24 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||||
import peewee
|
import peewee
|
||||||
|
from peewee import InterfaceError, OperationalError
|
||||||
|
|
||||||
from api.db.db_models import DB
|
from api.db.db_models import DB
|
||||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||||
|
|
||||||
|
def retry_db_operation(func):
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=1, max=5),
|
||||||
|
retry=retry_if_exception_type((InterfaceError, OperationalError)),
|
||||||
|
before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
class CommonService:
|
class CommonService:
|
||||||
"""Base service class that provides common database operations.
|
"""Base service class that provides common database operations.
|
||||||
@ -202,6 +214,7 @@ class CommonService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
|
@retry_db_operation
|
||||||
def update_by_id(cls, pid, data):
|
def update_by_id(cls, pid, data):
|
||||||
# Update a single record by ID
|
# Update a single record by ID
|
||||||
# Args:
|
# Args:
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from api.db.services.dialog_service import DialogService, chat
|
|||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from rag.prompts import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
|
|
||||||
|
|
||||||
class ConversationService(CommonService):
|
class ConversationService(CommonService):
|
||||||
@ -48,6 +48,21 @@ class ConversationService(CommonService):
|
|||||||
|
|
||||||
return list(sessions.dicts())
|
return list(sessions.dicts())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_conversation_by_dialog_ids(cls, dialog_ids):
|
||||||
|
sessions = cls.model.select().where(cls.model.dialog_id.in_(dialog_ids))
|
||||||
|
sessions.order_by(cls.model.create_time.asc())
|
||||||
|
offset, limit = 0, 100
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
s_batch = sessions.offset(offset).limit(limit)
|
||||||
|
_temp = list(s_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
def structure_answer(conv, ans, message_id, session_id):
|
def structure_answer(conv, ans, message_id, session_id):
|
||||||
reference = ans["reference"]
|
reference = ans["reference"]
|
||||||
|
|||||||
@ -39,8 +39,8 @@ from graphrag.general.mind_map_extractor import MindMapExtractor
|
|||||||
from rag.app.resume import forbidden_select_fields4resume
|
from rag.app.resume import forbidden_select_fields4resume
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp.search import index_name
|
from rag.nlp.search import index_name
|
||||||
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
|
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||||
from rag.prompts.prompts import gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||||
from rag.utils import num_tokens_from_string, rmSpace
|
from rag.utils import num_tokens_from_string, rmSpace
|
||||||
from rag.utils.tavily_conn import Tavily
|
from rag.utils.tavily_conn import Tavily
|
||||||
|
|
||||||
@ -159,6 +159,22 @@ class DialogService(CommonService):
|
|||||||
|
|
||||||
return list(dialogs.dicts()), count
|
return list(dialogs.dicts()), count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_dialogs_by_tenant_id(cls, tenant_id):
|
||||||
|
fields = [cls.model.id]
|
||||||
|
dialogs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||||
|
dialogs.order_by(cls.model.create_time.asc())
|
||||||
|
offset, limit = 0, 100
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
d_batch = dialogs.offset(offset).limit(limit)
|
||||||
|
_temp = list(d_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
def chat_solo(dialog, messages, stream=True):
|
def chat_solo(dialog, messages, stream=True):
|
||||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||||
@ -176,7 +192,7 @@ def chat_solo(dialog, messages, stream=True):
|
|||||||
delta_ans = ""
|
delta_ans = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||||
answer = ans
|
answer = ans
|
||||||
delta_ans = ans[len(last_ans) :]
|
delta_ans = ans[len(last_ans):]
|
||||||
if num_tokens_from_string(delta_ans) < 16:
|
if num_tokens_from_string(delta_ans) < 16:
|
||||||
continue
|
continue
|
||||||
last_ans = answer
|
last_ans = answer
|
||||||
@ -261,13 +277,13 @@ def convert_conditions(metadata_condition):
|
|||||||
"not is": "≠"
|
"not is": "≠"
|
||||||
}
|
}
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
||||||
"key": cond["name"],
|
"key": cond["name"],
|
||||||
"value": cond["value"]
|
"value": cond["value"]
|
||||||
}
|
}
|
||||||
for cond in metadata_condition.get("conditions", [])
|
for cond in metadata_condition.get("conditions", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def meta_filter(metas: dict, filters: list[dict]):
|
def meta_filter(metas: dict, filters: list[dict]):
|
||||||
@ -284,19 +300,19 @@ def meta_filter(metas: dict, filters: list[dict]):
|
|||||||
value = str(value)
|
value = str(value)
|
||||||
|
|
||||||
for conds in [
|
for conds in [
|
||||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||||
(operator == "empty", not input),
|
(operator == "empty", not input),
|
||||||
(operator == "not empty", input),
|
(operator == "not empty", input),
|
||||||
(operator == "=", input == value),
|
(operator == "=", input == value),
|
||||||
(operator == "≠", input != value),
|
(operator == "≠", input != value),
|
||||||
(operator == ">", input > value),
|
(operator == ">", input > value),
|
||||||
(operator == "<", input < value),
|
(operator == "<", input < value),
|
||||||
(operator == "≥", input >= value),
|
(operator == "≥", input >= value),
|
||||||
(operator == "≤", input <= value),
|
(operator == "≤", input <= value),
|
||||||
]:
|
]:
|
||||||
try:
|
try:
|
||||||
if all(conds):
|
if all(conds):
|
||||||
ids.extend(docids)
|
ids.extend(docids)
|
||||||
@ -456,7 +472,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||||
if prompt_config.get("use_kg"):
|
if prompt_config.get("use_kg"):
|
||||||
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
||||||
|
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
kbinfos["chunks"].insert(0, ck)
|
kbinfos["chunks"].insert(0, ck)
|
||||||
|
|
||||||
@ -467,7 +484,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
retrieval_ts = timer()
|
retrieval_ts = timer()
|
||||||
if not knowledges and prompt_config.get("empty_response"):
|
if not knowledges and prompt_config.get("empty_response"):
|
||||||
empty_res = prompt_config["empty_response"]
|
empty_res = prompt_config["empty_response"]
|
||||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)}
|
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||||
|
"audio_binary": tts(tts_mdl, empty_res)}
|
||||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||||
|
|
||||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||||
@ -565,7 +583,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
|
|
||||||
if langfuse_tracer:
|
if langfuse_tracer:
|
||||||
langfuse_generation = langfuse_tracer.start_generation(
|
langfuse_generation = langfuse_tracer.start_generation(
|
||||||
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
|
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"],
|
||||||
|
input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
@ -575,12 +594,12 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
if thought:
|
if thought:
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
answer = ans
|
answer = ans
|
||||||
delta_ans = ans[len(last_ans) :]
|
delta_ans = ans[len(last_ans):]
|
||||||
if num_tokens_from_string(delta_ans) < 16:
|
if num_tokens_from_string(delta_ans) < 16:
|
||||||
continue
|
continue
|
||||||
last_ans = answer
|
last_ans = answer
|
||||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||||
delta_ans = answer[len(last_ans) :]
|
delta_ans = answer[len(last_ans):]
|
||||||
if delta_ans:
|
if delta_ans:
|
||||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||||
yield decorate_answer(thought + answer)
|
yield decorate_answer(thought + answer)
|
||||||
@ -676,7 +695,9 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
|
|
||||||
# compose Markdown table
|
# compose Markdown table
|
||||||
columns = (
|
columns = (
|
||||||
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
"|" + "|".join(
|
||||||
|
[re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
|
||||||
|
"|Source|" if docid_idx and docid_idx else "|")
|
||||||
)
|
)
|
||||||
|
|
||||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||||
@ -753,7 +774,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
doc_ids = None
|
doc_ids = None
|
||||||
|
|
||||||
kbinfos = retriever.retrieval(
|
kbinfos = retriever.retrieval(
|
||||||
question = question,
|
question=question,
|
||||||
embd_mdl=embd_mdl,
|
embd_mdl=embd_mdl,
|
||||||
tenant_ids=tenant_ids,
|
tenant_ids=tenant_ids,
|
||||||
kb_ids=kb_ids,
|
kb_ids=kb_ids,
|
||||||
@ -775,7 +796,8 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
|
|
||||||
def decorate_answer(answer):
|
def decorate_answer(answer):
|
||||||
nonlocal knowledges, kbinfos, sys_prompt
|
nonlocal knowledges, kbinfos, sys_prompt
|
||||||
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
|
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]],
|
||||||
|
embd_mdl, tkweight=0.7, vtweight=0.3)
|
||||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||||
if not recall_docs:
|
if not recall_docs:
|
||||||
|
|||||||
@ -243,6 +243,46 @@ class DocumentService(CommonService):
|
|||||||
|
|
||||||
return int(query.scalar()) or 0
|
return int(query.scalar()) or 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_doc_ids_by_kb_ids(cls, kb_ids):
|
||||||
|
fields = [cls.model.id]
|
||||||
|
docs = cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids))
|
||||||
|
docs.order_by(cls.model.create_time.asc())
|
||||||
|
# maybe cause slow query by deep paginate, optimize later
|
||||||
|
offset, limit = 0, 100
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
doc_batch = docs.offset(offset).limit(limit)
|
||||||
|
_temp = list(doc_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_docs_by_creator_id(cls, creator_id):
|
||||||
|
fields = [
|
||||||
|
cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id
|
||||||
|
]
|
||||||
|
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(
|
||||||
|
cls.model.created_by == creator_id
|
||||||
|
)
|
||||||
|
docs.order_by(cls.model.create_time.asc())
|
||||||
|
# maybe cause slow query by deep paginate, optimize later
|
||||||
|
offset, limit = 0, 100
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
doc_batch = docs.offset(offset).limit(limit)
|
||||||
|
_temp = list(doc_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def insert(cls, doc):
|
def insert(cls, doc):
|
||||||
@ -517,9 +557,6 @@ class DocumentService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_doc_id_by_doc_name(cls, doc_name):
|
def get_doc_id_by_doc_name(cls, doc_name):
|
||||||
"""
|
|
||||||
highly rely on the strict deduplication guarantee from Document
|
|
||||||
"""
|
|
||||||
fields = [cls.model.id]
|
fields = [cls.model.id]
|
||||||
doc_id = cls.model.select(*fields) \
|
doc_id = cls.model.select(*fields) \
|
||||||
.where(cls.model.name == doc_name)
|
.where(cls.model.name == doc_name)
|
||||||
@ -681,8 +718,16 @@ class DocumentService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_doc_count(cls, kb_id):
|
def get_kb_doc_count(cls, kb_id):
|
||||||
return len(cls.model.select(cls.model.id).where(
|
return cls.model.select().where(cls.model.kb_id == kb_id).count()
|
||||||
cls.model.kb_id == kb_id).dicts())
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_kb_doc_count(cls):
|
||||||
|
result = {}
|
||||||
|
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias('count')).group_by(cls.model.kb_id)
|
||||||
|
for row in rows:
|
||||||
|
result[row.kb_id] = row.count
|
||||||
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
|
|||||||
@ -38,6 +38,12 @@ class File2DocumentService(CommonService):
|
|||||||
objs = cls.model.select().where(cls.model.document_id == document_id)
|
objs = cls.model.select().where(cls.model.document_id == document_id)
|
||||||
return objs
|
return objs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_by_document_ids(cls, document_ids):
|
||||||
|
objs = cls.model.select().where(cls.model.document_id.in_(document_ids))
|
||||||
|
return list(objs.dicts())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def insert(cls, obj):
|
def insert(cls, obj):
|
||||||
@ -50,6 +56,15 @@ class File2DocumentService(CommonService):
|
|||||||
def delete_by_file_id(cls, file_id):
|
def delete_by_file_id(cls, file_id):
|
||||||
return cls.model.delete().where(cls.model.file_id == file_id).execute()
|
return cls.model.delete().where(cls.model.file_id == file_id).execute()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_by_document_ids_or_file_ids(cls, document_ids, file_ids):
|
||||||
|
if not document_ids:
|
||||||
|
return cls.model.delete().where(cls.model.file_id.in_(file_ids)).execute()
|
||||||
|
elif not file_ids:
|
||||||
|
return cls.model.delete().where(cls.model.document_id.in_(document_ids)).execute()
|
||||||
|
return cls.model.delete().where(cls.model.document_id.in_(document_ids) | cls.model.file_id.in_(file_ids)).execute()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def delete_by_document_id(cls, doc_id):
|
def delete_by_document_id(cls, doc_id):
|
||||||
|
|||||||
@ -161,6 +161,23 @@ class FileService(CommonService):
|
|||||||
result_ids.append(folder_id)
|
result_ids.append(folder_id)
|
||||||
return result_ids
|
return result_ids
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_file_ids_by_tenant_id(cls, tenant_id):
|
||||||
|
fields = [cls.model.id]
|
||||||
|
files = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||||
|
files.order_by(cls.model.create_time.asc())
|
||||||
|
offset, limit = 0, 100
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
file_batch = files.offset(offset).limit(limit)
|
||||||
|
_temp = list(file_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def create_folder(cls, file, parent_id, name, count):
|
def create_folder(cls, file, parent_id, name, count):
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from datetime import datetime
|
|||||||
from peewee import fn, JOIN
|
from peewee import fn, JOIN
|
||||||
|
|
||||||
from api.db import StatusEnum, TenantPermission
|
from api.db import StatusEnum, TenantPermission
|
||||||
from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant, UserCanvas
|
from api.db.db_models import DB, Document, Knowledgebase, User, UserTenant, UserCanvas
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.utils import current_timestamp, datetime_format
|
from api.utils import current_timestamp, datetime_format
|
||||||
|
|
||||||
@ -190,6 +190,41 @@ class KnowledgebaseService(CommonService):
|
|||||||
|
|
||||||
return list(kbs.dicts()), count
|
return list(kbs.dicts()), count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id):
|
||||||
|
# will get all permitted kb, be cautious.
|
||||||
|
fields = [
|
||||||
|
cls.model.name,
|
||||||
|
cls.model.language,
|
||||||
|
cls.model.permission,
|
||||||
|
cls.model.doc_num,
|
||||||
|
cls.model.token_num,
|
||||||
|
cls.model.chunk_num,
|
||||||
|
cls.model.status,
|
||||||
|
cls.model.create_date,
|
||||||
|
cls.model.update_date
|
||||||
|
]
|
||||||
|
# find team kb and owned kb
|
||||||
|
kbs = cls.model.select(*fields).where(
|
||||||
|
(cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | (
|
||||||
|
cls.model.tenant_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# sort by create_time asc
|
||||||
|
kbs.order_by(cls.model.create_time.asc())
|
||||||
|
# maybe cause slow query by deep paginate, optimize later.
|
||||||
|
offset, limit = 0, 50
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
kb_batch = kbs.offset(offset).limit(limit)
|
||||||
|
_temp = list(kb_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_ids(cls, tenant_id):
|
def get_kb_ids(cls, tenant_id):
|
||||||
@ -226,7 +261,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
cls.model.chunk_num,
|
cls.model.chunk_num,
|
||||||
cls.model.parser_id,
|
cls.model.parser_id,
|
||||||
cls.model.pipeline_id,
|
cls.model.pipeline_id,
|
||||||
UserCanvas.title,
|
UserCanvas.title.alias("pipeline_name"),
|
||||||
UserCanvas.avatar.alias("pipeline_avatar"),
|
UserCanvas.avatar.alias("pipeline_avatar"),
|
||||||
cls.model.parser_config,
|
cls.model.parser_config,
|
||||||
cls.model.pagerank,
|
cls.model.pagerank,
|
||||||
@ -240,16 +275,14 @@ class KnowledgebaseService(CommonService):
|
|||||||
cls.model.update_time
|
cls.model.update_time
|
||||||
]
|
]
|
||||||
kbs = cls.model.select(*fields)\
|
kbs = cls.model.select(*fields)\
|
||||||
.join(Tenant, on=((Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value)))\
|
|
||||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||||
.where(
|
.where(
|
||||||
(cls.model.id == kb_id),
|
(cls.model.id == kb_id),
|
||||||
(cls.model.status == StatusEnum.VALID.value)
|
(cls.model.status == StatusEnum.VALID.value)
|
||||||
)
|
).dicts()
|
||||||
if not kbs:
|
if not kbs:
|
||||||
return
|
return
|
||||||
d = kbs[0].to_dict()
|
return kbs[0]
|
||||||
return d
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
@ -447,3 +480,17 @@ class KnowledgebaseService(CommonService):
|
|||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def decrease_document_num_in_delete(cls, kb_id, doc_num_info: dict):
|
||||||
|
kb_row = cls.model.get_by_id(kb_id)
|
||||||
|
if not kb_row:
|
||||||
|
raise RuntimeError(f"kb_id {kb_id} does not exist")
|
||||||
|
update_dict = {
|
||||||
|
'doc_num': kb_row.doc_num - doc_num_info['doc_num'],
|
||||||
|
'chunk_num': kb_row.chunk_num - doc_num_info['chunk_num'],
|
||||||
|
'token_num': kb_row.token_num - doc_num_info['token_num'],
|
||||||
|
'update_time': current_timestamp(),
|
||||||
|
'update_date': datetime_format(datetime.now())
|
||||||
|
}
|
||||||
|
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()
|
||||||
|
|||||||
@ -51,6 +51,11 @@ class TenantLangfuseService(CommonService):
|
|||||||
except peewee.DoesNotExist:
|
except peewee.DoesNotExist:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_ty_tenant_id(cls, tenant_id):
|
||||||
|
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_by_tenant(cls, tenant_id, langfuse_keys):
|
def update_by_tenant(cls, tenant_id, langfuse_keys):
|
||||||
langfuse_keys["update_time"] = current_timestamp()
|
langfuse_keys["update_time"] = current_timestamp()
|
||||||
|
|||||||
@ -84,3 +84,8 @@ class MCPServerService(CommonService):
|
|||||||
return bool(mcp_server), mcp_server
|
return bool(mcp_server), mcp_server
|
||||||
except Exception:
|
except Exception:
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_by_tenant_id(cls, tenant_id: str):
|
||||||
|
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||||
|
|||||||
@ -110,3 +110,8 @@ class SearchService(CommonService):
|
|||||||
query = query.paginate(page_number, items_per_page)
|
query = query.paginate(page_number, items_per_page)
|
||||||
|
|
||||||
return list(query.dicts()), count
|
return list(query.dicts()), count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_by_tenant_id(cls, tenant_id):
|
||||||
|
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||||
|
|||||||
@ -316,6 +316,12 @@ class TaskService(CommonService):
|
|||||||
process_duration = (datetime.now() - task.begin_at).total_seconds()
|
process_duration = (datetime.now() - task.begin_at).total_seconds()
|
||||||
cls.model.update(process_duration=process_duration).where(cls.model.id == id).execute()
|
cls.model.update(process_duration=process_duration).where(cls.model.id == id).execute()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_by_doc_ids(cls, doc_ids):
|
||||||
|
"""Delete task associated with a document."""
|
||||||
|
return cls.model.delete().where(cls.model.doc_id.in_(doc_ids)).execute()
|
||||||
|
|
||||||
|
|
||||||
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||||
"""Create and queue document processing tasks.
|
"""Create and queue document processing tasks.
|
||||||
|
|||||||
@ -209,6 +209,11 @@ class TenantLLMService(CommonService):
|
|||||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||||
return list(objs)
|
return list(objs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_by_tenant_id(cls, tenant_id):
|
||||||
|
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||||
from api.db.services.llm_service import LLMService
|
from api.db.services.llm_service import LLMService
|
||||||
|
|||||||
@ -24,7 +24,24 @@ class UserCanvasVersionService(CommonService):
|
|||||||
return None
|
return None
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_canvas_version_by_canvas_ids(cls, canvas_ids):
|
||||||
|
fields = [cls.model.id]
|
||||||
|
versions = cls.model.select(*fields).where(cls.model.user_canvas_id.in_(canvas_ids))
|
||||||
|
versions.order_by(cls.model.create_time.asc())
|
||||||
|
offset, limit = 0, 100
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
version_batch = versions.offset(offset).limit(limit)
|
||||||
|
_temp = list(version_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def delete_all_versions(cls, user_canvas_id):
|
def delete_all_versions(cls, user_canvas_id):
|
||||||
|
|||||||
@ -45,22 +45,22 @@ class UserService(CommonService):
|
|||||||
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
||||||
if 'access_token' in kwargs:
|
if 'access_token' in kwargs:
|
||||||
access_token = kwargs['access_token']
|
access_token = kwargs['access_token']
|
||||||
|
|
||||||
# Reject empty, None, or whitespace-only access tokens
|
# Reject empty, None, or whitespace-only access tokens
|
||||||
if not access_token or not str(access_token).strip():
|
if not access_token or not str(access_token).strip():
|
||||||
logging.warning("UserService.query: Rejecting empty access_token query")
|
logging.warning("UserService.query: Rejecting empty access_token query")
|
||||||
return cls.model.select().where(cls.model.id == "INVALID_EMPTY_TOKEN") # Returns empty result
|
return cls.model.select().where(cls.model.id == "INVALID_EMPTY_TOKEN") # Returns empty result
|
||||||
|
|
||||||
# Reject tokens that are too short (should be UUID, 32+ chars)
|
# Reject tokens that are too short (should be UUID, 32+ chars)
|
||||||
if len(str(access_token).strip()) < 32:
|
if len(str(access_token).strip()) < 32:
|
||||||
logging.warning(f"UserService.query: Rejecting short access_token query: {len(str(access_token))} chars")
|
logging.warning(f"UserService.query: Rejecting short access_token query: {len(str(access_token))} chars")
|
||||||
return cls.model.select().where(cls.model.id == "INVALID_SHORT_TOKEN") # Returns empty result
|
return cls.model.select().where(cls.model.id == "INVALID_SHORT_TOKEN") # Returns empty result
|
||||||
|
|
||||||
# Reject tokens that start with "INVALID_" (from logout)
|
# Reject tokens that start with "INVALID_" (from logout)
|
||||||
if str(access_token).startswith("INVALID_"):
|
if str(access_token).startswith("INVALID_"):
|
||||||
logging.warning("UserService.query: Rejecting invalidated access_token")
|
logging.warning("UserService.query: Rejecting invalidated access_token")
|
||||||
return cls.model.select().where(cls.model.id == "INVALID_LOGOUT_TOKEN") # Returns empty result
|
return cls.model.select().where(cls.model.id == "INVALID_LOGOUT_TOKEN") # Returns empty result
|
||||||
|
|
||||||
# Call parent query method for valid requests
|
# Call parent query method for valid requests
|
||||||
return super().query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
|
return super().query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
|
||||||
|
|
||||||
@ -100,6 +100,12 @@ class UserService(CommonService):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def query_user_by_email(cls, email):
|
||||||
|
users = cls.model.select().where((cls.model.email == email))
|
||||||
|
return list(users)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def save(cls, **kwargs):
|
def save(cls, **kwargs):
|
||||||
@ -133,6 +139,17 @@ class UserService(CommonService):
|
|||||||
cls.model.update(user_dict).where(
|
cls.model.update(user_dict).where(
|
||||||
cls.model.id == user_id).execute()
|
cls.model.id == user_id).execute()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def update_user_password(cls, user_id, new_password):
|
||||||
|
with DB.atomic():
|
||||||
|
update_dict = {
|
||||||
|
"password": generate_password_hash(str(new_password)),
|
||||||
|
"update_time": current_timestamp(),
|
||||||
|
"update_date": datetime_format(datetime.now())
|
||||||
|
}
|
||||||
|
cls.model.update(update_dict).where(cls.model.id == user_id).execute()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def is_admin(cls, user_id):
|
def is_admin(cls, user_id):
|
||||||
@ -140,6 +157,12 @@ class UserService(CommonService):
|
|||||||
cls.model.id == user_id,
|
cls.model.id == user_id,
|
||||||
cls.model.is_superuser == 1).count() > 0
|
cls.model.is_superuser == 1).count() > 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_users(cls):
|
||||||
|
users = cls.model.select()
|
||||||
|
return list(users)
|
||||||
|
|
||||||
|
|
||||||
class TenantService(CommonService):
|
class TenantService(CommonService):
|
||||||
"""Service class for managing tenant-related database operations.
|
"""Service class for managing tenant-related database operations.
|
||||||
@ -265,6 +288,17 @@ class UserTenantService(CommonService):
|
|||||||
.join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
|
.join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
|
||||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_user_tenant_relation_by_user_id(cls, user_id):
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.user_id,
|
||||||
|
cls.model.tenant_id,
|
||||||
|
cls.model.role
|
||||||
|
]
|
||||||
|
return list(cls.model.select(*fields).where(cls.model.user_id == user_id).dicts().dicts())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_num_members(cls, user_id: str):
|
def get_num_members(cls, user_id: str):
|
||||||
|
|||||||
@ -41,7 +41,7 @@ from api import utils
|
|||||||
from api.db.db_models import init_database_tables as init_web_db
|
from api.db.db_models import init_database_tables as init_web_db
|
||||||
from api.db.init_data import init_web_data
|
from api.db.init_data import init_web_data
|
||||||
from api.versions import get_ragflow_version
|
from api.versions import get_ragflow_version
|
||||||
from api.utils import show_configs
|
from api.utils.configs import show_configs
|
||||||
from rag.settings import print_rag_settings
|
from rag.settings import print_rag_settings
|
||||||
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||||
from rag.utils.redis_conn import RedisDistributedLock
|
from rag.utils.redis_conn import RedisDistributedLock
|
||||||
|
|||||||
@ -24,7 +24,7 @@ import rag.utils.es_conn
|
|||||||
import rag.utils.infinity_conn
|
import rag.utils.infinity_conn
|
||||||
import rag.utils.opensearch_conn
|
import rag.utils.opensearch_conn
|
||||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||||
from api.utils import decrypt_database_config, get_base_config
|
from api.utils.configs import decrypt_database_config, get_base_config
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
|
|
||||||
|
|||||||
@ -16,184 +16,15 @@
|
|||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import requests
|
import requests
|
||||||
import logging
|
|
||||||
import copy
|
|
||||||
from enum import Enum, IntEnum
|
|
||||||
import importlib
|
import importlib
|
||||||
from Cryptodome.PublicKey import RSA
|
|
||||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
|
||||||
from filelock import FileLock
|
|
||||||
from api.constants import SERVICE_CONF
|
|
||||||
|
|
||||||
from . import file_utils
|
from .common import string_to_bytes
|
||||||
|
|
||||||
|
|
||||||
def conf_realpath(conf_name):
|
|
||||||
conf_path = f"conf/{conf_name}"
|
|
||||||
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
|
||||||
|
|
||||||
|
|
||||||
def read_config(conf_name=SERVICE_CONF):
|
|
||||||
local_config = {}
|
|
||||||
local_path = conf_realpath(f'local.{conf_name}')
|
|
||||||
|
|
||||||
# load local config file
|
|
||||||
if os.path.exists(local_path):
|
|
||||||
local_config = file_utils.load_yaml_conf(local_path)
|
|
||||||
if not isinstance(local_config, dict):
|
|
||||||
raise ValueError(f'Invalid config file: "{local_path}".')
|
|
||||||
|
|
||||||
global_config_path = conf_realpath(conf_name)
|
|
||||||
global_config = file_utils.load_yaml_conf(global_config_path)
|
|
||||||
|
|
||||||
if not isinstance(global_config, dict):
|
|
||||||
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
|
||||||
|
|
||||||
global_config.update(local_config)
|
|
||||||
return global_config
|
|
||||||
|
|
||||||
|
|
||||||
CONFIGS = read_config()
|
|
||||||
|
|
||||||
|
|
||||||
def show_configs():
|
|
||||||
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
|
||||||
for k, v in CONFIGS.items():
|
|
||||||
if isinstance(v, dict):
|
|
||||||
if "password" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["password"] = "*" * 8
|
|
||||||
if "access_key" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["access_key"] = "*" * 8
|
|
||||||
if "secret_key" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["secret_key"] = "*" * 8
|
|
||||||
if "secret" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["secret"] = "*" * 8
|
|
||||||
if "sas_token" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["sas_token"] = "*" * 8
|
|
||||||
if "oauth" in k:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
for key, val in v.items():
|
|
||||||
if "client_secret" in val:
|
|
||||||
val["client_secret"] = "*" * 8
|
|
||||||
if "authentication" in k:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
for key, val in v.items():
|
|
||||||
if "http_secret_key" in val:
|
|
||||||
val["http_secret_key"] = "*" * 8
|
|
||||||
msg += f"\n\t{k}: {v}"
|
|
||||||
logging.info(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def get_base_config(key, default=None):
|
|
||||||
if key is None:
|
|
||||||
return None
|
|
||||||
if default is None:
|
|
||||||
default = os.environ.get(key.upper())
|
|
||||||
return CONFIGS.get(key, default)
|
|
||||||
|
|
||||||
|
|
||||||
use_deserialize_safe_module = get_base_config(
|
|
||||||
'use_deserialize_safe_module', False)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseType:
|
|
||||||
def to_dict(self):
|
|
||||||
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
|
||||||
|
|
||||||
def to_dict_with_type(self):
|
|
||||||
def _dict(obj):
|
|
||||||
module = None
|
|
||||||
if issubclass(obj.__class__, BaseType):
|
|
||||||
data = {}
|
|
||||||
for attr, v in obj.__dict__.items():
|
|
||||||
k = attr.lstrip("_")
|
|
||||||
data[k] = _dict(v)
|
|
||||||
module = obj.__module__
|
|
||||||
elif isinstance(obj, (list, tuple)):
|
|
||||||
data = []
|
|
||||||
for i, vv in enumerate(obj):
|
|
||||||
data.append(_dict(vv))
|
|
||||||
elif isinstance(obj, dict):
|
|
||||||
data = {}
|
|
||||||
for _k, vv in obj.items():
|
|
||||||
data[_k] = _dict(vv)
|
|
||||||
else:
|
|
||||||
data = obj
|
|
||||||
return {"type": obj.__class__.__name__,
|
|
||||||
"data": data, "module": module}
|
|
||||||
|
|
||||||
return _dict(self)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomJSONEncoder(json.JSONEncoder):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self._with_type = kwargs.pop("with_type", False)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def default(self, obj):
|
|
||||||
if isinstance(obj, datetime.datetime):
|
|
||||||
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
elif isinstance(obj, datetime.date):
|
|
||||||
return obj.strftime('%Y-%m-%d')
|
|
||||||
elif isinstance(obj, datetime.timedelta):
|
|
||||||
return str(obj)
|
|
||||||
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
|
||||||
return obj.value
|
|
||||||
elif isinstance(obj, set):
|
|
||||||
return list(obj)
|
|
||||||
elif issubclass(type(obj), BaseType):
|
|
||||||
if not self._with_type:
|
|
||||||
return obj.to_dict()
|
|
||||||
else:
|
|
||||||
return obj.to_dict_with_type()
|
|
||||||
elif isinstance(obj, type):
|
|
||||||
return obj.__name__
|
|
||||||
else:
|
|
||||||
return json.JSONEncoder.default(self, obj)
|
|
||||||
|
|
||||||
|
|
||||||
def rag_uuid():
|
|
||||||
return uuid.uuid1().hex
|
|
||||||
|
|
||||||
|
|
||||||
def string_to_bytes(string):
|
|
||||||
return string if isinstance(
|
|
||||||
string, bytes) else string.encode(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_string(byte):
|
|
||||||
return byte.decode(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def json_dumps(src, byte=False, indent=None, with_type=False):
|
|
||||||
dest = json.dumps(
|
|
||||||
src,
|
|
||||||
indent=indent,
|
|
||||||
cls=CustomJSONEncoder,
|
|
||||||
with_type=with_type)
|
|
||||||
if byte:
|
|
||||||
dest = string_to_bytes(dest)
|
|
||||||
return dest
|
|
||||||
|
|
||||||
|
|
||||||
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
|
||||||
if isinstance(src, bytes):
|
|
||||||
src = bytes_to_string(src)
|
|
||||||
return json.loads(src, object_hook=object_hook,
|
|
||||||
object_pairs_hook=object_pairs_hook)
|
|
||||||
|
|
||||||
|
|
||||||
def current_timestamp():
|
def current_timestamp():
|
||||||
@ -215,45 +46,6 @@ def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
|
|||||||
return time_stamp
|
return time_stamp
|
||||||
|
|
||||||
|
|
||||||
def serialize_b64(src, to_str=False):
|
|
||||||
dest = base64.b64encode(pickle.dumps(src))
|
|
||||||
if not to_str:
|
|
||||||
return dest
|
|
||||||
else:
|
|
||||||
return bytes_to_string(dest)
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_b64(src):
|
|
||||||
src = base64.b64decode(
|
|
||||||
string_to_bytes(src) if isinstance(
|
|
||||||
src, str) else src)
|
|
||||||
if use_deserialize_safe_module:
|
|
||||||
return restricted_loads(src)
|
|
||||||
return pickle.loads(src)
|
|
||||||
|
|
||||||
|
|
||||||
safe_module = {
|
|
||||||
'numpy',
|
|
||||||
'rag_flow'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class RestrictedUnpickler(pickle.Unpickler):
|
|
||||||
def find_class(self, module, name):
|
|
||||||
import importlib
|
|
||||||
if module.split('.')[0] in safe_module:
|
|
||||||
_module = importlib.import_module(module)
|
|
||||||
return getattr(_module, name)
|
|
||||||
# Forbid everything else.
|
|
||||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
|
||||||
(module, name))
|
|
||||||
|
|
||||||
|
|
||||||
def restricted_loads(src):
|
|
||||||
"""Helper function analogous to pickle.loads()."""
|
|
||||||
return RestrictedUnpickler(io.BytesIO(src)).load()
|
|
||||||
|
|
||||||
|
|
||||||
def get_lan_ip():
|
def get_lan_ip():
|
||||||
if os.name != "nt":
|
if os.name != "nt":
|
||||||
import fcntl
|
import fcntl
|
||||||
@ -298,47 +90,6 @@ def from_dict_hook(in_dict: dict):
|
|||||||
return in_dict
|
return in_dict
|
||||||
|
|
||||||
|
|
||||||
def decrypt_database_password(password):
|
|
||||||
encrypt_password = get_base_config("encrypt_password", False)
|
|
||||||
encrypt_module = get_base_config("encrypt_module", False)
|
|
||||||
private_key = get_base_config("private_key", None)
|
|
||||||
|
|
||||||
if not password or not encrypt_password:
|
|
||||||
return password
|
|
||||||
|
|
||||||
if not private_key:
|
|
||||||
raise ValueError("No private key")
|
|
||||||
|
|
||||||
module_fun = encrypt_module.split("#")
|
|
||||||
pwdecrypt_fun = getattr(
|
|
||||||
importlib.import_module(
|
|
||||||
module_fun[0]),
|
|
||||||
module_fun[1])
|
|
||||||
|
|
||||||
return pwdecrypt_fun(private_key, password)
|
|
||||||
|
|
||||||
|
|
||||||
def decrypt_database_config(
|
|
||||||
database=None, passwd_key="password", name="database"):
|
|
||||||
if not database:
|
|
||||||
database = get_base_config(name, {})
|
|
||||||
|
|
||||||
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
|
||||||
return database
|
|
||||||
|
|
||||||
|
|
||||||
def update_config(key, value, conf_name=SERVICE_CONF):
|
|
||||||
conf_path = conf_realpath(conf_name=conf_name)
|
|
||||||
if not os.path.isabs(conf_path):
|
|
||||||
conf_path = os.path.join(
|
|
||||||
file_utils.get_project_base_directory(), conf_path)
|
|
||||||
|
|
||||||
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
|
||||||
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
|
||||||
config[key] = value
|
|
||||||
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
|
||||||
|
|
||||||
|
|
||||||
def get_uuid():
|
def get_uuid():
|
||||||
return uuid.uuid1().hex
|
return uuid.uuid1().hex
|
||||||
|
|
||||||
@ -363,37 +114,6 @@ def elapsed2time(elapsed):
|
|||||||
return '%02d:%02d:%02d' % (hour, minuter, second)
|
return '%02d:%02d:%02d' % (hour, minuter, second)
|
||||||
|
|
||||||
|
|
||||||
def decrypt(line):
|
|
||||||
file_path = os.path.join(
|
|
||||||
file_utils.get_project_base_directory(),
|
|
||||||
"conf",
|
|
||||||
"private.pem")
|
|
||||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
|
||||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
|
||||||
return cipher.decrypt(base64.b64decode(
|
|
||||||
line), "Fail to decrypt password!").decode('utf-8')
|
|
||||||
|
|
||||||
|
|
||||||
def decrypt2(crypt_text):
|
|
||||||
from base64 import b64decode, b16decode
|
|
||||||
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
|
|
||||||
from Crypto.PublicKey import RSA
|
|
||||||
decode_data = b64decode(crypt_text)
|
|
||||||
if len(decode_data) == 127:
|
|
||||||
hex_fixed = '00' + decode_data.hex()
|
|
||||||
decode_data = b16decode(hex_fixed.upper())
|
|
||||||
|
|
||||||
file_path = os.path.join(
|
|
||||||
file_utils.get_project_base_directory(),
|
|
||||||
"conf",
|
|
||||||
"private.pem")
|
|
||||||
pem = open(file_path).read()
|
|
||||||
rsa_key = RSA.importKey(pem, "Welcome")
|
|
||||||
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
|
|
||||||
decrypt_text = cipher.decrypt(decode_data, None)
|
|
||||||
return (b64decode(decrypt_text)).decode()
|
|
||||||
|
|
||||||
|
|
||||||
def download_img(url):
|
def download_img(url):
|
||||||
if not url:
|
if not url:
|
||||||
return ""
|
return ""
|
||||||
@ -408,5 +128,5 @@ def delta_seconds(date_string: str):
|
|||||||
return (datetime.datetime.now() - dt).total_seconds()
|
return (datetime.datetime.now() - dt).total_seconds()
|
||||||
|
|
||||||
|
|
||||||
def hash_str2int(line:str, mod: int=10 ** 8) -> int:
|
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
|
||||||
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||||
|
|||||||
@ -39,6 +39,7 @@ from flask import (
|
|||||||
make_response,
|
make_response,
|
||||||
send_file,
|
send_file,
|
||||||
)
|
)
|
||||||
|
from flask_login import current_user
|
||||||
from flask import (
|
from flask import (
|
||||||
request as flask_request,
|
request as flask_request,
|
||||||
)
|
)
|
||||||
@ -48,10 +49,13 @@ from werkzeug.http import HTTP_STATUS_CODES
|
|||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
||||||
|
from api.db import ActiveEnum
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
|
from api.db.services import UserService
|
||||||
from api.db.services.llm_service import LLMService
|
from api.db.services.llm_service import LLMService
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
from api.utils.json import CustomJSONEncoder, json_dumps
|
||||||
|
from api.utils import get_uuid
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||||
|
|
||||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||||
@ -226,6 +230,18 @@ def not_allowed_parameters(*params):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def active_required(f):
|
||||||
|
@wraps(f)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
user_id = current_user.id
|
||||||
|
usr = UserService.filter_by_id(user_id)
|
||||||
|
# check is_active
|
||||||
|
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
|
||||||
|
return get_json_result(code=settings.RetCode.FORBIDDEN, message="User isn't active, please activate first.")
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def is_localhost(ip):
|
def is_localhost(ip):
|
||||||
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
|
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
|
||||||
|
|
||||||
@ -643,6 +659,16 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
|
|||||||
return transformed_data
|
return transformed_data
|
||||||
|
|
||||||
|
|
||||||
|
def group_by(list_of_dict, key):
|
||||||
|
res = {}
|
||||||
|
for item in list_of_dict:
|
||||||
|
if item[key] in res.keys():
|
||||||
|
res[item[key]].append(item)
|
||||||
|
else:
|
||||||
|
res[item[key]] = [item]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
|
def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
|
||||||
results = {}
|
results = {}
|
||||||
tool_call_sessions = []
|
tool_call_sessions = []
|
||||||
|
|||||||
23
api/utils/common.py
Normal file
23
api/utils/common.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
def string_to_bytes(string):
|
||||||
|
return string if isinstance(
|
||||||
|
string, bytes) else string.encode(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_string(byte):
|
||||||
|
return byte.decode(encoding="utf-8")
|
||||||
179
api/utils/configs.py
Normal file
179
api/utils/configs.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import base64
|
||||||
|
import pickle
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from api.utils import file_utils
|
||||||
|
from filelock import FileLock
|
||||||
|
from api.utils.common import bytes_to_string, string_to_bytes
|
||||||
|
from api.constants import SERVICE_CONF
|
||||||
|
|
||||||
|
|
||||||
|
def conf_realpath(conf_name):
|
||||||
|
conf_path = f"conf/{conf_name}"
|
||||||
|
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||||
|
|
||||||
|
|
||||||
|
def read_config(conf_name=SERVICE_CONF):
|
||||||
|
local_config = {}
|
||||||
|
local_path = conf_realpath(f'local.{conf_name}')
|
||||||
|
|
||||||
|
# load local config file
|
||||||
|
if os.path.exists(local_path):
|
||||||
|
local_config = file_utils.load_yaml_conf(local_path)
|
||||||
|
if not isinstance(local_config, dict):
|
||||||
|
raise ValueError(f'Invalid config file: "{local_path}".')
|
||||||
|
|
||||||
|
global_config_path = conf_realpath(conf_name)
|
||||||
|
global_config = file_utils.load_yaml_conf(global_config_path)
|
||||||
|
|
||||||
|
if not isinstance(global_config, dict):
|
||||||
|
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
||||||
|
|
||||||
|
global_config.update(local_config)
|
||||||
|
return global_config
|
||||||
|
|
||||||
|
|
||||||
|
CONFIGS = read_config()
|
||||||
|
|
||||||
|
|
||||||
|
def show_configs():
|
||||||
|
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
||||||
|
for k, v in CONFIGS.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
if "password" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["password"] = "*" * 8
|
||||||
|
if "access_key" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["access_key"] = "*" * 8
|
||||||
|
if "secret_key" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["secret_key"] = "*" * 8
|
||||||
|
if "secret" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["secret"] = "*" * 8
|
||||||
|
if "sas_token" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["sas_token"] = "*" * 8
|
||||||
|
if "oauth" in k:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
for key, val in v.items():
|
||||||
|
if "client_secret" in val:
|
||||||
|
val["client_secret"] = "*" * 8
|
||||||
|
if "authentication" in k:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
for key, val in v.items():
|
||||||
|
if "http_secret_key" in val:
|
||||||
|
val["http_secret_key"] = "*" * 8
|
||||||
|
msg += f"\n\t{k}: {v}"
|
||||||
|
logging.info(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_config(key, default=None):
|
||||||
|
if key is None:
|
||||||
|
return None
|
||||||
|
if default is None:
|
||||||
|
default = os.environ.get(key.upper())
|
||||||
|
return CONFIGS.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_database_password(password):
|
||||||
|
encrypt_password = get_base_config("encrypt_password", False)
|
||||||
|
encrypt_module = get_base_config("encrypt_module", False)
|
||||||
|
private_key = get_base_config("private_key", None)
|
||||||
|
|
||||||
|
if not password or not encrypt_password:
|
||||||
|
return password
|
||||||
|
|
||||||
|
if not private_key:
|
||||||
|
raise ValueError("No private key")
|
||||||
|
|
||||||
|
module_fun = encrypt_module.split("#")
|
||||||
|
pwdecrypt_fun = getattr(
|
||||||
|
importlib.import_module(
|
||||||
|
module_fun[0]),
|
||||||
|
module_fun[1])
|
||||||
|
|
||||||
|
return pwdecrypt_fun(private_key, password)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_database_config(
|
||||||
|
database=None, passwd_key="password", name="database"):
|
||||||
|
if not database:
|
||||||
|
database = get_base_config(name, {})
|
||||||
|
|
||||||
|
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
||||||
|
return database
|
||||||
|
|
||||||
|
|
||||||
|
def update_config(key, value, conf_name=SERVICE_CONF):
|
||||||
|
conf_path = conf_realpath(conf_name=conf_name)
|
||||||
|
if not os.path.isabs(conf_path):
|
||||||
|
conf_path = os.path.join(
|
||||||
|
file_utils.get_project_base_directory(), conf_path)
|
||||||
|
|
||||||
|
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||||
|
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||||
|
config[key] = value
|
||||||
|
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
||||||
|
|
||||||
|
|
||||||
|
safe_module = {
|
||||||
|
'numpy',
|
||||||
|
'rag_flow'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RestrictedUnpickler(pickle.Unpickler):
|
||||||
|
def find_class(self, module, name):
|
||||||
|
import importlib
|
||||||
|
if module.split('.')[0] in safe_module:
|
||||||
|
_module = importlib.import_module(module)
|
||||||
|
return getattr(_module, name)
|
||||||
|
# Forbid everything else.
|
||||||
|
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||||
|
(module, name))
|
||||||
|
|
||||||
|
|
||||||
|
def restricted_loads(src):
|
||||||
|
"""Helper function analogous to pickle.loads()."""
|
||||||
|
return RestrictedUnpickler(io.BytesIO(src)).load()
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_b64(src, to_str=False):
|
||||||
|
dest = base64.b64encode(pickle.dumps(src))
|
||||||
|
if not to_str:
|
||||||
|
return dest
|
||||||
|
else:
|
||||||
|
return bytes_to_string(dest)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_b64(src):
|
||||||
|
src = base64.b64decode(
|
||||||
|
string_to_bytes(src) if isinstance(
|
||||||
|
src, str) else src)
|
||||||
|
use_deserialize_safe_module = get_base_config(
|
||||||
|
'use_deserialize_safe_module', False)
|
||||||
|
if use_deserialize_safe_module:
|
||||||
|
return restricted_loads(src)
|
||||||
|
return pickle.loads(src)
|
||||||
64
api/utils/crypt.py
Normal file
64
api/utils/crypt.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from Cryptodome.PublicKey import RSA
|
||||||
|
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||||
|
from api.utils import file_utils
|
||||||
|
|
||||||
|
|
||||||
|
def crypt(line):
|
||||||
|
"""
|
||||||
|
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
|
||||||
|
"""
|
||||||
|
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
|
||||||
|
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||||
|
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||||
|
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
|
||||||
|
encrypted_password = cipher.encrypt(password_base64.encode())
|
||||||
|
return base64.b64encode(encrypted_password).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt(line):
|
||||||
|
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
|
||||||
|
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||||
|
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||||
|
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt2(crypt_text):
|
||||||
|
from base64 import b64decode, b16decode
|
||||||
|
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
|
||||||
|
from Crypto.PublicKey import RSA
|
||||||
|
decode_data = b64decode(crypt_text)
|
||||||
|
if len(decode_data) == 127:
|
||||||
|
hex_fixed = '00' + decode_data.hex()
|
||||||
|
decode_data = b16decode(hex_fixed.upper())
|
||||||
|
|
||||||
|
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
|
||||||
|
pem = open(file_path).read()
|
||||||
|
rsa_key = RSA.importKey(pem, "Welcome")
|
||||||
|
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
|
||||||
|
decrypt_text = cipher.decrypt(decode_data, None)
|
||||||
|
return (b64decode(decrypt_text)).decode()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
passwd = crypt(sys.argv[1])
|
||||||
|
print(passwd)
|
||||||
|
print(decrypt(passwd))
|
||||||
107
api/utils/health_utils.py
Normal file
107
api/utils/health_utils.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
|
from api import settings
|
||||||
|
from api.db.db_models import DB
|
||||||
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
|
|
||||||
|
def _ok_nok(ok: bool) -> str:
|
||||||
|
return "ok" if ok else "nok"
|
||||||
|
|
||||||
|
|
||||||
|
def check_db() -> tuple[bool, dict]:
|
||||||
|
st = timer()
|
||||||
|
try:
|
||||||
|
# lightweight probe; works for MySQL/Postgres
|
||||||
|
DB.execute_sql("SELECT 1")
|
||||||
|
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||||
|
except Exception as e:
|
||||||
|
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def check_redis() -> tuple[bool, dict]:
|
||||||
|
st = timer()
|
||||||
|
try:
|
||||||
|
ok = bool(REDIS_CONN.health())
|
||||||
|
return ok, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||||
|
except Exception as e:
|
||||||
|
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def check_doc_engine() -> tuple[bool, dict]:
|
||||||
|
st = timer()
|
||||||
|
try:
|
||||||
|
meta = settings.docStoreConn.health()
|
||||||
|
# treat any successful call as ok
|
||||||
|
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
|
||||||
|
except Exception as e:
|
||||||
|
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def check_storage() -> tuple[bool, dict]:
|
||||||
|
st = timer()
|
||||||
|
try:
|
||||||
|
STORAGE_IMPL.health()
|
||||||
|
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||||
|
except Exception as e:
|
||||||
|
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run_health_checks() -> tuple[dict, bool]:
|
||||||
|
result: dict[str, str | dict] = {}
|
||||||
|
|
||||||
|
db_ok, db_meta = check_db()
|
||||||
|
result["db"] = _ok_nok(db_ok)
|
||||||
|
if not db_ok:
|
||||||
|
result.setdefault("_meta", {})["db"] = db_meta
|
||||||
|
|
||||||
|
try:
|
||||||
|
redis_ok, redis_meta = check_redis()
|
||||||
|
result["redis"] = _ok_nok(redis_ok)
|
||||||
|
if not redis_ok:
|
||||||
|
result.setdefault("_meta", {})["redis"] = redis_meta
|
||||||
|
except Exception:
|
||||||
|
result["redis"] = "nok"
|
||||||
|
|
||||||
|
try:
|
||||||
|
doc_ok, doc_meta = check_doc_engine()
|
||||||
|
result["doc_engine"] = _ok_nok(doc_ok)
|
||||||
|
if not doc_ok:
|
||||||
|
result.setdefault("_meta", {})["doc_engine"] = doc_meta
|
||||||
|
except Exception:
|
||||||
|
result["doc_engine"] = "nok"
|
||||||
|
|
||||||
|
try:
|
||||||
|
sto_ok, sto_meta = check_storage()
|
||||||
|
result["storage"] = _ok_nok(sto_ok)
|
||||||
|
if not sto_ok:
|
||||||
|
result.setdefault("_meta", {})["storage"] = sto_meta
|
||||||
|
except Exception:
|
||||||
|
result["storage"] = "nok"
|
||||||
|
|
||||||
|
|
||||||
|
all_ok = (result.get("db") == "ok") and (result.get("redis") == "ok") and (result.get("doc_engine") == "ok") and (result.get("storage") == "ok")
|
||||||
|
result["status"] = "ok" if all_ok else "nok"
|
||||||
|
return result, all_ok
|
||||||
|
|
||||||
|
|
||||||
78
api/utils/json.py
Normal file
78
api/utils/json.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
from enum import Enum, IntEnum
|
||||||
|
from api.utils.common import string_to_bytes, bytes_to_string
|
||||||
|
|
||||||
|
|
||||||
|
class BaseType:
|
||||||
|
def to_dict(self):
|
||||||
|
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
||||||
|
|
||||||
|
def to_dict_with_type(self):
|
||||||
|
def _dict(obj):
|
||||||
|
module = None
|
||||||
|
if issubclass(obj.__class__, BaseType):
|
||||||
|
data = {}
|
||||||
|
for attr, v in obj.__dict__.items():
|
||||||
|
k = attr.lstrip("_")
|
||||||
|
data[k] = _dict(v)
|
||||||
|
module = obj.__module__
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
data = []
|
||||||
|
for i, vv in enumerate(obj):
|
||||||
|
data.append(_dict(vv))
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
data = {}
|
||||||
|
for _k, vv in obj.items():
|
||||||
|
data[_k] = _dict(vv)
|
||||||
|
else:
|
||||||
|
data = obj
|
||||||
|
return {"type": obj.__class__.__name__,
|
||||||
|
"data": data, "module": module}
|
||||||
|
|
||||||
|
return _dict(self)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomJSONEncoder(json.JSONEncoder):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self._with_type = kwargs.pop("with_type", False)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, datetime.datetime):
|
||||||
|
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
elif isinstance(obj, datetime.date):
|
||||||
|
return obj.strftime('%Y-%m-%d')
|
||||||
|
elif isinstance(obj, datetime.timedelta):
|
||||||
|
return str(obj)
|
||||||
|
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
||||||
|
return obj.value
|
||||||
|
elif isinstance(obj, set):
|
||||||
|
return list(obj)
|
||||||
|
elif issubclass(type(obj), BaseType):
|
||||||
|
if not self._with_type:
|
||||||
|
return obj.to_dict()
|
||||||
|
else:
|
||||||
|
return obj.to_dict_with_type()
|
||||||
|
elif isinstance(obj, type):
|
||||||
|
return obj.__name__
|
||||||
|
else:
|
||||||
|
return json.JSONEncoder.default(self, obj)
|
||||||
|
|
||||||
|
|
||||||
|
def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||||
|
dest = json.dumps(
|
||||||
|
src,
|
||||||
|
indent=indent,
|
||||||
|
cls=CustomJSONEncoder,
|
||||||
|
with_type=with_type)
|
||||||
|
if byte:
|
||||||
|
dest = string_to_bytes(dest)
|
||||||
|
return dest
|
||||||
|
|
||||||
|
|
||||||
|
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
||||||
|
if isinstance(src, bytes):
|
||||||
|
src = bytes_to_string(src)
|
||||||
|
return json.loads(src, object_hook=object_hook,
|
||||||
|
object_pairs_hook=object_pairs_hook)
|
||||||
@ -1,40 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from Cryptodome.PublicKey import RSA
|
|
||||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
|
||||||
from api.utils import decrypt, file_utils
|
|
||||||
|
|
||||||
|
|
||||||
def crypt(line):
|
|
||||||
file_path = os.path.join(
|
|
||||||
file_utils.get_project_base_directory(),
|
|
||||||
"conf",
|
|
||||||
"public.pem")
|
|
||||||
rsa_key = RSA.importKey(open(file_path).read(),"Welcome")
|
|
||||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
|
||||||
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
|
|
||||||
encrypted_password = cipher.encrypt(password_base64.encode())
|
|
||||||
return base64.b64encode(encrypted_password).decode('utf-8')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
passwd = crypt(sys.argv[1])
|
|
||||||
print(passwd)
|
|
||||||
print(decrypt(passwd))
|
|
||||||
19
chat_demo/index.html
Normal file
19
chat_demo/index.html
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
<iframe src="http://localhost:9222/next-chats/widget?shared_id=9dcfc68696c611f0bb789b9b8b765d12&from=chat&auth=U4MDU3NzkwOTZjNzExZjBiYjc4OWI5Yj&mode=master&streaming=false"
|
||||||
|
style="position:fixed;bottom:0;right:0;width:100px;height:100px;border:none;background:transparent;z-index:9999"
|
||||||
|
frameborder="0" allow="microphone;camera"></iframe>
|
||||||
|
<script>
|
||||||
|
window.addEventListener('message',e=>{
|
||||||
|
if(e.origin!=='http://localhost:9222')return;
|
||||||
|
if(e.data.type==='CREATE_CHAT_WINDOW'){
|
||||||
|
if(document.getElementById('chat-win'))return;
|
||||||
|
const i=document.createElement('iframe');
|
||||||
|
i.id='chat-win';i.src=e.data.src;
|
||||||
|
i.style.cssText='position:fixed;bottom:104px;right:24px;width:380px;height:500px;border:none;background:transparent;z-index:9998;display:none';
|
||||||
|
i.frameBorder='0';i.allow='microphone;camera';
|
||||||
|
document.body.appendChild(i);
|
||||||
|
}else if(e.data.type==='TOGGLE_CHAT'){
|
||||||
|
const w=document.getElementById('chat-win');
|
||||||
|
if(w)w.style.display=e.data.isOpen?'block':'none';
|
||||||
|
}else if(e.data.type==='SCROLL_PASSTHROUGH')window.scrollBy(0,e.data.deltaY);
|
||||||
|
});
|
||||||
|
</script>
|
||||||
154
chat_demo/widget_demo.html
Normal file
154
chat_demo/widget_demo.html
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Floating Chat Widget Demo</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
margin: 0;
|
||||||
|
padding: 40px;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
min-height: 100vh;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.demo-content {
|
||||||
|
max-width: 800px;
|
||||||
|
margin: 0 auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.demo-content h1 {
|
||||||
|
text-align: center;
|
||||||
|
font-size: 2.5rem;
|
||||||
|
margin-bottom: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.demo-content p {
|
||||||
|
font-size: 1.2rem;
|
||||||
|
line-height: 1.6;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.feature-list {
|
||||||
|
background: rgba(255, 255, 255, 0.1);
|
||||||
|
border-radius: 10px;
|
||||||
|
padding: 2rem;
|
||||||
|
margin: 2rem 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.feature-list h3 {
|
||||||
|
margin-top: 0;
|
||||||
|
font-size: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.feature-list ul {
|
||||||
|
list-style-type: none;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.feature-list li {
|
||||||
|
padding: 0.5rem 0;
|
||||||
|
padding-left: 1.5rem;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.feature-list li:before {
|
||||||
|
content: "✓";
|
||||||
|
position: absolute;
|
||||||
|
left: 0;
|
||||||
|
color: #4ade80;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="demo-content">
|
||||||
|
<h1>🚀 Floating Chat Widget Demo</h1>
|
||||||
|
|
||||||
|
<p>
|
||||||
|
Welcome to our demo page! This page simulates a real website with content.
|
||||||
|
Look for the floating chat button in the bottom-right corner - just like Intercom!
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div class="feature-list">
|
||||||
|
<h3>🎯 Widget Features</h3>
|
||||||
|
<ul>
|
||||||
|
<li>Floating button that stays visible while scrolling</li>
|
||||||
|
<li>Click to open/close the chat window</li>
|
||||||
|
<li>Minimize button to collapse the chat</li>
|
||||||
|
<li>Professional Intercom-style design</li>
|
||||||
|
<li>Unread message indicator (red badge)</li>
|
||||||
|
<li>Transparent background integration</li>
|
||||||
|
<li>Responsive design for all screen sizes</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p>
|
||||||
|
The chat widget is completely separate from your website's content and won't
|
||||||
|
interfere with your existing layout or functionality. It's designed to be
|
||||||
|
lightweight and performant.
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<p>
|
||||||
|
Try scrolling this page - notice how the chat button stays in position.
|
||||||
|
Click it to start a conversation with our AI assistant!
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div class="feature-list">
|
||||||
|
<h3>🔧 Implementation</h3>
|
||||||
|
<ul>
|
||||||
|
<li>Simple iframe embed - just copy and paste</li>
|
||||||
|
<li>No JavaScript dependencies required</li>
|
||||||
|
<li>Works on any website or platform</li>
|
||||||
|
<li>Customizable appearance and behavior</li>
|
||||||
|
<li>Secure and privacy-focused</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p>
|
||||||
|
This is just placeholder content to demonstrate how the widget integrates
|
||||||
|
seamlessly with your existing website content. The widget floats above
|
||||||
|
everything else without disrupting your user experience.
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<p style="margin-top: 4rem; text-align: center; font-style: italic;">
|
||||||
|
🎉 Ready to add this to your website? Get your embed code from the admin panel!
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<iframe id="main-widget" src="http://localhost:9222/next-chats/widget?shared_id=9dcfc68696c611f0bb789b9b8b765d12&from=chat&auth=U4MDU3NzkwOTZjNzExZjBiYjc4OWI5Yj&visible_avatar=1&locale=zh&mode=master&streaming=false"
|
||||||
|
style="position:fixed;bottom:0;right:0;width:100px;height:100px;border:none;background:transparent;z-index:9999;opacity:0;transition:opacity 0.2s ease"
|
||||||
|
frameborder="0" allow="microphone;camera"></iframe>
|
||||||
|
<script>
|
||||||
|
window.addEventListener('message',e=>{
|
||||||
|
if(e.origin!=='http://localhost:9222')return;
|
||||||
|
if(e.data.type==='WIDGET_READY'){
|
||||||
|
// Show the main widget when React is ready
|
||||||
|
const mainWidget = document.getElementById('main-widget');
|
||||||
|
if(mainWidget) mainWidget.style.opacity = '1';
|
||||||
|
}else if(e.data.type==='CREATE_CHAT_WINDOW'){
|
||||||
|
if(document.getElementById('chat-win'))return;
|
||||||
|
const i=document.createElement('iframe');
|
||||||
|
i.id='chat-win';i.src=e.data.src;
|
||||||
|
i.style.cssText='position:fixed;bottom:104px;right:24px;width:380px;height:500px;border:none;background:transparent;z-index:9998;display:none;opacity:0;transition:opacity 0.2s ease';
|
||||||
|
i.frameBorder='0';i.allow='microphone;camera';
|
||||||
|
document.body.appendChild(i);
|
||||||
|
}else if(e.data.type==='TOGGLE_CHAT'){
|
||||||
|
const w=document.getElementById('chat-win');
|
||||||
|
if(w){
|
||||||
|
if(e.data.isOpen){
|
||||||
|
w.style.display='block';
|
||||||
|
// Wait for the iframe content to be ready before showing
|
||||||
|
setTimeout(() => w.style.opacity='1', 100);
|
||||||
|
}else{
|
||||||
|
w.style.opacity='0';
|
||||||
|
setTimeout(() => w.style.display='none', 200);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}else if(e.data.type==='SCROLL_PASSTHROUGH')window.scrollBy(0,e.data.deltaY);
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@ -402,7 +402,7 @@
|
|||||||
"is_tools": true
|
"is_tools": true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"llm_name": "qwen3-max-preview",
|
"llm_name": "qwen3-max",
|
||||||
"tags": "LLM,CHAT,256k",
|
"tags": "LLM,CHAT,256k",
|
||||||
"max_tokens": 256000,
|
"max_tokens": 256000,
|
||||||
"model_type": "chat",
|
"model_type": "chat",
|
||||||
@ -436,6 +436,27 @@
|
|||||||
"model_type": "chat",
|
"model_type": "chat",
|
||||||
"is_tools": true
|
"is_tools": true
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-vl-plus",
|
||||||
|
"tags": "LLM,CHAT,IMAGE2TEXT,256k",
|
||||||
|
"max_tokens": 256000,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-vl-235b-a22b-instruct",
|
||||||
|
"tags": "LLM,CHAT,IMAGE2TEXT,128k",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-vl-235b-a22b-thinking",
|
||||||
|
"tags": "LLM,CHAT,IMAGE2TEXT,128k",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"llm_name": "qwen3-235b-a22b-instruct-2507",
|
"llm_name": "qwen3-235b-a22b-instruct-2507",
|
||||||
"tags": "LLM,CHAT,128k",
|
"tags": "LLM,CHAT,128k",
|
||||||
@ -457,6 +478,20 @@
|
|||||||
"model_type": "chat",
|
"model_type": "chat",
|
||||||
"is_tools": true
|
"is_tools": true
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-next-80b-a3b-instruct",
|
||||||
|
"tags": "LLM,CHAT,128k",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-next-80b-a3b-thinking",
|
||||||
|
"tags": "LLM,CHAT,128k",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"llm_name": "qwen3-0.6b",
|
"llm_name": "qwen3-0.6b",
|
||||||
"tags": "LLM,CHAT,32k",
|
"tags": "LLM,CHAT,32k",
|
||||||
@ -622,6 +657,13 @@
|
|||||||
"tags": "SPEECH2TEXT,8k",
|
"tags": "SPEECH2TEXT,8k",
|
||||||
"max_tokens": 8000,
|
"max_tokens": 8000,
|
||||||
"model_type": "speech2text"
|
"model_type": "speech2text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qianwen-deepresearch-30b-a3b-131k",
|
||||||
|
"tags": "LLM,CHAT,1M,AGENT,DEEPRESEARCH",
|
||||||
|
"max_tokens": 1000000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
ragflow:
|
ragflow:
|
||||||
host: 0.0.0.0
|
host: 0.0.0.0
|
||||||
http_port: 9380
|
http_port: 9380
|
||||||
|
admin:
|
||||||
|
host: 0.0.0.0
|
||||||
|
http_port: 9381
|
||||||
mysql:
|
mysql:
|
||||||
name: 'rag_flow'
|
name: 'rag_flow'
|
||||||
user: 'root'
|
user: 'root'
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
||||||
from rag.prompts import vision_llm_figure_describe_prompt
|
from rag.prompts.generator import vision_llm_figure_describe_prompt
|
||||||
|
|
||||||
|
|
||||||
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
||||||
|
|||||||
@ -37,7 +37,7 @@ from api.utils.file_utils import get_project_base_directory
|
|||||||
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
|
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
|
||||||
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
from rag.prompts import vision_llm_describe_prompt
|
from rag.prompts.generator import vision_llm_describe_prompt
|
||||||
from rag.settings import PARALLEL_DEVICES
|
from rag.settings import PARALLEL_DEVICES
|
||||||
|
|
||||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||||
|
|||||||
@ -350,7 +350,7 @@ class TextRecognizer:
|
|||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
# close session and release manually
|
# close session and release manually
|
||||||
logging.info('Close TextRecognizer.')
|
logging.info('Close text recognizer.')
|
||||||
if hasattr(self, "predictor"):
|
if hasattr(self, "predictor"):
|
||||||
del self.predictor
|
del self.predictor
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -490,7 +490,7 @@ class TextDetector:
|
|||||||
return dt_boxes
|
return dt_boxes
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
logging.info("Close TextDetector.")
|
logging.info("Close text detector.")
|
||||||
if hasattr(self, "predictor"):
|
if hasattr(self, "predictor"):
|
||||||
del self.predictor
|
del self.predictor
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
ragflow:
|
ragflow:
|
||||||
host: ${RAGFLOW_HOST:-0.0.0.0}
|
host: ${RAGFLOW_HOST:-0.0.0.0}
|
||||||
http_port: 9380
|
http_port: 9380
|
||||||
|
admin:
|
||||||
|
host: ${RAGFLOW_HOST:-0.0.0.0}
|
||||||
|
http_port: 9381
|
||||||
mysql:
|
mysql:
|
||||||
name: '${MYSQL_DBNAME:-rag_flow}'
|
name: '${MYSQL_DBNAME:-rag_flow}'
|
||||||
user: '${MYSQL_USER:-root}'
|
user: '${MYSQL_USER:-root}'
|
||||||
|
|||||||
@ -3,6 +3,6 @@
|
|||||||
"position": 40,
|
"position": 40,
|
||||||
"link": {
|
"link": {
|
||||||
"type": "generated-index",
|
"type": "generated-index",
|
||||||
"description": "Guides and references on accessing RAGFlow's knowledge bases via MCP."
|
"description": "Guides and references on accessing RAGFlow's datasets via MCP."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,9 +14,9 @@ A RAGFlow Model Context Protocol (MCP) server is designed as an independent comp
|
|||||||
An MCP server can start up in either self-host mode (default) or host mode:
|
An MCP server can start up in either self-host mode (default) or host mode:
|
||||||
|
|
||||||
- **Self-host mode**:
|
- **Self-host mode**:
|
||||||
When launching an MCP server in self-host mode, you must provide an API key to authenticate the MCP server with the RAGFlow server. In this mode, the MCP server can access *only* the datasets (knowledge bases) of a specified tenant on the RAGFlow server.
|
When launching an MCP server in self-host mode, you must provide an API key to authenticate the MCP server with the RAGFlow server. In this mode, the MCP server can access *only* the datasets of a specified tenant on the RAGFlow server.
|
||||||
- **Host mode**:
|
- **Host mode**:
|
||||||
In host mode, each MCP client can access their own knowledge bases on the RAGFlow server. However, each client request must include a valid API key to authenticate the client with the RAGFlow server.
|
In host mode, each MCP client can access their own datasets on the RAGFlow server. However, each client request must include a valid API key to authenticate the client with the RAGFlow server.
|
||||||
|
|
||||||
Once a connection is established, an MCP server communicates with its client in MCP HTTP+SSE (Server-Sent Events) mode, unidirectionally pushing responses from the RAGFlow server to its client in real time.
|
Once a connection is established, an MCP server communicates with its client in MCP HTTP+SSE (Server-Sent Events) mode, unidirectionally pushing responses from the RAGFlow server to its client in real time.
|
||||||
|
|
||||||
|
|||||||
15
docs/faq.mdx
15
docs/faq.mdx
@ -498,7 +498,7 @@ To switch your document engine from Elasticsearch to [Infinity](https://github.c
|
|||||||
|
|
||||||
### Where are my uploaded files stored in RAGFlow's image?
|
### Where are my uploaded files stored in RAGFlow's image?
|
||||||
|
|
||||||
All uploaded files are stored in Minio, RAGFlow's object storage solution. For instance, if you upload your file directly to a knowledge base, it is located at `<knowledgebase_id>/filename`.
|
All uploaded files are stored in Minio, RAGFlow's object storage solution. For instance, if you upload your file directly to a dataset, it is located at `<knowledgebase_id>/filename`.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@ -507,3 +507,16 @@ All uploaded files are stored in Minio, RAGFlow's object storage solution. For i
|
|||||||
You can control the batch size for document parsing and embedding by setting the environment variables `DOC_BULK_SIZE` and `EMBEDDING_BATCH_SIZE`. Increasing these values may improve throughput for large-scale data processing, but will also increase memory usage. Adjust them according to your hardware resources.
|
You can control the batch size for document parsing and embedding by setting the environment variables `DOC_BULK_SIZE` and `EMBEDDING_BATCH_SIZE`. Increasing these values may improve throughput for large-scale data processing, but will also increase memory usage. Adjust them according to your hardware resources.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
### How to accelerate the question-answering speed of my chat assistant?
|
||||||
|
|
||||||
|
See [here](./guides/chat/best_practices/accelerate_question_answering.mdx).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### How to accelerate the question-answering speed of my Agent?
|
||||||
|
|
||||||
|
See [here](./guides/agent/best_practices/accelerate_agent_question_answering.md).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
|||||||
@ -229,18 +229,4 @@ The global variable name for the output of the **Agent** component, which can be
|
|||||||
|
|
||||||
### Why does it take so long for my Agent to respond?
|
### Why does it take so long for my Agent to respond?
|
||||||
|
|
||||||
An Agent’s response time generally depends on two key factors: the LLM’s capabilities and the prompt, the latter reflecting task complexity. When using an Agent, you should always balance task demands with the LLM’s ability. See [How to balance task complexity with an Agent's performance and speed?](#how-to-balance-task-complexity-with-an-agents-performance-and-speed) for details.
|
See [here](../best_practices/accelerate_agent_question_answering.md) for details.
|
||||||
|
|
||||||
## Best practices
|
|
||||||
|
|
||||||
### How to balance task complexity with an Agent’s performance and speed?
|
|
||||||
|
|
||||||
- For simple tasks, such as retrieval, rewriting, formatting, or structured data extraction, use concise prompts, remove planning or reasoning instructions, enforce output length limits, and select smaller or Turbo-class models. This significantly reduces latency and cost with minimal impact on quality.
|
|
||||||
|
|
||||||
- For complex tasks, like multi-step reasoning, cross-document synthesis, or tool-based workflows, maintain or enhance prompts that include planning, reflection, and verification steps.
|
|
||||||
|
|
||||||
- In multi-Agent orchestration systems, delegate simple subtasks to sub-Agents using smaller, faster models, and reserve more powerful models for the lead Agent to handle complexity and uncertainty.
|
|
||||||
|
|
||||||
:::tip KEY INSIGHT
|
|
||||||
Focus on minimizing output tokens — through summarization, bullet points, or explicit length limits — as this has far greater impact on reducing latency than optimizing input size.
|
|
||||||
:::
|
|
||||||
@ -67,14 +67,14 @@ You can tune document parsing and embedding efficiency by setting the environmen
|
|||||||
|
|
||||||
## Frequently asked questions
|
## Frequently asked questions
|
||||||
|
|
||||||
### Is the uploaded file in a knowledge base?
|
### Is the uploaded file in a dataset?
|
||||||
|
|
||||||
No. Files uploaded to an agent as input are not stored in a knowledge base and hence will not be processed using RAGFlow's built-in OCR, DLR or TSR models, or chunked using RAGFlow's built-in chunking methods.
|
No. Files uploaded to an agent as input are not stored in a dataset and hence will not be processed using RAGFlow's built-in OCR, DLR or TSR models, or chunked using RAGFlow's built-in chunking methods.
|
||||||
|
|
||||||
### File size limit for an uploaded file
|
### File size limit for an uploaded file
|
||||||
|
|
||||||
There is no _specific_ file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete.
|
There is no _specific_ file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete.
|
||||||
|
|
||||||
:::tip NOTE
|
:::tip NOTE
|
||||||
The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a knowledge base or **File Management**. These settings DO NOT apply in this scenario.
|
The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a dataset or **File Management**. These settings DO NOT apply in this scenario.
|
||||||
:::
|
:::
|
||||||
|
|||||||
@ -49,6 +49,10 @@ You can specify multiple input sources for the **Code** component. Click **+ Add
|
|||||||
|
|
||||||
This field allows you to enter and edit your source code.
|
This field allows you to enter and edit your source code.
|
||||||
|
|
||||||
|
:::danger IMPORTANT
|
||||||
|
If your code implementation includes defined variables, whether input or output variables, ensure they are also specified in the corresponding **Input** or **Output** sections.
|
||||||
|
:::
|
||||||
|
|
||||||
#### A Python code example
|
#### A Python code example
|
||||||
|
|
||||||
```Python
|
```Python
|
||||||
@ -77,6 +81,15 @@ This field allows you to enter and edit your source code.
|
|||||||
|
|
||||||
You define the output variable(s) of the **Code** component here.
|
You define the output variable(s) of the **Code** component here.
|
||||||
|
|
||||||
|
:::danger IMPORTANT
|
||||||
|
If you define output variables here, ensure they are also defined in your code implementation; otherwise, their values will be `null`. The following are two examples:
|
||||||
|
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
:::
|
||||||
|
|
||||||
### Output
|
### Output
|
||||||
|
|
||||||
The defined output variable(s) will be auto-populated here.
|
The defined output variable(s) will be auto-populated here.
|
||||||
|
|||||||
@ -9,7 +9,7 @@ A component that retrieves information from specified datasets.
|
|||||||
|
|
||||||
## Scenarios
|
## Scenarios
|
||||||
|
|
||||||
A **Retrieval** component is essential in most RAG scenarios, where information is extracted from designated knowledge bases before being sent to the LLM for content generation. A **Retrieval** component can operate either as a standalone workflow module or as a tool for an **Agent** component. In the latter role, the **Agent** component has autonomous control over when to invoke it for query and retrieval.
|
A **Retrieval** component is essential in most RAG scenarios, where information is extracted from designated datasets before being sent to the LLM for content generation. A **Retrieval** component can operate either as a standalone workflow module or as a tool for an **Agent** component. In the latter role, the **Agent** component has autonomous control over when to invoke it for query and retrieval.
|
||||||
|
|
||||||
The following screenshot shows a reference design using the **Retrieval** component, where the component serves as a tool for an **Agent** component. You can find it from the **Report Agent Using Knowledge Base** Agent template.
|
The following screenshot shows a reference design using the **Retrieval** component, where the component serves as a tool for an **Agent** component. You can find it from the **Report Agent Using Knowledge Base** Agent template.
|
||||||
|
|
||||||
@ -17,7 +17,7 @@ The following screenshot shows a reference design using the **Retrieval** compon
|
|||||||
|
|
||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Ensure you [have properly configured your target knowledge base(s)](../../dataset/configure_knowledge_base.md).
|
Ensure you [have properly configured your target dataset(s)](../../dataset/configure_knowledge_base.md).
|
||||||
|
|
||||||
## Quickstart
|
## Quickstart
|
||||||
|
|
||||||
@ -36,9 +36,9 @@ The **Retrieval** component depends on query variables to specify its queries.
|
|||||||
|
|
||||||
By default, you can use `sys.query`, which is the user query and the default output of the **Begin** component. All global variables defined before the **Retrieval** component can also be used as query statements. Use the `(x)` button or type `/` to show all the available query variables.
|
By default, you can use `sys.query`, which is the user query and the default output of the **Begin** component. All global variables defined before the **Retrieval** component can also be used as query statements. Use the `(x)` button or type `/` to show all the available query variables.
|
||||||
|
|
||||||
### 3. Select knowledge base(s) to query
|
### 3. Select dataset(s) to query
|
||||||
|
|
||||||
You can specify one or multiple knowledge bases to retrieve data from. If selecting mutiple, ensure they use the same embedding model.
|
You can specify one or multiple datasets to retrieve data from. If selecting mutiple, ensure they use the same embedding model.
|
||||||
|
|
||||||
### 4. Expand **Advanced Settings** to configure the retrieval method
|
### 4. Expand **Advanced Settings** to configure the retrieval method
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ Using a rerank model will *significantly* increase the system's response time. I
|
|||||||
|
|
||||||
### 5. Enable cross-language search
|
### 5. Enable cross-language search
|
||||||
|
|
||||||
If your user query is different from the languages of the knowledge bases, you can select the target languages in the **Cross-language search** dropdown menu. The model will then translates queries to ensure accurate matching of semantic meaning across languages.
|
If your user query is different from the languages of the datasets, you can select the target languages in the **Cross-language search** dropdown menu. The model will then translates queries to ensure accurate matching of semantic meaning across languages.
|
||||||
|
|
||||||
|
|
||||||
### 6. Test retrieval results
|
### 6. Test retrieval results
|
||||||
@ -76,10 +76,10 @@ The **Retrieval** component relies on query variables to specify its queries. Al
|
|||||||
|
|
||||||
### Knowledge bases
|
### Knowledge bases
|
||||||
|
|
||||||
Select the knowledge base(s) to retrieve data from.
|
Select the dataset(s) to retrieve data from.
|
||||||
|
|
||||||
- If no knowledge base is selected, meaning conversations with the agent will not be based on any knowledge base, ensure that the **Empty response** field is left blank to avoid an error.
|
- If no dataset is selected, meaning conversations with the agent will not be based on any dataset, ensure that the **Empty response** field is left blank to avoid an error.
|
||||||
- If you select multiple knowledge bases, you must ensure that the knowledge bases (datasets) you select use the same embedding model; otherwise, an error message would occur.
|
- If you select multiple datasets, you must ensure that the datasets you select use the same embedding model; otherwise, an error message would occur.
|
||||||
|
|
||||||
### Similarity threshold
|
### Similarity threshold
|
||||||
|
|
||||||
@ -110,11 +110,11 @@ Using a rerank model will *significantly* increase the system's response time.
|
|||||||
|
|
||||||
### Empty response
|
### Empty response
|
||||||
|
|
||||||
- Set this as a response if no results are retrieved from the knowledge base(s) for your query, or
|
- Set this as a response if no results are retrieved from the dataset(s) for your query, or
|
||||||
- Leave this field blank to allow the chat model to improvise when nothing is found.
|
- Leave this field blank to allow the chat model to improvise when nothing is found.
|
||||||
|
|
||||||
:::caution WARNING
|
:::caution WARNING
|
||||||
If you do not specify a knowledge base, you must leave this field blank; otherwise, an error would occur.
|
If you do not specify a dataset, you must leave this field blank; otherwise, an error would occur.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
### Cross-language search
|
### Cross-language search
|
||||||
@ -124,10 +124,10 @@ Select one or more languages for cross‑language search. If no language is sele
|
|||||||
### Use knowledge graph
|
### Use knowledge graph
|
||||||
|
|
||||||
:::caution IMPORTANT
|
:::caution IMPORTANT
|
||||||
Before enabling this feature, ensure you have properly [constructed a knowledge graph from each target knowledge base](../../dataset/construct_knowledge_graph.md).
|
Before enabling this feature, ensure you have properly [constructed a knowledge graph from each target dataset](../../dataset/construct_knowledge_graph.md).
|
||||||
:::
|
:::
|
||||||
|
|
||||||
Whether to use knowledge graph(s) in the specified knowledge base(s) during retrieval for multi-hop question answering. When enabled, this would involve iterative searches across entity, relationship, and community report chunks, greatly increasing retrieval time.
|
Whether to use knowledge graph(s) in the specified dataset(s) during retrieval for multi-hop question answering. When enabled, this would involve iterative searches across entity, relationship, and community report chunks, greatly increasing retrieval time.
|
||||||
|
|
||||||
### Output
|
### Output
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@ Agents and RAG are complementary techniques, each enhancing the other’s capabi
|
|||||||
Before proceeding, ensure that:
|
Before proceeding, ensure that:
|
||||||
|
|
||||||
1. You have properly set the LLM to use. See the guides on [Configure your API key](../models/llm_api_key_setup.md) or [Deploy a local LLM](../models/deploy_local_llm.mdx) for more information.
|
1. You have properly set the LLM to use. See the guides on [Configure your API key](../models/llm_api_key_setup.md) or [Deploy a local LLM](../models/deploy_local_llm.mdx) for more information.
|
||||||
2. You have a knowledge base configured and the corresponding files properly parsed. See the guide on [Configure a knowledge base](../dataset/configure_knowledge_base.md) for more information.
|
2. You have a dataset configured and the corresponding files properly parsed. See the guide on [Configure a dataset](../dataset/configure_knowledge_base.md) for more information.
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
|||||||
8
docs/guides/agent/best_practices/_category_.json
Normal file
8
docs/guides/agent/best_practices/_category_.json
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"label": "Best practices",
|
||||||
|
"position": 30,
|
||||||
|
"link": {
|
||||||
|
"type": "generated-index",
|
||||||
|
"description": "Best practices on Agent configuration."
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,58 @@
|
|||||||
|
---
|
||||||
|
sidebar_position: 1
|
||||||
|
slug: /accelerate_agent_question_answering
|
||||||
|
---
|
||||||
|
|
||||||
|
# Accelerate answering
|
||||||
|
|
||||||
|
A checklist to speed up question answering.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Please note that some of your settings may consume a significant amount of time. If you often find that your question answering is time-consuming, here is a checklist to consider:
|
||||||
|
|
||||||
|
## Balance task complexity with an Agent’s performance and speed?
|
||||||
|
|
||||||
|
An Agent’s response time generally depends on many factors, e.g., the LLM’s capabilities and the prompt, the latter reflecting task complexity. When using an Agent, you should always balance task demands with the LLM’s ability.
|
||||||
|
|
||||||
|
- For simple tasks, such as retrieval, rewriting, formatting, or structured data extraction, use concise prompts, remove planning or reasoning instructions, enforce output length limits, and select smaller or Turbo-class models. This significantly reduces latency and cost with minimal impact on quality.
|
||||||
|
|
||||||
|
- For complex tasks, like multi-step reasoning, cross-document synthesis, or tool-based workflows, maintain or enhance prompts that include planning, reflection, and verification steps.
|
||||||
|
|
||||||
|
- In multi-Agent orchestration systems, delegate simple subtasks to sub-Agents using smaller, faster models, and reserve more powerful models for the lead Agent to handle complexity and uncertainty.
|
||||||
|
|
||||||
|
:::tip KEY INSIGHT
|
||||||
|
Focus on minimizing output tokens — through summarization, bullet points, or explicit length limits — as this has far greater impact on reducing latency than optimizing input size.
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Disable Reasoning
|
||||||
|
|
||||||
|
Disabling the **Reasoning** toggle will reduce the LLM's thinking time. For a model like Qwen3, you also need to add `/no_think` to the system prompt to disable reasoning.
|
||||||
|
|
||||||
|
## Disable Rerank model
|
||||||
|
|
||||||
|
- Leaving the **Rerank model** field empty (in the corresponding **Retrieval** component) will significantly decrease retrieval time.
|
||||||
|
- When using a rerank model, ensure you have a GPU for acceleration; otherwise, the reranking process will be *prohibitively* slow.
|
||||||
|
|
||||||
|
:::tip NOTE
|
||||||
|
Please note that rerank models are essential in certain scenarios. There is always a trade-off between speed and performance; you must weigh the pros against cons for your specific case.
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Check the time taken for each task
|
||||||
|
|
||||||
|
Click the light bulb icon above the *current* dialogue and scroll down the popup window to view the time taken for each task:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
| Item name | Description |
|
||||||
|
| ----------------- | --------------------------------------------------------------------------------------------- |
|
||||||
|
| Total | Total time spent on this conversation round, including chunk retrieval and answer generation. |
|
||||||
|
| Check LLM | Time to validate the specified LLM. |
|
||||||
|
| Create retriever | Time to create a chunk retriever. |
|
||||||
|
| Bind embedding | Time to initialize an embedding model instance. |
|
||||||
|
| Bind LLM | Time to initialize an LLM instance. |
|
||||||
|
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. |
|
||||||
|
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
||||||
|
| Generate keywords | Time to extract keywords from the user query. |
|
||||||
|
| Retrieval | Time to retrieve the chunks. |
|
||||||
|
| Generate answer | Time to generate the answer. |
|
||||||
@ -22,7 +22,7 @@ When debugging your chat assistant, you can use AI search as a reference to veri
|
|||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
- Ensure that you have configured the system's default models on the **Model providers** page.
|
- Ensure that you have configured the system's default models on the **Model providers** page.
|
||||||
- Ensure that the intended knowledge bases are properly configured and the intended documents have finished file parsing.
|
- Ensure that the intended datasets are properly configured and the intended documents have finished file parsing.
|
||||||
|
|
||||||
## Frequently asked questions
|
## Frequently asked questions
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user