mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-04 03:25:30 +08:00
Compare commits
84 Commits
v0.20.5
...
b0b866c8fd
| Author | SHA1 | Date | |
|---|---|---|---|
| b0b866c8fd | |||
| 3a831d0c28 | |||
| 9e323a9351 | |||
| 7ac95b759b | |||
| daea357940 | |||
| 4aa1abd8e5 | |||
| 922b5c652d | |||
| aaa97874c6 | |||
| 193d93d820 | |||
| 4058715df7 | |||
| 3f595029d7 | |||
| e8f5a4da56 | |||
| a9472e3652 | |||
| 4dd48b60f3 | |||
| e4ab8ba2de | |||
| a1f848bfe0 | |||
| f2309ff93e | |||
| 38be53cf31 | |||
| 65a06d62d8 | |||
| 10cbbb76f8 | |||
| 1c84d1b562 | |||
| 4eb7659499 | |||
| 46a61e5aff | |||
| da82566304 | |||
| c8b79dfed4 | |||
| da80fa40bc | |||
| 94dbd4aac9 | |||
| ca9f30e1a1 | |||
| 2e4295d5ca | |||
| d11b1628a1 | |||
| 45f9f428db | |||
| 902703d145 | |||
| 7ccca2143c | |||
| 70ce02faf4 | |||
| 3f1741c8c6 | |||
| 6c24ad7966 | |||
| 4846589599 | |||
| a24547aa66 | |||
| a04c5247ab | |||
| ed6a76dcc0 | |||
| a0ccbec8bd | |||
| 4693c5382a | |||
| ff3b4d0dcd | |||
| 62d35b1b73 | |||
| 91b609447d | |||
| c353840244 | |||
| f12b9fdcd4 | |||
| 80ede65bbe | |||
| 52cf186028 | |||
| ea0f1d47a5 | |||
| 9fe7c92217 | |||
| d353f7f7f8 | |||
| f3738b06f1 | |||
| 5a8bc88147 | |||
| 04ef5b2783 | |||
| c9ea22ef69 | |||
| 152111fd9d | |||
| 86f6da2f74 | |||
| 8c00cbc87a | |||
| 41e808f4e6 | |||
| bc0281040b | |||
| 341a7b1473 | |||
| c29c395390 | |||
| a23a0f230c | |||
| 2a88ce6be1 | |||
| 664b781d62 | |||
| 65571e5254 | |||
| aa30f20730 | |||
| b9b278d441 | |||
| e1d86cfee3 | |||
| 8ebd07337f | |||
| dd584d57b0 | |||
| 3d39b96c6f | |||
| 179091b1a4 | |||
| d14d92a900 | |||
| 1936ad82d2 | |||
| 8a09f07186 | |||
| df8d31451b | |||
| fc95d113c3 | |||
| 7d14455fbe | |||
| bbe6ed3b90 | |||
| 127af4e45c | |||
| 41cdba19ba | |||
| 0d9c1f1c3c |
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@ -88,7 +88,9 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
push: true
|
push: true
|
||||||
tags: infiniflow/ragflow:${{ env.RELEASE_TAG }}
|
tags: |
|
||||||
|
infiniflow/ragflow:${{ env.RELEASE_TAG }}
|
||||||
|
infiniflow/ragflow:latest-full
|
||||||
file: Dockerfile
|
file: Dockerfile
|
||||||
platforms: linux/amd64
|
platforms: linux/amd64
|
||||||
|
|
||||||
@ -98,7 +100,9 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
push: true
|
push: true
|
||||||
tags: infiniflow/ragflow:${{ env.RELEASE_TAG }}-slim
|
tags: |
|
||||||
|
infiniflow/ragflow:${{ env.RELEASE_TAG }}-slim
|
||||||
|
infiniflow/ragflow:latest-slim
|
||||||
file: Dockerfile
|
file: Dockerfile
|
||||||
build-args: LIGHTEN=1
|
build-args: LIGHTEN=1
|
||||||
platforms: linux/amd64
|
platforms: linux/amd64
|
||||||
|
|||||||
101
admin/README.md
Normal file
101
admin/README.md
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# RAGFlow Admin Service & CLI
|
||||||
|
|
||||||
|
### Introduction
|
||||||
|
|
||||||
|
Admin Service is a dedicated management component designed to monitor, maintain, and administrate the RAGFlow system. It provides comprehensive tools for ensuring system stability, performing operational tasks, and managing users and permissions efficiently.
|
||||||
|
|
||||||
|
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
|
||||||
|
|
||||||
|
For user and system management, it supports listing, creating, modifying, and deleting users and their associated resources like knowledge bases and Agents.
|
||||||
|
|
||||||
|
Built with scalability and reliability in mind, the Admin Service ensures smooth system operation and simplifies maintenance workflows.
|
||||||
|
|
||||||
|
It consists of a server-side Service and a command-line client (CLI), both implemented in Python. User commands are parsed using the Lark parsing toolkit.
|
||||||
|
|
||||||
|
- **Admin Service**: A backend service that interfaces with the RAGFlow system to execute administrative operations and monitor its status.
|
||||||
|
- **Admin CLI**: A command-line interface that allows users to connect to the Admin Service and issue commands for system management.
|
||||||
|
|
||||||
|
### Starting the Admin Service
|
||||||
|
|
||||||
|
1. Before start Admin Service, please make sure RAGFlow system is already started.
|
||||||
|
|
||||||
|
2. Run the service script:
|
||||||
|
```bash
|
||||||
|
python admin/admin_server.py
|
||||||
|
```
|
||||||
|
The service will start and listen for incoming connections from the CLI on the configured port.
|
||||||
|
|
||||||
|
### Using the Admin CLI
|
||||||
|
|
||||||
|
1. Ensure the Admin Service is running.
|
||||||
|
2. Launch the CLI client:
|
||||||
|
```bash
|
||||||
|
python admin/admin_client.py -h 0.0.0.0 -p 9381
|
||||||
|
|
||||||
|
## Supported Commands
|
||||||
|
|
||||||
|
Commands are case-insensitive and must be terminated with a semicolon (`;`).
|
||||||
|
|
||||||
|
### Service Management Commands
|
||||||
|
|
||||||
|
- `LIST SERVICES;`
|
||||||
|
- Lists all available services within the RAGFlow system.
|
||||||
|
- `SHOW SERVICE <id>;`
|
||||||
|
- Shows detailed status information for the service identified by `<id>`.
|
||||||
|
- `STARTUP SERVICE <id>;`
|
||||||
|
- Attempts to start the service identified by `<id>`.
|
||||||
|
- `SHUTDOWN SERVICE <id>;`
|
||||||
|
- Attempts to gracefully shut down the service identified by `<id>`.
|
||||||
|
- `RESTART SERVICE <id>;`
|
||||||
|
- Attempts to restart the service identified by `<id>`.
|
||||||
|
|
||||||
|
### User Management Commands
|
||||||
|
|
||||||
|
- `LIST USERS;`
|
||||||
|
- Lists all users known to the system.
|
||||||
|
- `SHOW USER '<username>';`
|
||||||
|
- Shows details and permissions for the specified user. The username must be enclosed in single or double quotes.
|
||||||
|
- `DROP USER '<username>';`
|
||||||
|
- Removes the specified user from the system. Use with caution.
|
||||||
|
- `ALTER USER PASSWORD '<username>' '<new_password>';`
|
||||||
|
- Changes the password for the specified user.
|
||||||
|
|
||||||
|
### Data and Agent Commands
|
||||||
|
|
||||||
|
- `LIST DATASETS OF '<username>';`
|
||||||
|
- Lists the datasets associated with the specified user.
|
||||||
|
- `LIST AGENTS OF '<username>';`
|
||||||
|
- Lists the agents associated with the specified user.
|
||||||
|
|
||||||
|
### Meta-Commands
|
||||||
|
|
||||||
|
Meta-commands are prefixed with a backslash (`\`).
|
||||||
|
|
||||||
|
- `\?` or `\help`
|
||||||
|
- Shows help information for the available commands.
|
||||||
|
- `\q` or `\quit`
|
||||||
|
- Exits the CLI application.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
```commandline
|
||||||
|
admin> list users;
|
||||||
|
+-------------------------------+------------------------+-----------+-------------+
|
||||||
|
| create_date | email | is_active | nickname |
|
||||||
|
+-------------------------------+------------------------+-----------+-------------+
|
||||||
|
| Fri, 22 Nov 2024 16:03:41 GMT | jeffery@infiniflow.org | 1 | Jeffery |
|
||||||
|
| Fri, 22 Nov 2024 16:10:55 GMT | aya@infiniflow.org | 1 | Waterdancer |
|
||||||
|
+-------------------------------+------------------------+-----------+-------------+
|
||||||
|
|
||||||
|
admin> list services;
|
||||||
|
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
|
||||||
|
| extra | host | id | name | port | service_type |
|
||||||
|
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
|
||||||
|
| {} | 0.0.0.0 | 0 | ragflow_0 | 9380 | ragflow_server |
|
||||||
|
| {'meta_type': 'mysql', 'password': 'infini_rag_flow', 'username': 'root'} | localhost | 1 | mysql | 5455 | meta_data |
|
||||||
|
| {'password': 'infini_rag_flow', 'store_type': 'minio', 'user': 'rag_flow'} | localhost | 2 | minio | 9000 | file_store |
|
||||||
|
| {'password': 'infini_rag_flow', 'retrieval_type': 'elasticsearch', 'username': 'elastic'} | localhost | 3 | elasticsearch | 1200 | retrieval |
|
||||||
|
| {'db_name': 'default_db', 'retrieval_type': 'infinity'} | localhost | 4 | infinity | 23817 | retrieval |
|
||||||
|
| {'database': 1, 'mq_type': 'redis', 'password': 'infini_rag_flow'} | localhost | 5 | redis | 6379 | message_queue |
|
||||||
|
+-------------------------------------------------------------------------------------------+-----------+----+---------------+-------+----------------+
|
||||||
|
```
|
||||||
567
admin/admin_client.py
Normal file
567
admin/admin_client.py
Normal file
@ -0,0 +1,567 @@
|
|||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
from Cryptodome.PublicKey import RSA
|
||||||
|
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||||
|
from typing import Dict, List, Any
|
||||||
|
from lark import Lark, Transformer, Tree
|
||||||
|
import requests
|
||||||
|
from requests.auth import HTTPBasicAuth
|
||||||
|
|
||||||
|
GRAMMAR = r"""
|
||||||
|
start: command
|
||||||
|
|
||||||
|
command: sql_command | meta_command
|
||||||
|
|
||||||
|
sql_command: list_services
|
||||||
|
| show_service
|
||||||
|
| startup_service
|
||||||
|
| shutdown_service
|
||||||
|
| restart_service
|
||||||
|
| list_users
|
||||||
|
| show_user
|
||||||
|
| drop_user
|
||||||
|
| alter_user
|
||||||
|
| create_user
|
||||||
|
| activate_user
|
||||||
|
| list_datasets
|
||||||
|
| list_agents
|
||||||
|
|
||||||
|
// meta command definition
|
||||||
|
meta_command: "\\" meta_command_name [meta_args]
|
||||||
|
|
||||||
|
meta_command_name: /[a-zA-Z?]+/
|
||||||
|
meta_args: (meta_arg)+
|
||||||
|
|
||||||
|
meta_arg: /[^\\s"']+/ | quoted_string
|
||||||
|
|
||||||
|
// command definition
|
||||||
|
|
||||||
|
LIST: "LIST"i
|
||||||
|
SERVICES: "SERVICES"i
|
||||||
|
SHOW: "SHOW"i
|
||||||
|
CREATE: "CREATE"i
|
||||||
|
SERVICE: "SERVICE"i
|
||||||
|
SHUTDOWN: "SHUTDOWN"i
|
||||||
|
STARTUP: "STARTUP"i
|
||||||
|
RESTART: "RESTART"i
|
||||||
|
USERS: "USERS"i
|
||||||
|
DROP: "DROP"i
|
||||||
|
USER: "USER"i
|
||||||
|
ALTER: "ALTER"i
|
||||||
|
ACTIVE: "ACTIVE"i
|
||||||
|
PASSWORD: "PASSWORD"i
|
||||||
|
DATASETS: "DATASETS"i
|
||||||
|
OF: "OF"i
|
||||||
|
AGENTS: "AGENTS"i
|
||||||
|
|
||||||
|
list_services: LIST SERVICES ";"
|
||||||
|
show_service: SHOW SERVICE NUMBER ";"
|
||||||
|
startup_service: STARTUP SERVICE NUMBER ";"
|
||||||
|
shutdown_service: SHUTDOWN SERVICE NUMBER ";"
|
||||||
|
restart_service: RESTART SERVICE NUMBER ";"
|
||||||
|
|
||||||
|
list_users: LIST USERS ";"
|
||||||
|
drop_user: DROP USER quoted_string ";"
|
||||||
|
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
|
||||||
|
show_user: SHOW USER quoted_string ";"
|
||||||
|
create_user: CREATE USER quoted_string quoted_string ";"
|
||||||
|
activate_user: ALTER USER ACTIVE quoted_string status ";"
|
||||||
|
|
||||||
|
list_datasets: LIST DATASETS OF quoted_string ";"
|
||||||
|
list_agents: LIST AGENTS OF quoted_string ";"
|
||||||
|
|
||||||
|
identifier: WORD
|
||||||
|
quoted_string: QUOTED_STRING
|
||||||
|
status: WORD
|
||||||
|
|
||||||
|
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
|
||||||
|
WORD: /[a-zA-Z0-9_\-\.]+/
|
||||||
|
NUMBER: /[0-9]+/
|
||||||
|
|
||||||
|
%import common.WS
|
||||||
|
%ignore WS
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class AdminTransformer(Transformer):
|
||||||
|
|
||||||
|
def start(self, items):
|
||||||
|
return items[0]
|
||||||
|
|
||||||
|
def command(self, items):
|
||||||
|
return items[0]
|
||||||
|
|
||||||
|
def list_services(self, items):
|
||||||
|
result = {'type': 'list_services'}
|
||||||
|
return result
|
||||||
|
|
||||||
|
def show_service(self, items):
|
||||||
|
service_id = int(items[2])
|
||||||
|
return {"type": "show_service", "number": service_id}
|
||||||
|
|
||||||
|
def startup_service(self, items):
|
||||||
|
service_id = int(items[2])
|
||||||
|
return {"type": "startup_service", "number": service_id}
|
||||||
|
|
||||||
|
def shutdown_service(self, items):
|
||||||
|
service_id = int(items[2])
|
||||||
|
return {"type": "shutdown_service", "number": service_id}
|
||||||
|
|
||||||
|
def restart_service(self, items):
|
||||||
|
service_id = int(items[2])
|
||||||
|
return {"type": "restart_service", "number": service_id}
|
||||||
|
|
||||||
|
def list_users(self, items):
|
||||||
|
return {"type": "list_users"}
|
||||||
|
|
||||||
|
def show_user(self, items):
|
||||||
|
user_name = items[2]
|
||||||
|
return {"type": "show_user", "username": user_name}
|
||||||
|
|
||||||
|
def drop_user(self, items):
|
||||||
|
user_name = items[2]
|
||||||
|
return {"type": "drop_user", "username": user_name}
|
||||||
|
|
||||||
|
def alter_user(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
new_password = items[4]
|
||||||
|
return {"type": "alter_user", "username": user_name, "password": new_password}
|
||||||
|
|
||||||
|
def create_user(self, items):
|
||||||
|
user_name = items[2]
|
||||||
|
password = items[3]
|
||||||
|
return {"type": "create_user", "username": user_name, "password": password, "role": "user"}
|
||||||
|
|
||||||
|
def activate_user(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
activate_status = items[4]
|
||||||
|
return {"type": "activate_user", "activate_status": activate_status, "username": user_name}
|
||||||
|
|
||||||
|
def list_datasets(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
return {"type": "list_datasets", "username": user_name}
|
||||||
|
|
||||||
|
def list_agents(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
return {"type": "list_agents", "username": user_name}
|
||||||
|
|
||||||
|
def meta_command(self, items):
|
||||||
|
command_name = str(items[0]).lower()
|
||||||
|
args = items[1:] if len(items) > 1 else []
|
||||||
|
|
||||||
|
# handle quoted parameter
|
||||||
|
parsed_args = []
|
||||||
|
for arg in args:
|
||||||
|
if hasattr(arg, 'value'):
|
||||||
|
parsed_args.append(arg.value)
|
||||||
|
else:
|
||||||
|
parsed_args.append(str(arg))
|
||||||
|
|
||||||
|
return {'type': 'meta', 'command': command_name, 'args': parsed_args}
|
||||||
|
|
||||||
|
def meta_command_name(self, items):
|
||||||
|
return items[0]
|
||||||
|
|
||||||
|
def meta_args(self, items):
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
def encode_to_base64(input_string):
|
||||||
|
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
||||||
|
return base64_encoded.decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt(input_string):
|
||||||
|
pub = '-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----'
|
||||||
|
pub_key = RSA.importKey(pub)
|
||||||
|
cipher = Cipher_pkcs1_v1_5.new(pub_key)
|
||||||
|
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode('utf-8')))
|
||||||
|
return base64.b64encode(cipher_text).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class AdminCommandParser:
|
||||||
|
def __init__(self):
|
||||||
|
self.parser = Lark(GRAMMAR, start='start', parser='lalr', transformer=AdminTransformer())
|
||||||
|
self.command_history = []
|
||||||
|
|
||||||
|
def parse_command(self, command_str: str) -> Dict[str, Any]:
|
||||||
|
if not command_str.strip():
|
||||||
|
return {'type': 'empty'}
|
||||||
|
|
||||||
|
self.command_history.append(command_str)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.parser.parse(command_str)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
return {'type': 'error', 'message': f'Parse error: {str(e)}'}
|
||||||
|
|
||||||
|
|
||||||
|
class AdminCLI:
|
||||||
|
def __init__(self):
|
||||||
|
self.parser = AdminCommandParser()
|
||||||
|
self.is_interactive = False
|
||||||
|
self.admin_account = "admin@ragflow.io"
|
||||||
|
self.admin_password: str = "admin"
|
||||||
|
self.host: str = ""
|
||||||
|
self.port: int = 0
|
||||||
|
|
||||||
|
def verify_admin(self, args):
|
||||||
|
|
||||||
|
conn_info = self._parse_connection_args(args)
|
||||||
|
if 'error' in conn_info:
|
||||||
|
print(f"Error: {conn_info['error']}")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.host = conn_info['host']
|
||||||
|
self.port = conn_info['port']
|
||||||
|
print(f"Attempt to access ip: {self.host}, port: {self.port}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/auth'
|
||||||
|
|
||||||
|
try_count = 0
|
||||||
|
while True:
|
||||||
|
try_count += 1
|
||||||
|
if try_count > 3:
|
||||||
|
return False
|
||||||
|
|
||||||
|
admin_passwd = input(f"password for {self.admin_account}: ").strip()
|
||||||
|
try:
|
||||||
|
self.admin_password = encode_to_base64(admin_passwd)
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
if response.status_code == 200:
|
||||||
|
res_json = response.json()
|
||||||
|
error_code = res_json.get('code', -1)
|
||||||
|
if error_code == 0:
|
||||||
|
print("Authentication successful.")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
error_message = res_json.get('message', 'Unknown error')
|
||||||
|
print(f"Authentication failed: {error_message}, try again")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
print(f"Bad response,status: {response.status_code}, try again")
|
||||||
|
except Exception:
|
||||||
|
print(f"Can't access {self.host}, port: {self.port}")
|
||||||
|
|
||||||
|
def _print_table_simple(self, data):
|
||||||
|
if not data:
|
||||||
|
print("No data to print")
|
||||||
|
return
|
||||||
|
if isinstance(data, dict):
|
||||||
|
# handle single row data
|
||||||
|
data = [data]
|
||||||
|
|
||||||
|
columns = list(data[0].keys())
|
||||||
|
col_widths = {}
|
||||||
|
|
||||||
|
for col in columns:
|
||||||
|
max_width = len(str(col))
|
||||||
|
for item in data:
|
||||||
|
value_len = len(str(item.get(col, '')))
|
||||||
|
if value_len > max_width:
|
||||||
|
max_width = value_len
|
||||||
|
col_widths[col] = max(2, max_width)
|
||||||
|
|
||||||
|
# Generate delimiter
|
||||||
|
separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
|
||||||
|
|
||||||
|
# Print header
|
||||||
|
print(separator)
|
||||||
|
header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
|
||||||
|
print(header)
|
||||||
|
print(separator)
|
||||||
|
|
||||||
|
# Print data
|
||||||
|
for item in data:
|
||||||
|
row = "|"
|
||||||
|
for col in columns:
|
||||||
|
value = str(item.get(col, ''))
|
||||||
|
if len(value) > col_widths[col]:
|
||||||
|
value = value[:col_widths[col] - 3] + "..."
|
||||||
|
row += f" {value:<{col_widths[col]}} |"
|
||||||
|
print(row)
|
||||||
|
|
||||||
|
print(separator)
|
||||||
|
|
||||||
|
def run_interactive(self):
|
||||||
|
|
||||||
|
self.is_interactive = True
|
||||||
|
print("RAGFlow Admin command line interface - Type '\\?' for help, '\\q' to quit")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
command = input("admin> ").strip()
|
||||||
|
if not command:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"command: {command}")
|
||||||
|
result = self.parser.parse_command(command)
|
||||||
|
self.execute_command(result)
|
||||||
|
|
||||||
|
if isinstance(result, Tree):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if result.get('type') == 'meta' and result.get('command') in ['q', 'quit', 'exit']:
|
||||||
|
break
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nUse '\\q' to quit")
|
||||||
|
except EOFError:
|
||||||
|
print("\nGoodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
def run_single_command(self, args):
|
||||||
|
conn_info = self._parse_connection_args(args)
|
||||||
|
if 'error' in conn_info:
|
||||||
|
print(f"Error: {conn_info['error']}")
|
||||||
|
return
|
||||||
|
|
||||||
|
def _parse_connection_args(self, args: List[str]) -> Dict[str, Any]:
|
||||||
|
parser = argparse.ArgumentParser(description='Admin CLI Client', add_help=False)
|
||||||
|
parser.add_argument('-h', '--host', default='localhost', help='Admin service host')
|
||||||
|
parser.add_argument('-p', '--port', type=int, default=8080, help='Admin service port')
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_args, remaining_args = parser.parse_known_args(args)
|
||||||
|
return {
|
||||||
|
'host': parsed_args.host,
|
||||||
|
'port': parsed_args.port,
|
||||||
|
}
|
||||||
|
except SystemExit:
|
||||||
|
return {'error': 'Invalid connection arguments'}
|
||||||
|
|
||||||
|
def execute_command(self, parsed_command: Dict[str, Any]):
|
||||||
|
|
||||||
|
command_dict: dict
|
||||||
|
if isinstance(parsed_command, Tree):
|
||||||
|
command_dict = parsed_command.children[0]
|
||||||
|
else:
|
||||||
|
if parsed_command['type'] == 'error':
|
||||||
|
print(f"Error: {parsed_command['message']}")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
command_dict = parsed_command
|
||||||
|
|
||||||
|
# print(f"Parsed command: {command_dict}")
|
||||||
|
|
||||||
|
command_type = command_dict['type']
|
||||||
|
|
||||||
|
match command_type:
|
||||||
|
case 'list_services':
|
||||||
|
self._handle_list_services(command_dict)
|
||||||
|
case 'show_service':
|
||||||
|
self._handle_show_service(command_dict)
|
||||||
|
case 'restart_service':
|
||||||
|
self._handle_restart_service(command_dict)
|
||||||
|
case 'shutdown_service':
|
||||||
|
self._handle_shutdown_service(command_dict)
|
||||||
|
case 'startup_service':
|
||||||
|
self._handle_startup_service(command_dict)
|
||||||
|
case 'list_users':
|
||||||
|
self._handle_list_users(command_dict)
|
||||||
|
case 'show_user':
|
||||||
|
self._handle_show_user(command_dict)
|
||||||
|
case 'drop_user':
|
||||||
|
self._handle_drop_user(command_dict)
|
||||||
|
case 'alter_user':
|
||||||
|
self._handle_alter_user(command_dict)
|
||||||
|
case 'create_user':
|
||||||
|
self._handle_create_user(command_dict)
|
||||||
|
case 'activate_user':
|
||||||
|
self._handle_activate_user(command_dict)
|
||||||
|
case 'list_datasets':
|
||||||
|
self._handle_list_datasets(command_dict)
|
||||||
|
case 'list_agents':
|
||||||
|
self._handle_list_agents(command_dict)
|
||||||
|
case 'meta':
|
||||||
|
self._handle_meta_command(command_dict)
|
||||||
|
case _:
|
||||||
|
print(f"Command '{command_type}' would be executed with API")
|
||||||
|
|
||||||
|
def _handle_list_services(self, command):
|
||||||
|
print("Listing all services")
|
||||||
|
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/services'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_show_service(self, command):
|
||||||
|
service_id: int = command['number']
|
||||||
|
print(f"Showing service: {service_id}")
|
||||||
|
|
||||||
|
def _handle_restart_service(self, command):
|
||||||
|
service_id: int = command['number']
|
||||||
|
print(f"Restart service {service_id}")
|
||||||
|
|
||||||
|
def _handle_shutdown_service(self, command):
|
||||||
|
service_id: int = command['number']
|
||||||
|
print(f"Shutdown service {service_id}")
|
||||||
|
|
||||||
|
def _handle_startup_service(self, command):
|
||||||
|
service_id: int = command['number']
|
||||||
|
print(f"Startup service {service_id}")
|
||||||
|
|
||||||
|
def _handle_list_users(self, command):
|
||||||
|
print("Listing all users")
|
||||||
|
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_show_user(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
print(f"Showing user: {username}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get user {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_drop_user(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
print(f"Drop user: {username}")
|
||||||
|
|
||||||
|
def _handle_alter_user(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
password_tree: Tree = command['password']
|
||||||
|
password: str = password_tree.children[0].strip("'\"")
|
||||||
|
print(f"Alter user: {username}, password: {password}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/password'
|
||||||
|
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password), json={'new_password': encrypt(password)})
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(res_json["message"])
|
||||||
|
else:
|
||||||
|
print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_create_user(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
password_tree: Tree = command['password']
|
||||||
|
password: str = password_tree.children[0].strip("'\"")
|
||||||
|
role: str = command['role']
|
||||||
|
print(f"Create user: {username}, password: {password}, role: {role}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||||
|
json={'username': username, 'password': encrypt(password), 'role': role}
|
||||||
|
)
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to create user {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_activate_user(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
activate_tree: Tree = command['activate_status']
|
||||||
|
activate_status: str = activate_tree.children[0].strip("'\"")
|
||||||
|
if activate_status.lower() in ['on', 'off']:
|
||||||
|
print(f"Alter user {username} activate status, turn {activate_status.lower()}.")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/activate'
|
||||||
|
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password), json={'activate_status': activate_status})
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(res_json["message"])
|
||||||
|
else:
|
||||||
|
print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
else:
|
||||||
|
print(f"Unknown activate status: {activate_status}.")
|
||||||
|
|
||||||
|
def _handle_list_datasets(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
print(f"Listing all datasets of user: {username}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/datasets'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get all datasets of {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_list_agents(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
print(f"Listing all agents of user: {username}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/agents'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get all agents of {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_meta_command(self, command):
|
||||||
|
meta_command = command['command']
|
||||||
|
args = command.get('args', [])
|
||||||
|
|
||||||
|
if meta_command in ['?', 'h', 'help']:
|
||||||
|
self.show_help()
|
||||||
|
elif meta_command in ['q', 'quit', 'exit']:
|
||||||
|
print("Goodbye!")
|
||||||
|
else:
|
||||||
|
print(f"Meta command '{meta_command}' with args {args}")
|
||||||
|
|
||||||
|
def show_help(self):
|
||||||
|
"""Help info"""
|
||||||
|
help_text = """
|
||||||
|
Commands:
|
||||||
|
LIST SERVICES
|
||||||
|
SHOW SERVICE <service>
|
||||||
|
STARTUP SERVICE <service>
|
||||||
|
SHUTDOWN SERVICE <service>
|
||||||
|
RESTART SERVICE <service>
|
||||||
|
LIST USERS
|
||||||
|
SHOW USER <user>
|
||||||
|
DROP USER <user>
|
||||||
|
CREATE USER <user> <password>
|
||||||
|
ALTER USER PASSWORD <user> <new_password>
|
||||||
|
LIST DATASETS OF <user>
|
||||||
|
LIST AGENTS OF <user>
|
||||||
|
|
||||||
|
Meta Commands:
|
||||||
|
\\?, \\h, \\help Show this help
|
||||||
|
\\q, \\quit, \\exit Quit the CLI
|
||||||
|
"""
|
||||||
|
print(help_text)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import sys
|
||||||
|
|
||||||
|
cli = AdminCLI()
|
||||||
|
|
||||||
|
if len(sys.argv) == 1 or (len(sys.argv) > 1 and sys.argv[1] == '-'):
|
||||||
|
print(r"""
|
||||||
|
____ ___ ______________ ___ __ _
|
||||||
|
/ __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___
|
||||||
|
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \
|
||||||
|
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /
|
||||||
|
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
|
||||||
|
""")
|
||||||
|
if cli.verify_admin(sys.argv):
|
||||||
|
cli.run_interactive()
|
||||||
|
else:
|
||||||
|
if cli.verify_admin(sys.argv):
|
||||||
|
cli.run_interactive()
|
||||||
|
# cli.run_single_command(sys.argv[1:])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
47
admin/admin_server.py
Normal file
47
admin/admin_server.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
from werkzeug.serving import run_simple
|
||||||
|
from flask import Flask
|
||||||
|
from routes import admin_bp
|
||||||
|
from api.utils.log_utils import init_root_logger
|
||||||
|
from api.constants import SERVICE_CONF
|
||||||
|
from api import settings
|
||||||
|
from config import load_configurations, SERVICE_CONFIGS
|
||||||
|
|
||||||
|
stop_event = threading.Event()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
init_root_logger("admin_service")
|
||||||
|
logging.info(r"""
|
||||||
|
____ ___ ______________ ___ __ _
|
||||||
|
/ __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___
|
||||||
|
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \
|
||||||
|
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /
|
||||||
|
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
|
||||||
|
""")
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.register_blueprint(admin_bp)
|
||||||
|
settings.init_settings()
|
||||||
|
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.info("RAGFlow Admin service start...")
|
||||||
|
run_simple(
|
||||||
|
hostname="0.0.0.0",
|
||||||
|
port=9381,
|
||||||
|
application=app,
|
||||||
|
threaded=True,
|
||||||
|
use_reloader=True,
|
||||||
|
use_debugger=True,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
stop_event.set()
|
||||||
|
time.sleep(1)
|
||||||
|
os.kill(os.getpid(), signal.SIGKILL)
|
||||||
57
admin/auth.py
Normal file
57
admin/auth.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from functools import wraps
|
||||||
|
from flask import request, jsonify
|
||||||
|
|
||||||
|
from exceptions import AdminException
|
||||||
|
from api.db.init_data import encode_to_base64
|
||||||
|
from api.db.services import UserService
|
||||||
|
|
||||||
|
|
||||||
|
def check_admin(username: str, password: str):
|
||||||
|
users = UserService.query(email=username)
|
||||||
|
if not users:
|
||||||
|
logging.info(f"Username: {username} is not registered!")
|
||||||
|
user_info = {
|
||||||
|
"id": uuid.uuid1().hex,
|
||||||
|
"password": encode_to_base64("admin"),
|
||||||
|
"nickname": "admin",
|
||||||
|
"is_superuser": True,
|
||||||
|
"email": "admin@ragflow.io",
|
||||||
|
"creator": "system",
|
||||||
|
"status": "1",
|
||||||
|
}
|
||||||
|
if not UserService.save(**user_info):
|
||||||
|
raise AdminException("Can't init admin.", 500)
|
||||||
|
|
||||||
|
user = UserService.query_user(username, password)
|
||||||
|
if user:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def login_verify(f):
|
||||||
|
@wraps(f)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
auth = request.authorization
|
||||||
|
if not auth or 'username' not in auth.parameters or 'password' not in auth.parameters:
|
||||||
|
return jsonify({
|
||||||
|
"code": 401,
|
||||||
|
"message": "Authentication required",
|
||||||
|
"data": None
|
||||||
|
}), 200
|
||||||
|
|
||||||
|
username = auth.parameters['username']
|
||||||
|
password = auth.parameters['password']
|
||||||
|
# TODO: to check the username and password from DB
|
||||||
|
if check_admin(username, password) is False:
|
||||||
|
return jsonify({
|
||||||
|
"code": 403,
|
||||||
|
"message": "Access denied",
|
||||||
|
"data": None
|
||||||
|
}), 200
|
||||||
|
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
280
admin/config.py
Normal file
280
admin/config.py
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Any
|
||||||
|
from api.utils import read_config
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceConfigs:
|
||||||
|
def __init__(self):
|
||||||
|
self.configs = []
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
SERVICE_CONFIGS = ServiceConfigs
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceType(Enum):
|
||||||
|
METADATA = "metadata"
|
||||||
|
RETRIEVAL = "retrieval"
|
||||||
|
MESSAGE_QUEUE = "message_queue"
|
||||||
|
RAGFLOW_SERVER = "ragflow_server"
|
||||||
|
TASK_EXECUTOR = "task_executor"
|
||||||
|
FILE_STORE = "file_store"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConfig(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
service_type: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port, 'service_type': self.service_type}
|
||||||
|
|
||||||
|
|
||||||
|
class MetaConfig(BaseConfig):
|
||||||
|
meta_type: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['meta_type'] = self.meta_type
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class MySQLConfig(MetaConfig):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['username'] = self.username
|
||||||
|
extra_dict['password'] = self.password
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresConfig(MetaConfig):
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalConfig(BaseConfig):
|
||||||
|
retrieval_type: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['retrieval_type'] = self.retrieval_type
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class InfinityConfig(RetrievalConfig):
|
||||||
|
db_name: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['db_name'] = self.db_name
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticsearchConfig(RetrievalConfig):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['username'] = self.username
|
||||||
|
extra_dict['password'] = self.password
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class MessageQueueConfig(BaseConfig):
|
||||||
|
mq_type: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['mq_type'] = self.mq_type
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class RedisConfig(MessageQueueConfig):
|
||||||
|
database: int
|
||||||
|
password: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['database'] = self.database
|
||||||
|
extra_dict['password'] = self.password
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class RabbitMQConfig(MessageQueueConfig):
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class RAGFlowServerConfig(BaseConfig):
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TaskExecutorConfig(BaseConfig):
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class FileStoreConfig(BaseConfig):
|
||||||
|
store_type: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['store_type'] = self.store_type
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class MinioConfig(FileStoreConfig):
|
||||||
|
user: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if 'extra' not in result:
|
||||||
|
result['extra'] = dict()
|
||||||
|
extra_dict = result['extra'].copy()
|
||||||
|
extra_dict['user'] = self.user
|
||||||
|
extra_dict['password'] = self.password
|
||||||
|
result['extra'] = extra_dict
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def load_configurations(config_path: str) -> list[BaseConfig]:
|
||||||
|
raw_configs = read_config(config_path)
|
||||||
|
configurations = []
|
||||||
|
ragflow_count = 0
|
||||||
|
id_count = 0
|
||||||
|
for k, v in raw_configs.items():
|
||||||
|
match (k):
|
||||||
|
case "ragflow":
|
||||||
|
name: str = f'ragflow_{ragflow_count}'
|
||||||
|
host: str = v['host']
|
||||||
|
http_port: int = v['http_port']
|
||||||
|
config = RAGFlowServerConfig(id=id_count, name=name, host=host, port=http_port, service_type="ragflow_server")
|
||||||
|
configurations.append(config)
|
||||||
|
id_count += 1
|
||||||
|
case "es":
|
||||||
|
name: str = 'elasticsearch'
|
||||||
|
url = v['hosts']
|
||||||
|
parsed = urlparse(url)
|
||||||
|
host: str = parsed.hostname
|
||||||
|
port: int = parsed.port
|
||||||
|
username: str = v.get('username')
|
||||||
|
password: str = v.get('password')
|
||||||
|
config = ElasticsearchConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval",
|
||||||
|
retrieval_type="elasticsearch",
|
||||||
|
username=username, password=password)
|
||||||
|
configurations.append(config)
|
||||||
|
id_count += 1
|
||||||
|
|
||||||
|
case "infinity":
|
||||||
|
name: str = 'infinity'
|
||||||
|
url = v['uri']
|
||||||
|
parts = url.split(':', 1)
|
||||||
|
host = parts[0]
|
||||||
|
port = int(parts[1])
|
||||||
|
database: str = v.get('db_name', 'default_db')
|
||||||
|
config = InfinityConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval", retrieval_type="infinity",
|
||||||
|
db_name=database)
|
||||||
|
configurations.append(config)
|
||||||
|
id_count += 1
|
||||||
|
case "minio":
|
||||||
|
name: str = 'minio'
|
||||||
|
url = v['host']
|
||||||
|
parts = url.split(':', 1)
|
||||||
|
host = parts[0]
|
||||||
|
port = int(parts[1])
|
||||||
|
user = v.get('user')
|
||||||
|
password = v.get('password')
|
||||||
|
config = MinioConfig(id=id_count, name=name, host=host, port=port, user=user, password=password, service_type="file_store",
|
||||||
|
store_type="minio")
|
||||||
|
configurations.append(config)
|
||||||
|
id_count += 1
|
||||||
|
case "redis":
|
||||||
|
name: str = 'redis'
|
||||||
|
url = v['host']
|
||||||
|
parts = url.split(':', 1)
|
||||||
|
host = parts[0]
|
||||||
|
port = int(parts[1])
|
||||||
|
password = v.get('password')
|
||||||
|
db: int = v.get('db')
|
||||||
|
config = RedisConfig(id=id_count, name=name, host=host, port=port, password=password, database=db,
|
||||||
|
service_type="message_queue", mq_type="redis")
|
||||||
|
configurations.append(config)
|
||||||
|
id_count += 1
|
||||||
|
case "mysql":
|
||||||
|
name: str = 'mysql'
|
||||||
|
host: str = v.get('host')
|
||||||
|
port: int = v.get('port')
|
||||||
|
username = v.get('user')
|
||||||
|
password = v.get('password')
|
||||||
|
config = MySQLConfig(id=id_count, name=name, host=host, port=port, username=username, password=password,
|
||||||
|
service_type="meta_data", meta_type="mysql")
|
||||||
|
configurations.append(config)
|
||||||
|
id_count += 1
|
||||||
|
case "admin":
|
||||||
|
pass
|
||||||
|
case _:
|
||||||
|
logging.warning(f"Unknown configuration key: {k}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return configurations
|
||||||
17
admin/exceptions.py
Normal file
17
admin/exceptions.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
class AdminException(Exception):
|
||||||
|
def __init__(self, message, code=400):
|
||||||
|
super().__init__(message)
|
||||||
|
self.code = code
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
class UserNotFoundError(AdminException):
|
||||||
|
def __init__(self, username):
|
||||||
|
super().__init__(f"User '{username}' not found", 404)
|
||||||
|
|
||||||
|
class UserAlreadyExistsError(AdminException):
|
||||||
|
def __init__(self, username):
|
||||||
|
super().__init__(f"User '{username}' already exists", 409)
|
||||||
|
|
||||||
|
class CannotDeleteAdminError(AdminException):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("Cannot delete admin account", 403)
|
||||||
0
admin/models.py
Normal file
0
admin/models.py
Normal file
15
admin/responses.py
Normal file
15
admin/responses.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from flask import jsonify
|
||||||
|
|
||||||
|
def success_response(data=None, message="Success", code = 0):
|
||||||
|
return jsonify({
|
||||||
|
"code": code,
|
||||||
|
"message": message,
|
||||||
|
"data": data
|
||||||
|
}), 200
|
||||||
|
|
||||||
|
def error_response(message="Error", code=-1, data=None):
|
||||||
|
return jsonify({
|
||||||
|
"code": code,
|
||||||
|
"message": message,
|
||||||
|
"data": data
|
||||||
|
}), 400
|
||||||
187
admin/routes.py
Normal file
187
admin/routes.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
from flask import Blueprint, request
|
||||||
|
|
||||||
|
from auth import login_verify
|
||||||
|
from responses import success_response, error_response
|
||||||
|
from services import UserMgr, ServiceMgr, UserServiceMgr
|
||||||
|
from exceptions import AdminException
|
||||||
|
|
||||||
|
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/auth', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def auth_admin():
|
||||||
|
try:
|
||||||
|
return success_response(None, "Admin is authorized", 0)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def list_users():
|
||||||
|
try:
|
||||||
|
users = UserMgr.get_all_users()
|
||||||
|
return success_response(users, "Get all users", 0)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users', methods=['POST'])
|
||||||
|
@login_verify
|
||||||
|
def create_user():
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
if not data or 'username' not in data or 'password' not in data:
|
||||||
|
return error_response("Username and password are required", 400)
|
||||||
|
|
||||||
|
username = data['username']
|
||||||
|
password = data['password']
|
||||||
|
role = data.get('role', 'user')
|
||||||
|
|
||||||
|
res = UserMgr.create_user(username, password, role)
|
||||||
|
if res["success"]:
|
||||||
|
user_info = res["user_info"]
|
||||||
|
user_info.pop("password") # do not return password
|
||||||
|
return success_response(user_info, "User created successfully")
|
||||||
|
else:
|
||||||
|
return error_response("create user failed")
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>', methods=['DELETE'])
|
||||||
|
@login_verify
|
||||||
|
def delete_user(username):
|
||||||
|
try:
|
||||||
|
UserMgr.delete_user(username)
|
||||||
|
return success_response(None, "User and all data deleted successfully")
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>/password', methods=['PUT'])
|
||||||
|
@login_verify
|
||||||
|
def change_password(username):
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
if not data or 'new_password' not in data:
|
||||||
|
return error_response("New password is required", 400)
|
||||||
|
|
||||||
|
new_password = data['new_password']
|
||||||
|
msg = UserMgr.update_user_password(username, new_password)
|
||||||
|
return success_response(None, msg)
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>/activate', methods=['PUT'])
|
||||||
|
@login_verify
|
||||||
|
def alter_user_activate_status(username):
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
if not data or 'activate_status' not in data:
|
||||||
|
return error_response("Activation status is required", 400)
|
||||||
|
activate_status = data['activate_status']
|
||||||
|
msg = UserMgr.update_user_activate_status(username, activate_status)
|
||||||
|
return success_response(None, msg)
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_user_details(username):
|
||||||
|
try:
|
||||||
|
user_details = UserMgr.get_user_details(username)
|
||||||
|
return success_response(user_details)
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>/datasets', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_user_datasets(username):
|
||||||
|
try:
|
||||||
|
datasets_list = UserServiceMgr.get_user_datasets(username)
|
||||||
|
return success_response(datasets_list)
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>/agents', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_user_agents(username):
|
||||||
|
try:
|
||||||
|
agents_list = UserServiceMgr.get_user_agents(username)
|
||||||
|
return success_response(agents_list)
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/services', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_services():
|
||||||
|
try:
|
||||||
|
services = ServiceMgr.get_all_services()
|
||||||
|
return success_response(services, "Get all services", 0)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/service_types/<service_type>', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_services_by_type(service_type_str):
|
||||||
|
try:
|
||||||
|
services = ServiceMgr.get_services_by_type(service_type_str)
|
||||||
|
return success_response(services)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/services/<service_id>', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_service(service_id):
|
||||||
|
try:
|
||||||
|
services = ServiceMgr.get_service_details(service_id)
|
||||||
|
return success_response(services)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/services/<service_id>', methods=['DELETE'])
|
||||||
|
@login_verify
|
||||||
|
def shutdown_service(service_id):
|
||||||
|
try:
|
||||||
|
services = ServiceMgr.shutdown_service(service_id)
|
||||||
|
return success_response(services)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/services/<service_id>', methods=['PUT'])
|
||||||
|
@login_verify
|
||||||
|
def restart_service(service_id):
|
||||||
|
try:
|
||||||
|
services = ServiceMgr.restart_service(service_id)
|
||||||
|
return success_response(services)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
163
admin/services.py
Normal file
163
admin/services.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
import re
|
||||||
|
from werkzeug.security import check_password_hash
|
||||||
|
from api.db import ActiveEnum
|
||||||
|
from api.db.services import UserService
|
||||||
|
from api.db.joint_services.user_account_service import create_new_user
|
||||||
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
|
from api.db.services.user_service import TenantService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.utils.crypt import decrypt
|
||||||
|
from exceptions import AdminException, UserAlreadyExistsError, UserNotFoundError
|
||||||
|
from config import SERVICE_CONFIGS
|
||||||
|
|
||||||
|
class UserMgr:
|
||||||
|
@staticmethod
|
||||||
|
def get_all_users():
|
||||||
|
users = UserService.get_all_users()
|
||||||
|
result = []
|
||||||
|
for user in users:
|
||||||
|
result.append({'email': user.email, 'nickname': user.nickname, 'create_date': user.create_date, 'is_active': user.is_active})
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_details(username):
|
||||||
|
# use email to query
|
||||||
|
users = UserService.query_user_by_email(username)
|
||||||
|
result = []
|
||||||
|
for user in users:
|
||||||
|
result.append({
|
||||||
|
'email': user.email,
|
||||||
|
'language': user.language,
|
||||||
|
'last_login_time': user.last_login_time,
|
||||||
|
'is_authenticated': user.is_authenticated,
|
||||||
|
'is_active': user.is_active,
|
||||||
|
'is_anonymous': user.is_anonymous,
|
||||||
|
'login_channel': user.login_channel,
|
||||||
|
'status': user.status,
|
||||||
|
'is_superuser': user.is_superuser,
|
||||||
|
'create_date': user.create_date,
|
||||||
|
'update_date': user.update_date
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_user(username, password, role="user") -> dict:
|
||||||
|
# Validate the email address
|
||||||
|
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", username):
|
||||||
|
raise AdminException(f"Invalid email address: {username}!")
|
||||||
|
# Check if the email address is already used
|
||||||
|
if UserService.query(email=username):
|
||||||
|
raise UserAlreadyExistsError(username)
|
||||||
|
# Construct user info data
|
||||||
|
user_info_dict = {
|
||||||
|
"email": username,
|
||||||
|
"nickname": "", # ask user to edit it manually in settings.
|
||||||
|
"password": decrypt(password),
|
||||||
|
"login_channel": "password",
|
||||||
|
"is_superuser": role == "admin",
|
||||||
|
}
|
||||||
|
return create_new_user(user_info_dict)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_user(username):
|
||||||
|
# use email to delete
|
||||||
|
raise AdminException("delete_user: not implemented")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_user_password(username, new_password) -> str:
|
||||||
|
# use email to find user. check exist and unique.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# check new_password different from old.
|
||||||
|
usr = user_list[0]
|
||||||
|
psw = decrypt(new_password)
|
||||||
|
if check_password_hash(usr.password, psw):
|
||||||
|
return "Same password, no need to update!"
|
||||||
|
# update password
|
||||||
|
UserService.update_user_password(usr.id, psw)
|
||||||
|
return "Password updated successfully!"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_user_activate_status(username, activate_status: str):
|
||||||
|
# use email to find user. check exist and unique.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# check activate status different from new
|
||||||
|
usr = user_list[0]
|
||||||
|
# format activate_status before handle
|
||||||
|
_activate_status = activate_status.lower()
|
||||||
|
target_status = {
|
||||||
|
'on': ActiveEnum.ACTIVE.value,
|
||||||
|
'off': ActiveEnum.INACTIVE.value,
|
||||||
|
}.get(_activate_status)
|
||||||
|
if not target_status:
|
||||||
|
raise AdminException(f"Invalid activate_status: {activate_status}")
|
||||||
|
if target_status == usr.is_active:
|
||||||
|
return f"User activate status is already {_activate_status}!"
|
||||||
|
# update is_active
|
||||||
|
UserService.update_user(usr.id, {"is_active": target_status})
|
||||||
|
return f"Turn {_activate_status} user activate status successfully!"
|
||||||
|
|
||||||
|
class UserServiceMgr:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_datasets(username):
|
||||||
|
# use email to find user.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# find tenants
|
||||||
|
usr = user_list[0]
|
||||||
|
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||||
|
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||||
|
# filter permitted kb and owned kb
|
||||||
|
return KnowledgebaseService.get_all_kb_by_tenant_ids(tenant_ids, usr.id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_agents(username):
|
||||||
|
# use email to find user.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# find tenants
|
||||||
|
usr = user_list[0]
|
||||||
|
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||||
|
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||||
|
# filter permitted agents and owned agents
|
||||||
|
return UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
|
||||||
|
|
||||||
|
class ServiceMgr:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_services():
|
||||||
|
result = []
|
||||||
|
configs = SERVICE_CONFIGS.configs
|
||||||
|
for config in configs:
|
||||||
|
result.append(config.to_dict())
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_services_by_type(service_type_str: str):
|
||||||
|
raise AdminException("get_services_by_type: not implemented")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_service_details(service_id: int):
|
||||||
|
raise AdminException("get_service_details: not implemented")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def shutdown_service(service_id: int):
|
||||||
|
raise AdminException("shutdown_service: not implemented")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def restart_service(service_id: int):
|
||||||
|
raise AdminException("restart_service: not implemented")
|
||||||
@ -27,7 +27,7 @@ from agent.component import component_class
|
|||||||
from agent.component.base import ComponentBase
|
from agent.component.base import ComponentBase
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.utils import get_uuid, hash_str2int
|
from api.utils import get_uuid, hash_str2int
|
||||||
from rag.prompts.prompts import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
|
||||||
class Graph:
|
class Graph:
|
||||||
@ -490,7 +490,8 @@ class Canvas(Graph):
|
|||||||
|
|
||||||
r = self.retrieval[-1]
|
r = self.retrieval[-1]
|
||||||
for ck in chunks_format({"chunks": chunks}):
|
for ck in chunks_format({"chunks": chunks}):
|
||||||
cid = hash_str2int(ck["id"], 100)
|
cid = hash_str2int(ck["id"], 500)
|
||||||
|
# cid = uuid.uuid5(uuid.NAMESPACE_DNS, ck["id"])
|
||||||
if cid not in r:
|
if cid not in r:
|
||||||
r["chunks"][cid] = ck
|
r["chunks"][cid] = ck
|
||||||
|
|
||||||
|
|||||||
@ -28,9 +28,8 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.mcp_server_service import MCPServerService
|
from api.db.services.mcp_server_service import MCPServerService
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.prompts import message_fit_in
|
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
|
||||||
from rag.prompts.prompts import next_step, COMPLETE_TASK, analyze_task, \
|
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
|
||||||
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question
|
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||||
from agent.component.llm import LLMParam, LLM
|
from agent.component.llm import LLMParam, LLM
|
||||||
|
|
||||||
|
|||||||
@ -244,7 +244,7 @@ class ComponentParamBase(ABC):
|
|||||||
|
|
||||||
if not value_legal:
|
if not value_legal:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
|
"Please check runtime conf, {} = {} does not match user-parameter restriction".format(
|
||||||
variable, value
|
variable, value
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from rag.llm.chat_model import ERROR_PREFIX
|
|||||||
class CategorizeParam(LLMParam):
|
class CategorizeParam(LLMParam):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Define the Categorize component parameters.
|
Define the categorize component parameters.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -96,7 +96,7 @@ Here's description of each category:
|
|||||||
class Categorize(LLM, ABC):
|
class Categorize(LLM, ABC):
|
||||||
component_name = "Categorize"
|
component_name = "Categorize"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||||
if not msg:
|
if not msg:
|
||||||
|
|||||||
@ -26,8 +26,7 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.prompts import message_fit_in, citation_prompt
|
from rag.prompts.generator import tool_call_summary, message_fit_in, citation_prompt
|
||||||
from rag.prompts.prompts import tool_call_summary
|
|
||||||
|
|
||||||
|
|
||||||
class LLMParam(ComponentParamBase):
|
class LLMParam(ComponentParamBase):
|
||||||
@ -83,8 +82,8 @@ class LLMParam(ComponentParamBase):
|
|||||||
class LLM(ComponentBase):
|
class LLM(ComponentBase):
|
||||||
component_name = "LLM"
|
component_name = "LLM"
|
||||||
|
|
||||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
def __init__(self, canvas, component_id, param: ComponentParamBase):
|
||||||
super().__init__(canvas, id, param)
|
super().__init__(canvas, component_id, param)
|
||||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id),
|
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id),
|
||||||
self._param.llm_id, max_retries=self._param.max_retries,
|
self._param.llm_id, max_retries=self._param.max_retries,
|
||||||
retry_interval=self._param.delay_after_error
|
retry_interval=self._param.delay_after_error
|
||||||
@ -210,7 +209,7 @@ class LLM(ComponentBase):
|
|||||||
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||||
|
|
||||||
prompt, msg, _ = self._prepare_prompt_variables()
|
prompt, msg, _ = self._prepare_prompt_variables()
|
||||||
error = ""
|
error: str = ""
|
||||||
|
|
||||||
if self._param.output_structure:
|
if self._param.output_structure:
|
||||||
prompt += "\nThe output MUST follow this JSON format:\n"+json.dumps(self._param.output_structure, ensure_ascii=False, indent=2)
|
prompt += "\nThe output MUST follow this JSON format:\n"+json.dumps(self._param.output_structure, ensure_ascii=False, indent=2)
|
||||||
|
|||||||
@ -49,7 +49,7 @@ class MessageParam(ComponentParamBase):
|
|||||||
class Message(ComponentBase):
|
class Message(ComponentBase):
|
||||||
component_name = "Message"
|
component_name = "Message"
|
||||||
|
|
||||||
def get_kwargs(self, script:str, kwargs:dict = {}, delimeter:str=None) -> tuple[str, dict[str, str | list | Any]]:
|
def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]:
|
||||||
for k,v in self.get_input_elements_from_text(script).items():
|
for k,v in self.get_input_elements_from_text(script).items():
|
||||||
if k in kwargs:
|
if k in kwargs:
|
||||||
continue
|
continue
|
||||||
@ -60,8 +60,8 @@ class Message(ComponentBase):
|
|||||||
if isinstance(v, partial):
|
if isinstance(v, partial):
|
||||||
for t in v():
|
for t in v():
|
||||||
ans += t
|
ans += t
|
||||||
elif isinstance(v, list) and delimeter:
|
elif isinstance(v, list) and delimiter:
|
||||||
ans = delimeter.join([str(vv) for vv in v])
|
ans = delimiter.join([str(vv) for vv in v])
|
||||||
elif not isinstance(v, str):
|
elif not isinstance(v, str):
|
||||||
try:
|
try:
|
||||||
ans = json.dumps(v, ensure_ascii=False)
|
ans = json.dumps(v, ensure_ascii=False)
|
||||||
|
|||||||
@ -56,7 +56,7 @@ class StringTransform(Message, ABC):
|
|||||||
"type": "line"
|
"type": "line"
|
||||||
} for k, o in self.get_input_elements_from_text(self._param.script).items()}
|
} for k, o in self.get_input_elements_from_text(self._param.script).items()}
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if self._param.method == "split":
|
if self._param.method == "split":
|
||||||
self._split(kwargs.get("line"))
|
self._split(kwargs.get("line"))
|
||||||
@ -90,7 +90,7 @@ class StringTransform(Message, ABC):
|
|||||||
for k,v in kwargs.items():
|
for k,v in kwargs.items():
|
||||||
if not v:
|
if not v:
|
||||||
v = ""
|
v = ""
|
||||||
script = re.sub(k, v, script)
|
script = re.sub(k, lambda match: v, script)
|
||||||
|
|
||||||
self.set_output("result", script)
|
self.set_output("result", script)
|
||||||
|
|
||||||
|
|||||||
@ -83,7 +83,7 @@
|
|||||||
},
|
},
|
||||||
"password": "20010812Yy!",
|
"password": "20010812Yy!",
|
||||||
"port": 3306,
|
"port": 3306,
|
||||||
"sql": "Agent:WickedGoatsDivide@content",
|
"sql": "{Agent:WickedGoatsDivide@content}",
|
||||||
"username": "13637682833@163.com"
|
"username": "13637682833@163.com"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -114,9 +114,7 @@
|
|||||||
"params": {
|
"params": {
|
||||||
"cross_languages": [],
|
"cross_languages": [],
|
||||||
"empty_response": "",
|
"empty_response": "",
|
||||||
"kb_ids": [
|
"kb_ids": [],
|
||||||
"ed31364c727211f0bdb2bafe6e7908e6"
|
|
||||||
],
|
|
||||||
"keywords_similarity_weight": 0.7,
|
"keywords_similarity_weight": 0.7,
|
||||||
"outputs": {
|
"outputs": {
|
||||||
"formalized_content": {
|
"formalized_content": {
|
||||||
@ -124,7 +122,7 @@
|
|||||||
"value": ""
|
"value": ""
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"query": "sys.query",
|
"query": "{sys.query}",
|
||||||
"rerank_id": "",
|
"rerank_id": "",
|
||||||
"similarity_threshold": 0.2,
|
"similarity_threshold": 0.2,
|
||||||
"top_k": 1024,
|
"top_k": 1024,
|
||||||
@ -145,9 +143,7 @@
|
|||||||
"params": {
|
"params": {
|
||||||
"cross_languages": [],
|
"cross_languages": [],
|
||||||
"empty_response": "",
|
"empty_response": "",
|
||||||
"kb_ids": [
|
"kb_ids": [],
|
||||||
"0f968106727311f08357bafe6e7908e6"
|
|
||||||
],
|
|
||||||
"keywords_similarity_weight": 0.7,
|
"keywords_similarity_weight": 0.7,
|
||||||
"outputs": {
|
"outputs": {
|
||||||
"formalized_content": {
|
"formalized_content": {
|
||||||
@ -155,7 +151,7 @@
|
|||||||
"value": ""
|
"value": ""
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"query": "sys.query",
|
"query": "{sys.query}",
|
||||||
"rerank_id": "",
|
"rerank_id": "",
|
||||||
"similarity_threshold": 0.2,
|
"similarity_threshold": 0.2,
|
||||||
"top_k": 1024,
|
"top_k": 1024,
|
||||||
@ -176,9 +172,7 @@
|
|||||||
"params": {
|
"params": {
|
||||||
"cross_languages": [],
|
"cross_languages": [],
|
||||||
"empty_response": "",
|
"empty_response": "",
|
||||||
"kb_ids": [
|
"kb_ids": [],
|
||||||
"4ad1f9d0727311f0827dbafe6e7908e6"
|
|
||||||
],
|
|
||||||
"keywords_similarity_weight": 0.7,
|
"keywords_similarity_weight": 0.7,
|
||||||
"outputs": {
|
"outputs": {
|
||||||
"formalized_content": {
|
"formalized_content": {
|
||||||
@ -186,7 +180,7 @@
|
|||||||
"value": ""
|
"value": ""
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"query": "sys.query",
|
"query": "{sys.query}",
|
||||||
"rerank_id": "",
|
"rerank_id": "",
|
||||||
"similarity_threshold": 0.2,
|
"similarity_threshold": 0.2,
|
||||||
"top_k": 1024,
|
"top_k": 1024,
|
||||||
@ -347,9 +341,7 @@
|
|||||||
"form": {
|
"form": {
|
||||||
"cross_languages": [],
|
"cross_languages": [],
|
||||||
"empty_response": "",
|
"empty_response": "",
|
||||||
"kb_ids": [
|
"kb_ids": [],
|
||||||
"ed31364c727211f0bdb2bafe6e7908e6"
|
|
||||||
],
|
|
||||||
"keywords_similarity_weight": 0.7,
|
"keywords_similarity_weight": 0.7,
|
||||||
"outputs": {
|
"outputs": {
|
||||||
"formalized_content": {
|
"formalized_content": {
|
||||||
@ -357,7 +349,7 @@
|
|||||||
"value": ""
|
"value": ""
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"query": "sys.query",
|
"query": "{sys.query}",
|
||||||
"rerank_id": "",
|
"rerank_id": "",
|
||||||
"similarity_threshold": 0.2,
|
"similarity_threshold": 0.2,
|
||||||
"top_k": 1024,
|
"top_k": 1024,
|
||||||
@ -387,9 +379,7 @@
|
|||||||
"form": {
|
"form": {
|
||||||
"cross_languages": [],
|
"cross_languages": [],
|
||||||
"empty_response": "",
|
"empty_response": "",
|
||||||
"kb_ids": [
|
"kb_ids": [],
|
||||||
"0f968106727311f08357bafe6e7908e6"
|
|
||||||
],
|
|
||||||
"keywords_similarity_weight": 0.7,
|
"keywords_similarity_weight": 0.7,
|
||||||
"outputs": {
|
"outputs": {
|
||||||
"formalized_content": {
|
"formalized_content": {
|
||||||
@ -397,7 +387,7 @@
|
|||||||
"value": ""
|
"value": ""
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"query": "sys.query",
|
"query": "{sys.query}",
|
||||||
"rerank_id": "",
|
"rerank_id": "",
|
||||||
"similarity_threshold": 0.2,
|
"similarity_threshold": 0.2,
|
||||||
"top_k": 1024,
|
"top_k": 1024,
|
||||||
@ -427,9 +417,7 @@
|
|||||||
"form": {
|
"form": {
|
||||||
"cross_languages": [],
|
"cross_languages": [],
|
||||||
"empty_response": "",
|
"empty_response": "",
|
||||||
"kb_ids": [
|
"kb_ids": [],
|
||||||
"4ad1f9d0727311f0827dbafe6e7908e6"
|
|
||||||
],
|
|
||||||
"keywords_similarity_weight": 0.7,
|
"keywords_similarity_weight": 0.7,
|
||||||
"outputs": {
|
"outputs": {
|
||||||
"formalized_content": {
|
"formalized_content": {
|
||||||
@ -437,7 +425,7 @@
|
|||||||
"value": ""
|
"value": ""
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"query": "sys.query",
|
"query": "{sys.query}",
|
||||||
"rerank_id": "",
|
"rerank_id": "",
|
||||||
"similarity_threshold": 0.2,
|
"similarity_threshold": 0.2,
|
||||||
"top_k": 1024,
|
"top_k": 1024,
|
||||||
@ -539,7 +527,7 @@
|
|||||||
},
|
},
|
||||||
"password": "20010812Yy!",
|
"password": "20010812Yy!",
|
||||||
"port": 3306,
|
"port": 3306,
|
||||||
"sql": "Agent:WickedGoatsDivide@content",
|
"sql": "{Agent:WickedGoatsDivide@content}",
|
||||||
"username": "13637682833@163.com"
|
"username": "13637682833@163.com"
|
||||||
},
|
},
|
||||||
"label": "ExeSQL",
|
"label": "ExeSQL",
|
||||||
|
|||||||
@ -61,7 +61,7 @@ class ArXivParam(ToolParamBase):
|
|||||||
class ArXiv(ToolBase, ABC):
|
class ArXiv(ToolBase, ABC):
|
||||||
component_name = "ArXiv"
|
component_name = "ArXiv"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from typing import TypedDict, List, Any
|
|||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
from agent.component.base import ComponentParamBase, ComponentBase
|
||||||
from api.utils import hash_str2int
|
from api.utils import hash_str2int
|
||||||
from rag.llm.chat_model import ToolCallSession
|
from rag.llm.chat_model import ToolCallSession
|
||||||
from rag.prompts.prompts import kb_prompt
|
from rag.prompts.generator import kb_prompt
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
|
|||||||
@ -129,7 +129,7 @@ module.exports = { main };
|
|||||||
class CodeExec(ToolBase, ABC):
|
class CodeExec(ToolBase, ABC):
|
||||||
component_name = "CodeExec"
|
component_name = "CodeExec"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
lang = kwargs.get("lang", self._param.lang)
|
lang = kwargs.get("lang", self._param.lang)
|
||||||
script = kwargs.get("script", self._param.script)
|
script = kwargs.get("script", self._param.script)
|
||||||
@ -157,7 +157,7 @@ class CodeExec(ToolBase, ABC):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run", code_req, resp.status_code)
|
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:")
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
body = resp.json()
|
body = resp.json()
|
||||||
|
|||||||
@ -73,7 +73,7 @@ class DuckDuckGoParam(ToolParamBase):
|
|||||||
class DuckDuckGo(ToolBase, ABC):
|
class DuckDuckGo(ToolBase, ABC):
|
||||||
component_name = "DuckDuckGo"
|
component_name = "DuckDuckGo"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
|
|||||||
@ -99,7 +99,7 @@ class EmailParam(ToolParamBase):
|
|||||||
class Email(ToolBase, ABC):
|
class Email(ToolBase, ABC):
|
||||||
component_name = "Email"
|
component_name = "Email"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("to_email"):
|
if not kwargs.get("to_email"):
|
||||||
self.set_output("success", False)
|
self.set_output("success", False)
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class ExeSQLParam(ToolParamBase):
|
|||||||
self.max_records = 1024
|
self.max_records = 1024
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb', 'mssql'])
|
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql'])
|
||||||
self.check_empty(self.database, "Database name")
|
self.check_empty(self.database, "Database name")
|
||||||
self.check_empty(self.username, "database username")
|
self.check_empty(self.username, "database username")
|
||||||
self.check_empty(self.host, "IP Address")
|
self.check_empty(self.host, "IP Address")
|
||||||
@ -78,7 +78,7 @@ class ExeSQLParam(ToolParamBase):
|
|||||||
class ExeSQL(ToolBase, ABC):
|
class ExeSQL(ToolBase, ABC):
|
||||||
component_name = "ExeSQL"
|
component_name = "ExeSQL"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
|
|
||||||
def convert_decimals(obj):
|
def convert_decimals(obj):
|
||||||
@ -111,7 +111,7 @@ class ExeSQL(ToolBase, ABC):
|
|||||||
if self._param.db_type in ["mysql", "mariadb"]:
|
if self._param.db_type in ["mysql", "mariadb"]:
|
||||||
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
|
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
|
||||||
port=self._param.port, password=self._param.password)
|
port=self._param.port, password=self._param.password)
|
||||||
elif self._param.db_type == 'postgresql':
|
elif self._param.db_type == 'postgres':
|
||||||
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=self._param.host,
|
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=self._param.host,
|
||||||
port=self._param.port, password=self._param.password)
|
port=self._param.port, password=self._param.password)
|
||||||
elif self._param.db_type == 'mssql':
|
elif self._param.db_type == 'mssql':
|
||||||
|
|||||||
@ -57,7 +57,7 @@ class GitHubParam(ToolParamBase):
|
|||||||
class GitHub(ToolBase, ABC):
|
class GitHub(ToolBase, ABC):
|
||||||
component_name = "GitHub"
|
component_name = "GitHub"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
|
|||||||
@ -116,7 +116,7 @@ class GoogleParam(ToolParamBase):
|
|||||||
class Google(ToolBase, ABC):
|
class Google(ToolBase, ABC):
|
||||||
component_name = "Google"
|
component_name = "Google"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("q"):
|
if not kwargs.get("q"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
|
|||||||
@ -63,7 +63,7 @@ class GoogleScholarParam(ToolParamBase):
|
|||||||
class GoogleScholar(ToolBase, ABC):
|
class GoogleScholar(ToolBase, ABC):
|
||||||
component_name = "GoogleScholar"
|
component_name = "GoogleScholar"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
|
|||||||
@ -69,7 +69,7 @@ In addition to MEDLINE, PubMed provides access to:
|
|||||||
class PubMed(ToolBase, ABC):
|
class PubMed(ToolBase, ABC):
|
||||||
component_name = "PubMed"
|
component_name = "PubMed"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
|
|||||||
@ -23,8 +23,7 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts import kb_prompt
|
from rag.prompts.generator import cross_languages, kb_prompt
|
||||||
from rag.prompts.prompts import cross_languages
|
|
||||||
|
|
||||||
|
|
||||||
class RetrievalParam(ToolParamBase):
|
class RetrievalParam(ToolParamBase):
|
||||||
@ -75,7 +74,7 @@ class RetrievalParam(ToolParamBase):
|
|||||||
class Retrieval(ToolBase, ABC):
|
class Retrieval(ToolBase, ABC):
|
||||||
component_name = "Retrieval"
|
component_name = "Retrieval"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", self._param.empty_response)
|
self.set_output("formalized_content", self._param.empty_response)
|
||||||
@ -163,9 +162,16 @@ class Retrieval(ToolBase, ABC):
|
|||||||
self.set_output("formalized_content", self._param.empty_response)
|
self.set_output("formalized_content", self._param.empty_response)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Format the chunks for JSON output (similar to how other tools do it)
|
||||||
|
json_output = kbinfos["chunks"].copy()
|
||||||
|
|
||||||
self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"])
|
self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"])
|
||||||
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
|
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
|
||||||
|
|
||||||
|
# Set both formalized content and JSON output
|
||||||
self.set_output("formalized_content", form_cnt)
|
self.set_output("formalized_content", form_cnt)
|
||||||
|
self.set_output("json", json_output)
|
||||||
|
|
||||||
return form_cnt
|
return form_cnt
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
|
|||||||
@ -77,7 +77,7 @@ class SearXNGParam(ToolParamBase):
|
|||||||
class SearXNG(ToolBase, ABC):
|
class SearXNG(ToolBase, ABC):
|
||||||
component_name = "SearXNG"
|
component_name = "SearXNG"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
# Gracefully handle try-run without inputs
|
# Gracefully handle try-run without inputs
|
||||||
query = kwargs.get("query")
|
query = kwargs.get("query")
|
||||||
@ -94,7 +94,6 @@ class SearXNG(ToolBase, ABC):
|
|||||||
last_e = ""
|
last_e = ""
|
||||||
for _ in range(self._param.max_retries+1):
|
for _ in range(self._param.max_retries+1):
|
||||||
try:
|
try:
|
||||||
# 构建搜索参数
|
|
||||||
search_params = {
|
search_params = {
|
||||||
'q': query,
|
'q': query,
|
||||||
'format': 'json',
|
'format': 'json',
|
||||||
@ -104,7 +103,6 @@ class SearXNG(ToolBase, ABC):
|
|||||||
'pageno': 1
|
'pageno': 1
|
||||||
}
|
}
|
||||||
|
|
||||||
# 发送搜索请求
|
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{searxng_url}/search",
|
f"{searxng_url}/search",
|
||||||
params=search_params,
|
params=search_params,
|
||||||
@ -114,7 +112,6 @@ class SearXNG(ToolBase, ABC):
|
|||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
# 验证响应数据
|
|
||||||
if not data or not isinstance(data, dict):
|
if not data or not isinstance(data, dict):
|
||||||
raise ValueError("Invalid response from SearXNG")
|
raise ValueError("Invalid response from SearXNG")
|
||||||
|
|
||||||
@ -122,10 +119,8 @@ class SearXNG(ToolBase, ABC):
|
|||||||
if not isinstance(results, list):
|
if not isinstance(results, list):
|
||||||
raise ValueError("Invalid results format from SearXNG")
|
raise ValueError("Invalid results format from SearXNG")
|
||||||
|
|
||||||
# 限制结果数量
|
|
||||||
results = results[:self._param.top_n]
|
results = results[:self._param.top_n]
|
||||||
|
|
||||||
# 处理搜索结果
|
|
||||||
self._retrieve_chunks(results,
|
self._retrieve_chunks(results,
|
||||||
get_title=lambda r: r.get("title", ""),
|
get_title=lambda r: r.get("title", ""),
|
||||||
get_url=lambda r: r.get("url", ""),
|
get_url=lambda r: r.get("url", ""),
|
||||||
|
|||||||
@ -101,7 +101,7 @@ When searching:
|
|||||||
class TavilySearch(ToolBase, ABC):
|
class TavilySearch(ToolBase, ABC):
|
||||||
component_name = "TavilySearch"
|
component_name = "TavilySearch"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -199,7 +199,7 @@ class TavilyExtractParam(ToolParamBase):
|
|||||||
class TavilyExtract(ToolBase, ABC):
|
class TavilyExtract(ToolBase, ABC):
|
||||||
component_name = "TavilyExtract"
|
component_name = "TavilyExtract"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
||||||
last_e = None
|
last_e = None
|
||||||
|
|||||||
@ -68,7 +68,7 @@ fund selection platform: through AI technology, is committed to providing excell
|
|||||||
class WenCai(ToolBase, ABC):
|
class WenCai(ToolBase, ABC):
|
||||||
component_name = "WenCai"
|
component_name = "WenCai"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("report", "")
|
self.set_output("report", "")
|
||||||
|
|||||||
@ -64,7 +64,7 @@ class WikipediaParam(ToolParamBase):
|
|||||||
class Wikipedia(ToolBase, ABC):
|
class Wikipedia(ToolBase, ABC):
|
||||||
component_name = "Wikipedia"
|
component_name = "Wikipedia"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
|
|||||||
@ -72,7 +72,7 @@ class YahooFinanceParam(ToolParamBase):
|
|||||||
class YahooFinance(ToolBase, ABC):
|
class YahooFinance(ToolBase, ABC):
|
||||||
component_name = "YahooFinance"
|
component_name = "YahooFinance"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("stock_code"):
|
if not kwargs.get("stock_code"):
|
||||||
self.set_output("report", "")
|
self.set_output("report", "")
|
||||||
|
|||||||
@ -27,7 +27,8 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
|||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
from api.db.db_models import close_connection
|
from api.db.db_models import close_connection
|
||||||
from api.db.services import UserService
|
from api.db.services import UserService
|
||||||
from api.utils import CustomJSONEncoder, commands
|
from api.utils.json import CustomJSONEncoder
|
||||||
|
from api.utils import commands
|
||||||
|
|
||||||
from flask_mail import Mail
|
from flask_mail import Mail
|
||||||
from flask_session import Session
|
from flask_session import Session
|
||||||
|
|||||||
@ -39,7 +39,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
|
|||||||
|
|
||||||
from api.utils.file_utils import filename_type, thumbnail
|
from api.utils.file_utils import filename_type, thumbnail
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts import keyword_extraction
|
from rag.prompts.generator import keyword_extraction
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
from api.db.services.canvas_service import UserCanvasService
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
|
|||||||
@ -23,7 +23,7 @@ import trio
|
|||||||
from flask import request, Response
|
from flask import request, Response
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
from agent.component import LLM
|
from agent.component.llm import LLM
|
||||||
from api.db import CanvasCategory, FileType
|
from api.db import CanvasCategory, FileType
|
||||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
@ -332,7 +332,7 @@ def test_db_connect():
|
|||||||
if req["db_type"] in ["mysql", "mariadb"]:
|
if req["db_type"] in ["mysql", "mariadb"]:
|
||||||
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||||
password=req["password"])
|
password=req["password"])
|
||||||
elif req["db_type"] == 'postgresql':
|
elif req["db_type"] == 'postgres':
|
||||||
db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||||
password=req["password"])
|
password=req["password"])
|
||||||
elif req["db_type"] == 'mssql':
|
elif req["db_type"] == 'mssql':
|
||||||
@ -474,7 +474,7 @@ def sessions(canvas_id):
|
|||||||
@manager.route('/prompts', methods=['GET']) # noqa: F821
|
@manager.route('/prompts', methods=['GET']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def prompts():
|
def prompts():
|
||||||
from rag.prompts.prompts import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
||||||
return get_json_result(data={
|
return get_json_result(data={
|
||||||
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
||||||
"plan_generation": NEXT_STEP,
|
"plan_generation": NEXT_STEP,
|
||||||
|
|||||||
@ -33,8 +33,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, server_e
|
|||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
from rag.prompts import cross_languages, keyword_extraction
|
from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
|
||||||
from rag.prompts.prompts import gen_meta_filter
|
|
||||||
from rag.settings import PAGERANK_FLD
|
from rag.settings import PAGERANK_FLD
|
||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import traceback
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from flask import Response, request
|
from flask import Response, request
|
||||||
from flask_login import current_user, login_required
|
from flask_login import current_user, login_required
|
||||||
@ -29,8 +29,8 @@ from api.db.services.search_service import SearchService
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||||
from rag.prompts.prompt_template import load_prompt
|
from rag.prompts.template import load_prompt
|
||||||
from rag.prompts.prompts import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||||
@ -226,7 +226,7 @@ def completion():
|
|||||||
if not is_embedded:
|
if not is_embedded:
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
logging.exception(e)
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from flask import request
|
|||||||
from flask_login import current_user, login_required
|
from flask_login import current_user, login_required
|
||||||
|
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
from agent.component import LLM
|
from agent.component.llm import LLM
|
||||||
from api.db import CanvasCategory, FileType
|
from api.db import CanvasCategory, FileType
|
||||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from api.db.services.document_service import DocumentService
|
|||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters, active_required
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.db import StatusEnum, FileSource
|
from api.db import StatusEnum, FileSource
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
@ -38,6 +38,7 @@ from rag.utils.storage_factory import STORAGE_IMPL
|
|||||||
|
|
||||||
@manager.route('/create', methods=['post']) # noqa: F821
|
@manager.route('/create', methods=['post']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
|
@active_required
|
||||||
@validate_request("name")
|
@validate_request("name")
|
||||||
def create():
|
def create():
|
||||||
req = request.json
|
req = request.json
|
||||||
@ -379,3 +380,19 @@ def get_meta():
|
|||||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||||
)
|
)
|
||||||
return get_json_result(data=DocumentService.get_meta_by_kbs(kb_ids))
|
return get_json_result(data=DocumentService.get_meta_by_kbs(kb_ids))
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/basic_info", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def get_basic_info():
|
||||||
|
kb_id = request.args.get("kb_id", "")
|
||||||
|
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||||
|
return get_json_result(
|
||||||
|
data=False,
|
||||||
|
message='No authorization.',
|
||||||
|
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
basic_info = DocumentService.knowledgebase_basic_info(kb_id)
|
||||||
|
|
||||||
|
return get_json_result(data=basic_info)
|
||||||
|
|||||||
@ -40,7 +40,7 @@ from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_
|
|||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
from rag.prompts import cross_languages, keyword_extraction
|
from rag.prompts.generator import cross_languages, keyword_extraction
|
||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
|
|||||||
@ -3,9 +3,11 @@ import re
|
|||||||
|
|
||||||
import flask
|
import flask
|
||||||
from flask import request
|
from flask import request
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.utils.api_utils import server_error_response, token_required
|
from api.utils.api_utils import server_error_response, token_required
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.db import FileType
|
from api.db import FileType
|
||||||
@ -81,16 +83,16 @@ def upload(tenant_id):
|
|||||||
return get_json_result(data=False, message="Can't find this folder!", code=404)
|
return get_json_result(data=False, message="Can't find this folder!", code=404)
|
||||||
|
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
# 文件路径处理
|
# Handle file path
|
||||||
full_path = '/' + file_obj.filename
|
full_path = '/' + file_obj.filename
|
||||||
file_obj_names = full_path.split('/')
|
file_obj_names = full_path.split('/')
|
||||||
file_len = len(file_obj_names)
|
file_len = len(file_obj_names)
|
||||||
|
|
||||||
# 获取文件夹路径ID
|
# Get folder path ID
|
||||||
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
||||||
len_id_list = len(file_id_list)
|
len_id_list = len(file_id_list)
|
||||||
|
|
||||||
# 创建文件夹结构
|
# Crete file folder
|
||||||
if file_len != len_id_list:
|
if file_len != len_id_list:
|
||||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||||
if not e:
|
if not e:
|
||||||
@ -666,3 +668,71 @@ def move(tenant_id):
|
|||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
||||||
|
@token_required
|
||||||
|
def convert(tenant_id):
|
||||||
|
req = request.json
|
||||||
|
kb_ids = req["kb_ids"]
|
||||||
|
file_ids = req["file_ids"]
|
||||||
|
file2documents = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
files = FileService.get_by_ids(file_ids)
|
||||||
|
files_set = dict({file.id: file for file in files})
|
||||||
|
for file_id in file_ids:
|
||||||
|
file = files_set[file_id]
|
||||||
|
if not file:
|
||||||
|
return get_json_result(message="File not found!", code=404)
|
||||||
|
file_ids_list = [file_id]
|
||||||
|
if file.type == FileType.FOLDER.value:
|
||||||
|
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||||
|
for id in file_ids_list:
|
||||||
|
informs = File2DocumentService.get_by_file_id(id)
|
||||||
|
# delete
|
||||||
|
for inform in informs:
|
||||||
|
doc_id = inform.document_id
|
||||||
|
e, doc = DocumentService.get_by_id(doc_id)
|
||||||
|
if not e:
|
||||||
|
return get_json_result(message="Document not found!", code=404)
|
||||||
|
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||||
|
if not tenant_id:
|
||||||
|
return get_json_result(message="Tenant not found!", code=404)
|
||||||
|
if not DocumentService.remove_document(doc, tenant_id):
|
||||||
|
return get_json_result(
|
||||||
|
message="Database error (Document removal)!", code=404)
|
||||||
|
File2DocumentService.delete_by_file_id(id)
|
||||||
|
|
||||||
|
# insert
|
||||||
|
for kb_id in kb_ids:
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not e:
|
||||||
|
return get_json_result(
|
||||||
|
message="Can't find this knowledgebase!", code=404)
|
||||||
|
e, file = FileService.get_by_id(id)
|
||||||
|
if not e:
|
||||||
|
return get_json_result(
|
||||||
|
message="Can't find this file!", code=404)
|
||||||
|
|
||||||
|
doc = DocumentService.insert({
|
||||||
|
"id": get_uuid(),
|
||||||
|
"kb_id": kb.id,
|
||||||
|
"parser_id": FileService.get_parser(file.type, file.name, kb.parser_id),
|
||||||
|
"parser_config": kb.parser_config,
|
||||||
|
"created_by": tenant_id,
|
||||||
|
"type": file.type,
|
||||||
|
"name": file.name,
|
||||||
|
"suffix": Path(file.name).suffix.lstrip("."),
|
||||||
|
"location": file.location,
|
||||||
|
"size": file.size
|
||||||
|
})
|
||||||
|
file2document = File2DocumentService.insert({
|
||||||
|
"id": get_uuid(),
|
||||||
|
"file_id": id,
|
||||||
|
"document_id": doc.id,
|
||||||
|
})
|
||||||
|
|
||||||
|
file2documents.append(file2document.to_json())
|
||||||
|
return get_json_result(data=file2documents)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
@ -38,9 +38,8 @@ from api.db.services.user_service import UserTenantService
|
|||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request
|
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts import chunks_format
|
from rag.prompts.template import load_prompt
|
||||||
from rag.prompts.prompt_template import load_prompt
|
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
||||||
from rag.prompts.prompts import cross_languages, gen_meta_filter, keyword_extraction
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||||
@ -414,7 +413,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
|||||||
tenant_id,
|
tenant_id,
|
||||||
agent_id,
|
agent_id,
|
||||||
question,
|
question,
|
||||||
session_id=req.get("session_id", req.get("id", "") or req.get("metadata", {}).get("id", "")),
|
session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""),
|
||||||
stream=True,
|
stream=True,
|
||||||
**req,
|
**req,
|
||||||
),
|
),
|
||||||
@ -432,7 +431,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
|||||||
tenant_id,
|
tenant_id,
|
||||||
agent_id,
|
agent_id,
|
||||||
question,
|
question,
|
||||||
session_id=req.get("session_id", req.get("id", "") or req.get("metadata", {}).get("id", "")),
|
session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""),
|
||||||
stream=False,
|
stream=False,
|
||||||
**req,
|
**req,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -36,6 +36,8 @@ from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
|||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
from flask import jsonify
|
||||||
|
from api.utils.health_utils import run_health_checks
|
||||||
|
|
||||||
@manager.route("/version", methods=["GET"]) # noqa: F821
|
@manager.route("/version", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@ -169,6 +171,12 @@ def status():
|
|||||||
return get_json_result(data=res)
|
return get_json_result(data=res)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/healthz", methods=["GET"]) # noqa: F821
|
||||||
|
def healthz():
|
||||||
|
result, all_ok = run_health_checks()
|
||||||
|
return jsonify(result), (200 if all_ok else 500)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/new_token", methods=["POST"]) # noqa: F821
|
@manager.route("/new_token", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def new_token():
|
def new_token():
|
||||||
|
|||||||
@ -34,7 +34,6 @@ from api.db.services.user_service import TenantService, UserService, UserTenantS
|
|||||||
from api.utils import (
|
from api.utils import (
|
||||||
current_timestamp,
|
current_timestamp,
|
||||||
datetime_format,
|
datetime_format,
|
||||||
decrypt,
|
|
||||||
download_img,
|
download_img,
|
||||||
get_format_time,
|
get_format_time,
|
||||||
get_uuid,
|
get_uuid,
|
||||||
@ -46,6 +45,7 @@ from api.utils.api_utils import (
|
|||||||
server_error_response,
|
server_error_response,
|
||||||
validate_request,
|
validate_request,
|
||||||
)
|
)
|
||||||
|
from api.utils.crypt import decrypt
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||||
|
|||||||
@ -23,6 +23,11 @@ class StatusEnum(Enum):
|
|||||||
INVALID = "0"
|
INVALID = "0"
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveEnum(Enum):
|
||||||
|
ACTIVE = "1"
|
||||||
|
INACTIVE = "0"
|
||||||
|
|
||||||
|
|
||||||
class UserTenantRole(StrEnum):
|
class UserTenantRole(StrEnum):
|
||||||
OWNER = 'owner'
|
OWNER = 'owner'
|
||||||
ADMIN = 'admin'
|
ADMIN = 'admin'
|
||||||
@ -111,7 +116,7 @@ class CanvasCategory(StrEnum):
|
|||||||
Agent = "agent_canvas"
|
Agent = "agent_canvas"
|
||||||
DataFlow = "dataflow_canvas"
|
DataFlow = "dataflow_canvas"
|
||||||
|
|
||||||
VALID_CAVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow}
|
VALID_CANVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow}
|
||||||
|
|
||||||
|
|
||||||
class MCPServerType(StrEnum):
|
class MCPServerType(StrEnum):
|
||||||
|
|||||||
@ -26,12 +26,14 @@ from functools import wraps
|
|||||||
|
|
||||||
from flask_login import UserMixin
|
from flask_login import UserMixin
|
||||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||||
from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||||
|
|
||||||
from api import settings, utils
|
from api import settings, utils
|
||||||
from api.db import ParserType, SerializedType
|
from api.db import ParserType, SerializedType
|
||||||
|
from api.utils.json import json_dumps, json_loads
|
||||||
|
from api.utils.configs import deserialize_b64, serialize_b64
|
||||||
|
|
||||||
|
|
||||||
def singleton(cls, *args, **kw):
|
def singleton(cls, *args, **kw):
|
||||||
@ -70,12 +72,12 @@ class JSONField(LongTextField):
|
|||||||
def db_value(self, value):
|
def db_value(self, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
value = self.default_value
|
value = self.default_value
|
||||||
return utils.json_dumps(value)
|
return json_dumps(value)
|
||||||
|
|
||||||
def python_value(self, value):
|
def python_value(self, value):
|
||||||
if not value:
|
if not value:
|
||||||
return self.default_value
|
return self.default_value
|
||||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||||
|
|
||||||
|
|
||||||
class ListField(JSONField):
|
class ListField(JSONField):
|
||||||
@ -91,21 +93,21 @@ class SerializedField(LongTextField):
|
|||||||
|
|
||||||
def db_value(self, value):
|
def db_value(self, value):
|
||||||
if self._serialized_type == SerializedType.PICKLE:
|
if self._serialized_type == SerializedType.PICKLE:
|
||||||
return utils.serialize_b64(value, to_str=True)
|
return serialize_b64(value, to_str=True)
|
||||||
elif self._serialized_type == SerializedType.JSON:
|
elif self._serialized_type == SerializedType.JSON:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
return utils.json_dumps(value, with_type=True)
|
return json_dumps(value, with_type=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||||
|
|
||||||
def python_value(self, value):
|
def python_value(self, value):
|
||||||
if self._serialized_type == SerializedType.PICKLE:
|
if self._serialized_type == SerializedType.PICKLE:
|
||||||
return utils.deserialize_b64(value)
|
return deserialize_b64(value)
|
||||||
elif self._serialized_type == SerializedType.JSON:
|
elif self._serialized_type == SerializedType.JSON:
|
||||||
if value is None:
|
if value is None:
|
||||||
return {}
|
return {}
|
||||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||||
|
|
||||||
@ -250,36 +252,63 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def execute_sql(self, sql, params=None, commit=True):
|
def execute_sql(self, sql, params=None, commit=True):
|
||||||
from peewee import OperationalError
|
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
return super().execute_sql(sql, params, commit)
|
return super().execute_sql(sql, params, commit)
|
||||||
except OperationalError as e:
|
except (OperationalError, InterfaceError) as e:
|
||||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
error_codes = [2013, 2006]
|
||||||
logging.warning(f"Lost connection (attempt {attempt + 1}/{self.max_retries}): {e}")
|
error_messages = ['', 'Lost connection']
|
||||||
|
should_retry = (
|
||||||
|
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||||
|
(str(e) in error_messages) or
|
||||||
|
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_retry and attempt < self.max_retries:
|
||||||
|
logging.warning(
|
||||||
|
f"Database connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
|
||||||
|
)
|
||||||
self._handle_connection_loss()
|
self._handle_connection_loss()
|
||||||
time.sleep(self.retry_delay * (2**attempt))
|
time.sleep(self.retry_delay * (2 ** attempt))
|
||||||
else:
|
else:
|
||||||
logging.error(f"DB execution failure: {e}")
|
logging.error(f"DB execution failure: {e}")
|
||||||
raise
|
raise
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _handle_connection_loss(self):
|
def _handle_connection_loss(self):
|
||||||
self.close_all()
|
# self.close_all()
|
||||||
self.connect()
|
# self.connect()
|
||||||
|
try:
|
||||||
|
self.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
self.connect()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to reconnect: {e}")
|
||||||
|
time.sleep(0.1)
|
||||||
|
self.connect()
|
||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
from peewee import OperationalError
|
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
return super().begin()
|
return super().begin()
|
||||||
except OperationalError as e:
|
except (OperationalError, InterfaceError) as e:
|
||||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
error_codes = [2013, 2006]
|
||||||
logging.warning(f"Lost connection during transaction (attempt {attempt + 1}/{self.max_retries})")
|
error_messages = ['', 'Lost connection']
|
||||||
|
|
||||||
|
should_retry = (
|
||||||
|
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||||
|
(str(e) in error_messages) or
|
||||||
|
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_retry and attempt < self.max_retries:
|
||||||
|
logging.warning(
|
||||||
|
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
|
||||||
|
)
|
||||||
self._handle_connection_loss()
|
self._handle_connection_loss()
|
||||||
time.sleep(self.retry_delay * (2**attempt))
|
time.sleep(self.retry_delay * (2 ** attempt))
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -299,7 +328,16 @@ class BaseDataBase:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
database_config = settings.DATABASE.copy()
|
database_config = settings.DATABASE.copy()
|
||||||
db_name = database_config.pop("name")
|
db_name = database_config.pop("name")
|
||||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
|
||||||
|
pool_config = {
|
||||||
|
'max_retries': 5,
|
||||||
|
'retry_delay': 1,
|
||||||
|
}
|
||||||
|
database_config.update(pool_config)
|
||||||
|
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(
|
||||||
|
db_name, **database_config
|
||||||
|
)
|
||||||
|
# self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||||
logging.info("init database on cluster mode successfully")
|
logging.info("init database on cluster mode successfully")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -144,8 +144,9 @@ def init_llm_factory():
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
break
|
break
|
||||||
|
doc_count = DocumentService.get_all_kb_doc_count()
|
||||||
for kb_id in KnowledgebaseService.get_all_ids():
|
for kb_id in KnowledgebaseService.get_all_ids():
|
||||||
KnowledgebaseService.update_document_number_in_init(kb_id=kb_id, doc_num=DocumentService.get_kb_doc_count(kb_id))
|
KnowledgebaseService.update_document_number_in_init(kb_id=kb_id, doc_num=doc_count.get(kb_id, 0))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
0
api/db/joint_services/__init__.py
Normal file
0
api/db/joint_services/__init__.py
Normal file
120
api/db/joint_services/user_account_service.py
Normal file
120
api/db/joint_services/user_account_service.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from api import settings
|
||||||
|
from api.db import FileType, UserTenantRole
|
||||||
|
from api.db.db_models import TenantLLM
|
||||||
|
from api.db.services.llm_service import get_init_tenant_llm
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
|
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_new_user(user_info: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Add a new user, and create tenant, tenant llm, file folder for new user.
|
||||||
|
:param user_info: {
|
||||||
|
"email": <example@example.com>,
|
||||||
|
"nickname": <str, "name">,
|
||||||
|
"password": <decrypted password>,
|
||||||
|
"login_channel": <enum, "password">,
|
||||||
|
"is_superuser": <bool, role == "admin">,
|
||||||
|
}
|
||||||
|
:return: {
|
||||||
|
"success": <bool>,
|
||||||
|
"user_info": <dict>, # if true, return user_info
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# generate user_id and access_token for user
|
||||||
|
user_id = uuid.uuid1().hex
|
||||||
|
user_info['id'] = user_id
|
||||||
|
user_info['access_token'] = uuid.uuid1().hex
|
||||||
|
# construct tenant info
|
||||||
|
tenant = {
|
||||||
|
"id": user_id,
|
||||||
|
"name": user_info["nickname"] + "‘s Kingdom",
|
||||||
|
"llm_id": settings.CHAT_MDL,
|
||||||
|
"embd_id": settings.EMBEDDING_MDL,
|
||||||
|
"asr_id": settings.ASR_MDL,
|
||||||
|
"parser_ids": settings.PARSERS,
|
||||||
|
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||||
|
"rerank_id": settings.RERANK_MDL,
|
||||||
|
}
|
||||||
|
usr_tenant = {
|
||||||
|
"tenant_id": user_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"invited_by": user_id,
|
||||||
|
"role": UserTenantRole.OWNER,
|
||||||
|
}
|
||||||
|
# construct file folder info
|
||||||
|
file_id = uuid.uuid1().hex
|
||||||
|
file = {
|
||||||
|
"id": file_id,
|
||||||
|
"parent_id": file_id,
|
||||||
|
"tenant_id": user_id,
|
||||||
|
"created_by": user_id,
|
||||||
|
"name": "/",
|
||||||
|
"type": FileType.FOLDER.value,
|
||||||
|
"size": 0,
|
||||||
|
"location": "",
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
tenant_llm = get_init_tenant_llm(user_id)
|
||||||
|
|
||||||
|
if not UserService.save(**user_info):
|
||||||
|
return {"success": False}
|
||||||
|
|
||||||
|
TenantService.insert(**tenant)
|
||||||
|
UserTenantService.insert(**usr_tenant)
|
||||||
|
TenantLLMService.insert_many(tenant_llm)
|
||||||
|
FileService.insert(file)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"user_info": user_info,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as create_error:
|
||||||
|
logging.exception(create_error)
|
||||||
|
# rollback
|
||||||
|
try:
|
||||||
|
TenantService.delete_by_id(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
try:
|
||||||
|
u = UserTenantService.query(tenant_id=user_id)
|
||||||
|
if u:
|
||||||
|
UserTenantService.delete_by_id(u[0].id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
try:
|
||||||
|
TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
try:
|
||||||
|
FileService.delete_by_id(file["id"])
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
# delete user row finally
|
||||||
|
try:
|
||||||
|
UserService.delete_by_id(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
# reraise
|
||||||
|
raise create_error
|
||||||
@ -19,7 +19,7 @@ from pathlib import PurePath
|
|||||||
from .user_service import UserService as UserService
|
from .user_service import UserService as UserService
|
||||||
|
|
||||||
|
|
||||||
def split_name_counter(filename: str) -> tuple[str, int | None]:
|
def _split_name_counter(filename: str) -> tuple[str, int | None]:
|
||||||
"""
|
"""
|
||||||
Splits a filename into main part and counter (if present in parentheses).
|
Splits a filename into main part and counter (if present in parentheses).
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ def duplicate_name(query_func, **kwargs) -> str:
|
|||||||
stem = path.stem
|
stem = path.stem
|
||||||
suffix = path.suffix
|
suffix = path.suffix
|
||||||
|
|
||||||
main_part, counter = split_name_counter(stem)
|
main_part, counter = _split_name_counter(stem)
|
||||||
counter = counter + 1 if counter else 1
|
counter = counter + 1 if counter else 1
|
||||||
|
|
||||||
new_name = f"{main_part}({counter}){suffix}"
|
new_name = f"{main_part}({counter}){suffix}"
|
||||||
|
|||||||
@ -61,6 +61,36 @@ class UserCanvasService(CommonService):
|
|||||||
|
|
||||||
return list(agents.dicts())
|
return list(agents.dicts())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
|
||||||
|
# will get all permitted agents, be cautious
|
||||||
|
fields = [
|
||||||
|
cls.model.title,
|
||||||
|
cls.model.permission,
|
||||||
|
cls.model.canvas_type,
|
||||||
|
cls.model.canvas_category
|
||||||
|
]
|
||||||
|
# find team agents and owned agents
|
||||||
|
agents = cls.model.select(*fields).where(
|
||||||
|
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
|
||||||
|
cls.model.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# sort by create_time, asc
|
||||||
|
agents.order_by(cls.model.create_time.asc())
|
||||||
|
# maybe cause slow query by deep paginate, optimize later
|
||||||
|
offset, limit = 0, 50
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
ag_batch = agents.offset(offset).limit(limit)
|
||||||
|
_temp = list(ag_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_tenant_id(cls, pid):
|
def get_by_tenant_id(cls, pid):
|
||||||
|
|||||||
@ -14,12 +14,24 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||||
import peewee
|
import peewee
|
||||||
|
from peewee import InterfaceError, OperationalError
|
||||||
|
|
||||||
from api.db.db_models import DB
|
from api.db.db_models import DB
|
||||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||||
|
|
||||||
|
def retry_db_operation(func):
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=1, max=5),
|
||||||
|
retry=retry_if_exception_type((InterfaceError, OperationalError)),
|
||||||
|
before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
class CommonService:
|
class CommonService:
|
||||||
"""Base service class that provides common database operations.
|
"""Base service class that provides common database operations.
|
||||||
@ -202,6 +214,7 @@ class CommonService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
|
@retry_db_operation
|
||||||
def update_by_id(cls, pid, data):
|
def update_by_id(cls, pid, data):
|
||||||
# Update a single record by ID
|
# Update a single record by ID
|
||||||
# Args:
|
# Args:
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from api.db.services.dialog_service import DialogService, chat
|
|||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from rag.prompts import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
|
|
||||||
|
|
||||||
class ConversationService(CommonService):
|
class ConversationService(CommonService):
|
||||||
|
|||||||
@ -39,8 +39,8 @@ from graphrag.general.mind_map_extractor import MindMapExtractor
|
|||||||
from rag.app.resume import forbidden_select_fields4resume
|
from rag.app.resume import forbidden_select_fields4resume
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp.search import index_name
|
from rag.nlp.search import index_name
|
||||||
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
|
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||||
from rag.prompts.prompts import gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||||
from rag.utils import num_tokens_from_string, rmSpace
|
from rag.utils import num_tokens_from_string, rmSpace
|
||||||
from rag.utils.tavily_conn import Tavily
|
from rag.utils.tavily_conn import Tavily
|
||||||
|
|
||||||
@ -176,7 +176,7 @@ def chat_solo(dialog, messages, stream=True):
|
|||||||
delta_ans = ""
|
delta_ans = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||||
answer = ans
|
answer = ans
|
||||||
delta_ans = ans[len(last_ans) :]
|
delta_ans = ans[len(last_ans):]
|
||||||
if num_tokens_from_string(delta_ans) < 16:
|
if num_tokens_from_string(delta_ans) < 16:
|
||||||
continue
|
continue
|
||||||
last_ans = answer
|
last_ans = answer
|
||||||
@ -261,13 +261,13 @@ def convert_conditions(metadata_condition):
|
|||||||
"not is": "≠"
|
"not is": "≠"
|
||||||
}
|
}
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
||||||
"key": cond["name"],
|
"key": cond["name"],
|
||||||
"value": cond["value"]
|
"value": cond["value"]
|
||||||
}
|
}
|
||||||
for cond in metadata_condition.get("conditions", [])
|
for cond in metadata_condition.get("conditions", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def meta_filter(metas: dict, filters: list[dict]):
|
def meta_filter(metas: dict, filters: list[dict]):
|
||||||
@ -284,19 +284,19 @@ def meta_filter(metas: dict, filters: list[dict]):
|
|||||||
value = str(value)
|
value = str(value)
|
||||||
|
|
||||||
for conds in [
|
for conds in [
|
||||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||||
(operator == "empty", not input),
|
(operator == "empty", not input),
|
||||||
(operator == "not empty", input),
|
(operator == "not empty", input),
|
||||||
(operator == "=", input == value),
|
(operator == "=", input == value),
|
||||||
(operator == "≠", input != value),
|
(operator == "≠", input != value),
|
||||||
(operator == ">", input > value),
|
(operator == ">", input > value),
|
||||||
(operator == "<", input < value),
|
(operator == "<", input < value),
|
||||||
(operator == "≥", input >= value),
|
(operator == "≥", input >= value),
|
||||||
(operator == "≤", input <= value),
|
(operator == "≤", input <= value),
|
||||||
]:
|
]:
|
||||||
try:
|
try:
|
||||||
if all(conds):
|
if all(conds):
|
||||||
ids.extend(docids)
|
ids.extend(docids)
|
||||||
@ -456,7 +456,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||||
if prompt_config.get("use_kg"):
|
if prompt_config.get("use_kg"):
|
||||||
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
||||||
|
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
kbinfos["chunks"].insert(0, ck)
|
kbinfos["chunks"].insert(0, ck)
|
||||||
|
|
||||||
@ -467,7 +468,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
retrieval_ts = timer()
|
retrieval_ts = timer()
|
||||||
if not knowledges and prompt_config.get("empty_response"):
|
if not knowledges and prompt_config.get("empty_response"):
|
||||||
empty_res = prompt_config["empty_response"]
|
empty_res = prompt_config["empty_response"]
|
||||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)}
|
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||||
|
"audio_binary": tts(tts_mdl, empty_res)}
|
||||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||||
|
|
||||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||||
@ -565,7 +567,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
|
|
||||||
if langfuse_tracer:
|
if langfuse_tracer:
|
||||||
langfuse_generation = langfuse_tracer.start_generation(
|
langfuse_generation = langfuse_tracer.start_generation(
|
||||||
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
|
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"],
|
||||||
|
input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
@ -575,12 +578,12 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
if thought:
|
if thought:
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
answer = ans
|
answer = ans
|
||||||
delta_ans = ans[len(last_ans) :]
|
delta_ans = ans[len(last_ans):]
|
||||||
if num_tokens_from_string(delta_ans) < 16:
|
if num_tokens_from_string(delta_ans) < 16:
|
||||||
continue
|
continue
|
||||||
last_ans = answer
|
last_ans = answer
|
||||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||||
delta_ans = answer[len(last_ans) :]
|
delta_ans = answer[len(last_ans):]
|
||||||
if delta_ans:
|
if delta_ans:
|
||||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||||
yield decorate_answer(thought + answer)
|
yield decorate_answer(thought + answer)
|
||||||
@ -676,7 +679,9 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
|
|
||||||
# compose Markdown table
|
# compose Markdown table
|
||||||
columns = (
|
columns = (
|
||||||
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
"|" + "|".join(
|
||||||
|
[re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
|
||||||
|
"|Source|" if docid_idx and docid_idx else "|")
|
||||||
)
|
)
|
||||||
|
|
||||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||||
@ -753,7 +758,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
doc_ids = None
|
doc_ids = None
|
||||||
|
|
||||||
kbinfos = retriever.retrieval(
|
kbinfos = retriever.retrieval(
|
||||||
question = question,
|
question=question,
|
||||||
embd_mdl=embd_mdl,
|
embd_mdl=embd_mdl,
|
||||||
tenant_ids=tenant_ids,
|
tenant_ids=tenant_ids,
|
||||||
kb_ids=kb_ids,
|
kb_ids=kb_ids,
|
||||||
@ -775,7 +780,8 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
|
|
||||||
def decorate_answer(answer):
|
def decorate_answer(answer):
|
||||||
nonlocal knowledges, kbinfos, sys_prompt
|
nonlocal knowledges, kbinfos, sys_prompt
|
||||||
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
|
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]],
|
||||||
|
embd_mdl, tkweight=0.7, vtweight=0.3)
|
||||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||||
if not recall_docs:
|
if not recall_docs:
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from io import BytesIO
|
|||||||
|
|
||||||
import trio
|
import trio
|
||||||
import xxhash
|
import xxhash
|
||||||
from peewee import fn
|
from peewee import fn, Case
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
|
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
|
||||||
@ -660,8 +660,16 @@ class DocumentService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_doc_count(cls, kb_id):
|
def get_kb_doc_count(cls, kb_id):
|
||||||
return len(cls.model.select(cls.model.id).where(
|
return cls.model.select().where(cls.model.kb_id == kb_id).count()
|
||||||
cls.model.kb_id == kb_id).dicts())
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_kb_doc_count(cls):
|
||||||
|
result = {}
|
||||||
|
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias('count')).group_by(cls.model.kb_id)
|
||||||
|
for row in rows:
|
||||||
|
result[row.kb_id] = row.count
|
||||||
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
@ -674,6 +682,53 @@ class DocumentService(CommonService):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
|
||||||
|
# cancelled: run == "2" but progress can vary
|
||||||
|
cancelled = (
|
||||||
|
cls.model.select(fn.COUNT(1))
|
||||||
|
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
|
||||||
|
.scalar()
|
||||||
|
)
|
||||||
|
|
||||||
|
row = (
|
||||||
|
cls.model.select(
|
||||||
|
# finished: progress == 1
|
||||||
|
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == 1, 1)], 0)), 0).alias("finished"),
|
||||||
|
|
||||||
|
# failed: progress == -1
|
||||||
|
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == -1, 1)], 0)), 0).alias("failed"),
|
||||||
|
|
||||||
|
# processing: 0 <= progress < 1
|
||||||
|
fn.COALESCE(
|
||||||
|
fn.SUM(
|
||||||
|
Case(
|
||||||
|
None,
|
||||||
|
[
|
||||||
|
(((cls.model.progress == 0) | ((cls.model.progress > 0) & (cls.model.progress < 1))), 1),
|
||||||
|
],
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
).alias("processing"),
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
(cls.model.kb_id == kb_id)
|
||||||
|
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL))
|
||||||
|
)
|
||||||
|
.dicts()
|
||||||
|
.get()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"processing": int(row["processing"]),
|
||||||
|
"finished": int(row["finished"]),
|
||||||
|
"failed": int(row["failed"]),
|
||||||
|
"cancelled": int(cancelled),
|
||||||
|
}
|
||||||
|
|
||||||
def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
||||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||||
hasher = xxhash.xxh64()
|
hasher = xxhash.xxh64()
|
||||||
@ -702,6 +757,8 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
|||||||
|
|
||||||
def get_queue_length(priority):
|
def get_queue_length(priority):
|
||||||
group_info = REDIS_CONN.queue_info(get_svr_queue_name(priority), SVR_CONSUMER_GROUP_NAME)
|
group_info = REDIS_CONN.queue_info(get_svr_queue_name(priority), SVR_CONSUMER_GROUP_NAME)
|
||||||
|
if not group_info:
|
||||||
|
return 0
|
||||||
return int(group_info.get("lag", 0) or 0)
|
return int(group_info.get("lag", 0) or 0)
|
||||||
|
|
||||||
|
|
||||||
@ -847,3 +904,4 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||||
|
|
||||||
return [d["id"] for d, _ in files]
|
return [d["id"] for d, _ in files]
|
||||||
|
|
||||||
|
|||||||
@ -190,6 +190,41 @@ class KnowledgebaseService(CommonService):
|
|||||||
|
|
||||||
return list(kbs.dicts()), count
|
return list(kbs.dicts()), count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id):
|
||||||
|
# will get all permitted kb, be cautious.
|
||||||
|
fields = [
|
||||||
|
cls.model.name,
|
||||||
|
cls.model.language,
|
||||||
|
cls.model.permission,
|
||||||
|
cls.model.doc_num,
|
||||||
|
cls.model.token_num,
|
||||||
|
cls.model.chunk_num,
|
||||||
|
cls.model.status,
|
||||||
|
cls.model.create_date,
|
||||||
|
cls.model.update_date
|
||||||
|
]
|
||||||
|
# find team kb and owned kb
|
||||||
|
kbs = cls.model.select(*fields).where(
|
||||||
|
(cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | (
|
||||||
|
cls.model.tenant_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# sort by create_time asc
|
||||||
|
kbs.order_by(cls.model.create_time.asc())
|
||||||
|
# maybe cause slow query by deep paginate, optimize later.
|
||||||
|
offset, limit = 0, 50
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
kb_batch = kbs.offset(offset).limit(limit)
|
||||||
|
_temp = list(kb_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_ids(cls, tenant_id):
|
def get_kb_ids(cls, tenant_id):
|
||||||
|
|||||||
@ -100,6 +100,12 @@ class UserService(CommonService):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def query_user_by_email(cls, email):
|
||||||
|
users = cls.model.select().where((cls.model.email == email))
|
||||||
|
return list(users)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def save(cls, **kwargs):
|
def save(cls, **kwargs):
|
||||||
@ -133,6 +139,17 @@ class UserService(CommonService):
|
|||||||
cls.model.update(user_dict).where(
|
cls.model.update(user_dict).where(
|
||||||
cls.model.id == user_id).execute()
|
cls.model.id == user_id).execute()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def update_user_password(cls, user_id, new_password):
|
||||||
|
with DB.atomic():
|
||||||
|
update_dict = {
|
||||||
|
"password": generate_password_hash(str(new_password)),
|
||||||
|
"update_time": current_timestamp(),
|
||||||
|
"update_date": datetime_format(datetime.now())
|
||||||
|
}
|
||||||
|
cls.model.update(update_dict).where(cls.model.id == user_id).execute()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def is_admin(cls, user_id):
|
def is_admin(cls, user_id):
|
||||||
@ -140,6 +157,12 @@ class UserService(CommonService):
|
|||||||
cls.model.id == user_id,
|
cls.model.id == user_id,
|
||||||
cls.model.is_superuser == 1).count() > 0
|
cls.model.is_superuser == 1).count() > 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_users(cls):
|
||||||
|
users = cls.model.select()
|
||||||
|
return list(users)
|
||||||
|
|
||||||
|
|
||||||
class TenantService(CommonService):
|
class TenantService(CommonService):
|
||||||
"""Service class for managing tenant-related database operations.
|
"""Service class for managing tenant-related database operations.
|
||||||
|
|||||||
@ -41,7 +41,7 @@ from api import utils
|
|||||||
from api.db.db_models import init_database_tables as init_web_db
|
from api.db.db_models import init_database_tables as init_web_db
|
||||||
from api.db.init_data import init_web_data
|
from api.db.init_data import init_web_data
|
||||||
from api.versions import get_ragflow_version
|
from api.versions import get_ragflow_version
|
||||||
from api.utils import show_configs
|
from api.utils.configs import show_configs
|
||||||
from rag.settings import print_rag_settings
|
from rag.settings import print_rag_settings
|
||||||
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||||
from rag.utils.redis_conn import RedisDistributedLock
|
from rag.utils.redis_conn import RedisDistributedLock
|
||||||
|
|||||||
@ -24,7 +24,7 @@ import rag.utils.es_conn
|
|||||||
import rag.utils.infinity_conn
|
import rag.utils.infinity_conn
|
||||||
import rag.utils.opensearch_conn
|
import rag.utils.opensearch_conn
|
||||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||||
from api.utils import decrypt_database_config, get_base_config
|
from api.utils.configs import decrypt_database_config, get_base_config
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
|
|
||||||
|
|||||||
@ -16,184 +16,15 @@
|
|||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import requests
|
import requests
|
||||||
import logging
|
|
||||||
import copy
|
|
||||||
from enum import Enum, IntEnum
|
|
||||||
import importlib
|
import importlib
|
||||||
from Cryptodome.PublicKey import RSA
|
|
||||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
|
||||||
from filelock import FileLock
|
|
||||||
from api.constants import SERVICE_CONF
|
|
||||||
|
|
||||||
from . import file_utils
|
from .common import string_to_bytes
|
||||||
|
|
||||||
|
|
||||||
def conf_realpath(conf_name):
|
|
||||||
conf_path = f"conf/{conf_name}"
|
|
||||||
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
|
||||||
|
|
||||||
|
|
||||||
def read_config(conf_name=SERVICE_CONF):
|
|
||||||
local_config = {}
|
|
||||||
local_path = conf_realpath(f'local.{conf_name}')
|
|
||||||
|
|
||||||
# load local config file
|
|
||||||
if os.path.exists(local_path):
|
|
||||||
local_config = file_utils.load_yaml_conf(local_path)
|
|
||||||
if not isinstance(local_config, dict):
|
|
||||||
raise ValueError(f'Invalid config file: "{local_path}".')
|
|
||||||
|
|
||||||
global_config_path = conf_realpath(conf_name)
|
|
||||||
global_config = file_utils.load_yaml_conf(global_config_path)
|
|
||||||
|
|
||||||
if not isinstance(global_config, dict):
|
|
||||||
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
|
||||||
|
|
||||||
global_config.update(local_config)
|
|
||||||
return global_config
|
|
||||||
|
|
||||||
|
|
||||||
CONFIGS = read_config()
|
|
||||||
|
|
||||||
|
|
||||||
def show_configs():
|
|
||||||
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
|
||||||
for k, v in CONFIGS.items():
|
|
||||||
if isinstance(v, dict):
|
|
||||||
if "password" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["password"] = "*" * 8
|
|
||||||
if "access_key" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["access_key"] = "*" * 8
|
|
||||||
if "secret_key" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["secret_key"] = "*" * 8
|
|
||||||
if "secret" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["secret"] = "*" * 8
|
|
||||||
if "sas_token" in v:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
v["sas_token"] = "*" * 8
|
|
||||||
if "oauth" in k:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
for key, val in v.items():
|
|
||||||
if "client_secret" in val:
|
|
||||||
val["client_secret"] = "*" * 8
|
|
||||||
if "authentication" in k:
|
|
||||||
v = copy.deepcopy(v)
|
|
||||||
for key, val in v.items():
|
|
||||||
if "http_secret_key" in val:
|
|
||||||
val["http_secret_key"] = "*" * 8
|
|
||||||
msg += f"\n\t{k}: {v}"
|
|
||||||
logging.info(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def get_base_config(key, default=None):
|
|
||||||
if key is None:
|
|
||||||
return None
|
|
||||||
if default is None:
|
|
||||||
default = os.environ.get(key.upper())
|
|
||||||
return CONFIGS.get(key, default)
|
|
||||||
|
|
||||||
|
|
||||||
use_deserialize_safe_module = get_base_config(
|
|
||||||
'use_deserialize_safe_module', False)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseType:
|
|
||||||
def to_dict(self):
|
|
||||||
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
|
||||||
|
|
||||||
def to_dict_with_type(self):
|
|
||||||
def _dict(obj):
|
|
||||||
module = None
|
|
||||||
if issubclass(obj.__class__, BaseType):
|
|
||||||
data = {}
|
|
||||||
for attr, v in obj.__dict__.items():
|
|
||||||
k = attr.lstrip("_")
|
|
||||||
data[k] = _dict(v)
|
|
||||||
module = obj.__module__
|
|
||||||
elif isinstance(obj, (list, tuple)):
|
|
||||||
data = []
|
|
||||||
for i, vv in enumerate(obj):
|
|
||||||
data.append(_dict(vv))
|
|
||||||
elif isinstance(obj, dict):
|
|
||||||
data = {}
|
|
||||||
for _k, vv in obj.items():
|
|
||||||
data[_k] = _dict(vv)
|
|
||||||
else:
|
|
||||||
data = obj
|
|
||||||
return {"type": obj.__class__.__name__,
|
|
||||||
"data": data, "module": module}
|
|
||||||
|
|
||||||
return _dict(self)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomJSONEncoder(json.JSONEncoder):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self._with_type = kwargs.pop("with_type", False)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def default(self, obj):
|
|
||||||
if isinstance(obj, datetime.datetime):
|
|
||||||
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
elif isinstance(obj, datetime.date):
|
|
||||||
return obj.strftime('%Y-%m-%d')
|
|
||||||
elif isinstance(obj, datetime.timedelta):
|
|
||||||
return str(obj)
|
|
||||||
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
|
||||||
return obj.value
|
|
||||||
elif isinstance(obj, set):
|
|
||||||
return list(obj)
|
|
||||||
elif issubclass(type(obj), BaseType):
|
|
||||||
if not self._with_type:
|
|
||||||
return obj.to_dict()
|
|
||||||
else:
|
|
||||||
return obj.to_dict_with_type()
|
|
||||||
elif isinstance(obj, type):
|
|
||||||
return obj.__name__
|
|
||||||
else:
|
|
||||||
return json.JSONEncoder.default(self, obj)
|
|
||||||
|
|
||||||
|
|
||||||
def rag_uuid():
|
|
||||||
return uuid.uuid1().hex
|
|
||||||
|
|
||||||
|
|
||||||
def string_to_bytes(string):
|
|
||||||
return string if isinstance(
|
|
||||||
string, bytes) else string.encode(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_string(byte):
|
|
||||||
return byte.decode(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def json_dumps(src, byte=False, indent=None, with_type=False):
|
|
||||||
dest = json.dumps(
|
|
||||||
src,
|
|
||||||
indent=indent,
|
|
||||||
cls=CustomJSONEncoder,
|
|
||||||
with_type=with_type)
|
|
||||||
if byte:
|
|
||||||
dest = string_to_bytes(dest)
|
|
||||||
return dest
|
|
||||||
|
|
||||||
|
|
||||||
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
|
||||||
if isinstance(src, bytes):
|
|
||||||
src = bytes_to_string(src)
|
|
||||||
return json.loads(src, object_hook=object_hook,
|
|
||||||
object_pairs_hook=object_pairs_hook)
|
|
||||||
|
|
||||||
|
|
||||||
def current_timestamp():
|
def current_timestamp():
|
||||||
@ -215,45 +46,6 @@ def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
|
|||||||
return time_stamp
|
return time_stamp
|
||||||
|
|
||||||
|
|
||||||
def serialize_b64(src, to_str=False):
|
|
||||||
dest = base64.b64encode(pickle.dumps(src))
|
|
||||||
if not to_str:
|
|
||||||
return dest
|
|
||||||
else:
|
|
||||||
return bytes_to_string(dest)
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_b64(src):
|
|
||||||
src = base64.b64decode(
|
|
||||||
string_to_bytes(src) if isinstance(
|
|
||||||
src, str) else src)
|
|
||||||
if use_deserialize_safe_module:
|
|
||||||
return restricted_loads(src)
|
|
||||||
return pickle.loads(src)
|
|
||||||
|
|
||||||
|
|
||||||
safe_module = {
|
|
||||||
'numpy',
|
|
||||||
'rag_flow'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class RestrictedUnpickler(pickle.Unpickler):
|
|
||||||
def find_class(self, module, name):
|
|
||||||
import importlib
|
|
||||||
if module.split('.')[0] in safe_module:
|
|
||||||
_module = importlib.import_module(module)
|
|
||||||
return getattr(_module, name)
|
|
||||||
# Forbid everything else.
|
|
||||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
|
||||||
(module, name))
|
|
||||||
|
|
||||||
|
|
||||||
def restricted_loads(src):
|
|
||||||
"""Helper function analogous to pickle.loads()."""
|
|
||||||
return RestrictedUnpickler(io.BytesIO(src)).load()
|
|
||||||
|
|
||||||
|
|
||||||
def get_lan_ip():
|
def get_lan_ip():
|
||||||
if os.name != "nt":
|
if os.name != "nt":
|
||||||
import fcntl
|
import fcntl
|
||||||
@ -298,47 +90,6 @@ def from_dict_hook(in_dict: dict):
|
|||||||
return in_dict
|
return in_dict
|
||||||
|
|
||||||
|
|
||||||
def decrypt_database_password(password):
|
|
||||||
encrypt_password = get_base_config("encrypt_password", False)
|
|
||||||
encrypt_module = get_base_config("encrypt_module", False)
|
|
||||||
private_key = get_base_config("private_key", None)
|
|
||||||
|
|
||||||
if not password or not encrypt_password:
|
|
||||||
return password
|
|
||||||
|
|
||||||
if not private_key:
|
|
||||||
raise ValueError("No private key")
|
|
||||||
|
|
||||||
module_fun = encrypt_module.split("#")
|
|
||||||
pwdecrypt_fun = getattr(
|
|
||||||
importlib.import_module(
|
|
||||||
module_fun[0]),
|
|
||||||
module_fun[1])
|
|
||||||
|
|
||||||
return pwdecrypt_fun(private_key, password)
|
|
||||||
|
|
||||||
|
|
||||||
def decrypt_database_config(
|
|
||||||
database=None, passwd_key="password", name="database"):
|
|
||||||
if not database:
|
|
||||||
database = get_base_config(name, {})
|
|
||||||
|
|
||||||
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
|
||||||
return database
|
|
||||||
|
|
||||||
|
|
||||||
def update_config(key, value, conf_name=SERVICE_CONF):
|
|
||||||
conf_path = conf_realpath(conf_name=conf_name)
|
|
||||||
if not os.path.isabs(conf_path):
|
|
||||||
conf_path = os.path.join(
|
|
||||||
file_utils.get_project_base_directory(), conf_path)
|
|
||||||
|
|
||||||
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
|
||||||
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
|
||||||
config[key] = value
|
|
||||||
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
|
||||||
|
|
||||||
|
|
||||||
def get_uuid():
|
def get_uuid():
|
||||||
return uuid.uuid1().hex
|
return uuid.uuid1().hex
|
||||||
|
|
||||||
@ -363,37 +114,6 @@ def elapsed2time(elapsed):
|
|||||||
return '%02d:%02d:%02d' % (hour, minuter, second)
|
return '%02d:%02d:%02d' % (hour, minuter, second)
|
||||||
|
|
||||||
|
|
||||||
def decrypt(line):
|
|
||||||
file_path = os.path.join(
|
|
||||||
file_utils.get_project_base_directory(),
|
|
||||||
"conf",
|
|
||||||
"private.pem")
|
|
||||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
|
||||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
|
||||||
return cipher.decrypt(base64.b64decode(
|
|
||||||
line), "Fail to decrypt password!").decode('utf-8')
|
|
||||||
|
|
||||||
|
|
||||||
def decrypt2(crypt_text):
|
|
||||||
from base64 import b64decode, b16decode
|
|
||||||
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
|
|
||||||
from Crypto.PublicKey import RSA
|
|
||||||
decode_data = b64decode(crypt_text)
|
|
||||||
if len(decode_data) == 127:
|
|
||||||
hex_fixed = '00' + decode_data.hex()
|
|
||||||
decode_data = b16decode(hex_fixed.upper())
|
|
||||||
|
|
||||||
file_path = os.path.join(
|
|
||||||
file_utils.get_project_base_directory(),
|
|
||||||
"conf",
|
|
||||||
"private.pem")
|
|
||||||
pem = open(file_path).read()
|
|
||||||
rsa_key = RSA.importKey(pem, "Welcome")
|
|
||||||
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
|
|
||||||
decrypt_text = cipher.decrypt(decode_data, None)
|
|
||||||
return (b64decode(decrypt_text)).decode()
|
|
||||||
|
|
||||||
|
|
||||||
def download_img(url):
|
def download_img(url):
|
||||||
if not url:
|
if not url:
|
||||||
return ""
|
return ""
|
||||||
@ -408,5 +128,5 @@ def delta_seconds(date_string: str):
|
|||||||
return (datetime.datetime.now() - dt).total_seconds()
|
return (datetime.datetime.now() - dt).total_seconds()
|
||||||
|
|
||||||
|
|
||||||
def hash_str2int(line:str, mod: int=10 ** 8) -> int:
|
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
|
||||||
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||||
@ -39,6 +39,7 @@ from flask import (
|
|||||||
make_response,
|
make_response,
|
||||||
send_file,
|
send_file,
|
||||||
)
|
)
|
||||||
|
from flask_login import current_user
|
||||||
from flask import (
|
from flask import (
|
||||||
request as flask_request,
|
request as flask_request,
|
||||||
)
|
)
|
||||||
@ -48,10 +49,13 @@ from werkzeug.http import HTTP_STATUS_CODES
|
|||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
||||||
|
from api.db import ActiveEnum
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
|
from api.db.services import UserService
|
||||||
from api.db.services.llm_service import LLMService
|
from api.db.services.llm_service import LLMService
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
from api.utils.json import CustomJSONEncoder, json_dumps
|
||||||
|
from api.utils import get_uuid
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||||
|
|
||||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||||
@ -226,6 +230,18 @@ def not_allowed_parameters(*params):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def active_required(f):
|
||||||
|
@wraps(f)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
user_id = current_user.id
|
||||||
|
usr = UserService.filter_by_id(user_id)
|
||||||
|
# check is_active
|
||||||
|
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
|
||||||
|
return get_json_result(code=settings.RetCode.FORBIDDEN, message="User isn't active, please activate first.")
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def is_localhost(ip):
|
def is_localhost(ip):
|
||||||
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
|
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
|
||||||
|
|
||||||
|
|||||||
23
api/utils/common.py
Normal file
23
api/utils/common.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
def string_to_bytes(string):
|
||||||
|
return string if isinstance(
|
||||||
|
string, bytes) else string.encode(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_string(byte):
|
||||||
|
return byte.decode(encoding="utf-8")
|
||||||
179
api/utils/configs.py
Normal file
179
api/utils/configs.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import base64
|
||||||
|
import pickle
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from api.utils import file_utils
|
||||||
|
from filelock import FileLock
|
||||||
|
from api.utils.common import bytes_to_string, string_to_bytes
|
||||||
|
from api.constants import SERVICE_CONF
|
||||||
|
|
||||||
|
|
||||||
|
def conf_realpath(conf_name):
|
||||||
|
conf_path = f"conf/{conf_name}"
|
||||||
|
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||||
|
|
||||||
|
|
||||||
|
def read_config(conf_name=SERVICE_CONF):
|
||||||
|
local_config = {}
|
||||||
|
local_path = conf_realpath(f'local.{conf_name}')
|
||||||
|
|
||||||
|
# load local config file
|
||||||
|
if os.path.exists(local_path):
|
||||||
|
local_config = file_utils.load_yaml_conf(local_path)
|
||||||
|
if not isinstance(local_config, dict):
|
||||||
|
raise ValueError(f'Invalid config file: "{local_path}".')
|
||||||
|
|
||||||
|
global_config_path = conf_realpath(conf_name)
|
||||||
|
global_config = file_utils.load_yaml_conf(global_config_path)
|
||||||
|
|
||||||
|
if not isinstance(global_config, dict):
|
||||||
|
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
||||||
|
|
||||||
|
global_config.update(local_config)
|
||||||
|
return global_config
|
||||||
|
|
||||||
|
|
||||||
|
CONFIGS = read_config()
|
||||||
|
|
||||||
|
|
||||||
|
def show_configs():
|
||||||
|
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
||||||
|
for k, v in CONFIGS.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
if "password" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["password"] = "*" * 8
|
||||||
|
if "access_key" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["access_key"] = "*" * 8
|
||||||
|
if "secret_key" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["secret_key"] = "*" * 8
|
||||||
|
if "secret" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["secret"] = "*" * 8
|
||||||
|
if "sas_token" in v:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
v["sas_token"] = "*" * 8
|
||||||
|
if "oauth" in k:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
for key, val in v.items():
|
||||||
|
if "client_secret" in val:
|
||||||
|
val["client_secret"] = "*" * 8
|
||||||
|
if "authentication" in k:
|
||||||
|
v = copy.deepcopy(v)
|
||||||
|
for key, val in v.items():
|
||||||
|
if "http_secret_key" in val:
|
||||||
|
val["http_secret_key"] = "*" * 8
|
||||||
|
msg += f"\n\t{k}: {v}"
|
||||||
|
logging.info(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_config(key, default=None):
|
||||||
|
if key is None:
|
||||||
|
return None
|
||||||
|
if default is None:
|
||||||
|
default = os.environ.get(key.upper())
|
||||||
|
return CONFIGS.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_database_password(password):
|
||||||
|
encrypt_password = get_base_config("encrypt_password", False)
|
||||||
|
encrypt_module = get_base_config("encrypt_module", False)
|
||||||
|
private_key = get_base_config("private_key", None)
|
||||||
|
|
||||||
|
if not password or not encrypt_password:
|
||||||
|
return password
|
||||||
|
|
||||||
|
if not private_key:
|
||||||
|
raise ValueError("No private key")
|
||||||
|
|
||||||
|
module_fun = encrypt_module.split("#")
|
||||||
|
pwdecrypt_fun = getattr(
|
||||||
|
importlib.import_module(
|
||||||
|
module_fun[0]),
|
||||||
|
module_fun[1])
|
||||||
|
|
||||||
|
return pwdecrypt_fun(private_key, password)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_database_config(
|
||||||
|
database=None, passwd_key="password", name="database"):
|
||||||
|
if not database:
|
||||||
|
database = get_base_config(name, {})
|
||||||
|
|
||||||
|
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
||||||
|
return database
|
||||||
|
|
||||||
|
|
||||||
|
def update_config(key, value, conf_name=SERVICE_CONF):
|
||||||
|
conf_path = conf_realpath(conf_name=conf_name)
|
||||||
|
if not os.path.isabs(conf_path):
|
||||||
|
conf_path = os.path.join(
|
||||||
|
file_utils.get_project_base_directory(), conf_path)
|
||||||
|
|
||||||
|
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||||
|
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||||
|
config[key] = value
|
||||||
|
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
||||||
|
|
||||||
|
|
||||||
|
safe_module = {
|
||||||
|
'numpy',
|
||||||
|
'rag_flow'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RestrictedUnpickler(pickle.Unpickler):
|
||||||
|
def find_class(self, module, name):
|
||||||
|
import importlib
|
||||||
|
if module.split('.')[0] in safe_module:
|
||||||
|
_module = importlib.import_module(module)
|
||||||
|
return getattr(_module, name)
|
||||||
|
# Forbid everything else.
|
||||||
|
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||||
|
(module, name))
|
||||||
|
|
||||||
|
|
||||||
|
def restricted_loads(src):
|
||||||
|
"""Helper function analogous to pickle.loads()."""
|
||||||
|
return RestrictedUnpickler(io.BytesIO(src)).load()
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_b64(src, to_str=False):
|
||||||
|
dest = base64.b64encode(pickle.dumps(src))
|
||||||
|
if not to_str:
|
||||||
|
return dest
|
||||||
|
else:
|
||||||
|
return bytes_to_string(dest)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_b64(src):
|
||||||
|
src = base64.b64decode(
|
||||||
|
string_to_bytes(src) if isinstance(
|
||||||
|
src, str) else src)
|
||||||
|
use_deserialize_safe_module = get_base_config(
|
||||||
|
'use_deserialize_safe_module', False)
|
||||||
|
if use_deserialize_safe_module:
|
||||||
|
return restricted_loads(src)
|
||||||
|
return pickle.loads(src)
|
||||||
64
api/utils/crypt.py
Normal file
64
api/utils/crypt.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from Cryptodome.PublicKey import RSA
|
||||||
|
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||||
|
from api.utils import file_utils
|
||||||
|
|
||||||
|
|
||||||
|
def crypt(line):
|
||||||
|
"""
|
||||||
|
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
|
||||||
|
"""
|
||||||
|
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
|
||||||
|
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||||
|
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||||
|
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
|
||||||
|
encrypted_password = cipher.encrypt(password_base64.encode())
|
||||||
|
return base64.b64encode(encrypted_password).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt(line):
|
||||||
|
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
|
||||||
|
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||||
|
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||||
|
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt2(crypt_text):
|
||||||
|
from base64 import b64decode, b16decode
|
||||||
|
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
|
||||||
|
from Crypto.PublicKey import RSA
|
||||||
|
decode_data = b64decode(crypt_text)
|
||||||
|
if len(decode_data) == 127:
|
||||||
|
hex_fixed = '00' + decode_data.hex()
|
||||||
|
decode_data = b16decode(hex_fixed.upper())
|
||||||
|
|
||||||
|
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
|
||||||
|
pem = open(file_path).read()
|
||||||
|
rsa_key = RSA.importKey(pem, "Welcome")
|
||||||
|
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
|
||||||
|
decrypt_text = cipher.decrypt(decode_data, None)
|
||||||
|
return (b64decode(decrypt_text)).decode()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
passwd = crypt(sys.argv[1])
|
||||||
|
print(passwd)
|
||||||
|
print(decrypt(passwd))
|
||||||
107
api/utils/health_utils.py
Normal file
107
api/utils/health_utils.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
|
from api import settings
|
||||||
|
from api.db.db_models import DB
|
||||||
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
|
|
||||||
|
def _ok_nok(ok: bool) -> str:
|
||||||
|
return "ok" if ok else "nok"
|
||||||
|
|
||||||
|
|
||||||
|
def check_db() -> tuple[bool, dict]:
|
||||||
|
st = timer()
|
||||||
|
try:
|
||||||
|
# lightweight probe; works for MySQL/Postgres
|
||||||
|
DB.execute_sql("SELECT 1")
|
||||||
|
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||||
|
except Exception as e:
|
||||||
|
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def check_redis() -> tuple[bool, dict]:
|
||||||
|
st = timer()
|
||||||
|
try:
|
||||||
|
ok = bool(REDIS_CONN.health())
|
||||||
|
return ok, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||||
|
except Exception as e:
|
||||||
|
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def check_doc_engine() -> tuple[bool, dict]:
|
||||||
|
st = timer()
|
||||||
|
try:
|
||||||
|
meta = settings.docStoreConn.health()
|
||||||
|
# treat any successful call as ok
|
||||||
|
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
|
||||||
|
except Exception as e:
|
||||||
|
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def check_storage() -> tuple[bool, dict]:
|
||||||
|
st = timer()
|
||||||
|
try:
|
||||||
|
STORAGE_IMPL.health()
|
||||||
|
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||||
|
except Exception as e:
|
||||||
|
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run_health_checks() -> tuple[dict, bool]:
|
||||||
|
result: dict[str, str | dict] = {}
|
||||||
|
|
||||||
|
db_ok, db_meta = check_db()
|
||||||
|
result["db"] = _ok_nok(db_ok)
|
||||||
|
if not db_ok:
|
||||||
|
result.setdefault("_meta", {})["db"] = db_meta
|
||||||
|
|
||||||
|
try:
|
||||||
|
redis_ok, redis_meta = check_redis()
|
||||||
|
result["redis"] = _ok_nok(redis_ok)
|
||||||
|
if not redis_ok:
|
||||||
|
result.setdefault("_meta", {})["redis"] = redis_meta
|
||||||
|
except Exception:
|
||||||
|
result["redis"] = "nok"
|
||||||
|
|
||||||
|
try:
|
||||||
|
doc_ok, doc_meta = check_doc_engine()
|
||||||
|
result["doc_engine"] = _ok_nok(doc_ok)
|
||||||
|
if not doc_ok:
|
||||||
|
result.setdefault("_meta", {})["doc_engine"] = doc_meta
|
||||||
|
except Exception:
|
||||||
|
result["doc_engine"] = "nok"
|
||||||
|
|
||||||
|
try:
|
||||||
|
sto_ok, sto_meta = check_storage()
|
||||||
|
result["storage"] = _ok_nok(sto_ok)
|
||||||
|
if not sto_ok:
|
||||||
|
result.setdefault("_meta", {})["storage"] = sto_meta
|
||||||
|
except Exception:
|
||||||
|
result["storage"] = "nok"
|
||||||
|
|
||||||
|
|
||||||
|
all_ok = (result.get("db") == "ok") and (result.get("redis") == "ok") and (result.get("doc_engine") == "ok") and (result.get("storage") == "ok")
|
||||||
|
result["status"] = "ok" if all_ok else "nok"
|
||||||
|
return result, all_ok
|
||||||
|
|
||||||
|
|
||||||
78
api/utils/json.py
Normal file
78
api/utils/json.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
from enum import Enum, IntEnum
|
||||||
|
from api.utils.common import string_to_bytes, bytes_to_string
|
||||||
|
|
||||||
|
|
||||||
|
class BaseType:
|
||||||
|
def to_dict(self):
|
||||||
|
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
||||||
|
|
||||||
|
def to_dict_with_type(self):
|
||||||
|
def _dict(obj):
|
||||||
|
module = None
|
||||||
|
if issubclass(obj.__class__, BaseType):
|
||||||
|
data = {}
|
||||||
|
for attr, v in obj.__dict__.items():
|
||||||
|
k = attr.lstrip("_")
|
||||||
|
data[k] = _dict(v)
|
||||||
|
module = obj.__module__
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
data = []
|
||||||
|
for i, vv in enumerate(obj):
|
||||||
|
data.append(_dict(vv))
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
data = {}
|
||||||
|
for _k, vv in obj.items():
|
||||||
|
data[_k] = _dict(vv)
|
||||||
|
else:
|
||||||
|
data = obj
|
||||||
|
return {"type": obj.__class__.__name__,
|
||||||
|
"data": data, "module": module}
|
||||||
|
|
||||||
|
return _dict(self)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomJSONEncoder(json.JSONEncoder):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self._with_type = kwargs.pop("with_type", False)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, datetime.datetime):
|
||||||
|
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
elif isinstance(obj, datetime.date):
|
||||||
|
return obj.strftime('%Y-%m-%d')
|
||||||
|
elif isinstance(obj, datetime.timedelta):
|
||||||
|
return str(obj)
|
||||||
|
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
||||||
|
return obj.value
|
||||||
|
elif isinstance(obj, set):
|
||||||
|
return list(obj)
|
||||||
|
elif issubclass(type(obj), BaseType):
|
||||||
|
if not self._with_type:
|
||||||
|
return obj.to_dict()
|
||||||
|
else:
|
||||||
|
return obj.to_dict_with_type()
|
||||||
|
elif isinstance(obj, type):
|
||||||
|
return obj.__name__
|
||||||
|
else:
|
||||||
|
return json.JSONEncoder.default(self, obj)
|
||||||
|
|
||||||
|
|
||||||
|
def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||||
|
dest = json.dumps(
|
||||||
|
src,
|
||||||
|
indent=indent,
|
||||||
|
cls=CustomJSONEncoder,
|
||||||
|
with_type=with_type)
|
||||||
|
if byte:
|
||||||
|
dest = string_to_bytes(dest)
|
||||||
|
return dest
|
||||||
|
|
||||||
|
|
||||||
|
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
||||||
|
if isinstance(src, bytes):
|
||||||
|
src = bytes_to_string(src)
|
||||||
|
return json.loads(src, object_hook=object_hook,
|
||||||
|
object_pairs_hook=object_pairs_hook)
|
||||||
@ -1,40 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from Cryptodome.PublicKey import RSA
|
|
||||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
|
||||||
from api.utils import decrypt, file_utils
|
|
||||||
|
|
||||||
|
|
||||||
def crypt(line):
|
|
||||||
file_path = os.path.join(
|
|
||||||
file_utils.get_project_base_directory(),
|
|
||||||
"conf",
|
|
||||||
"public.pem")
|
|
||||||
rsa_key = RSA.importKey(open(file_path).read(),"Welcome")
|
|
||||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
|
||||||
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
|
|
||||||
encrypted_password = cipher.encrypt(password_base64.encode())
|
|
||||||
return base64.b64encode(encrypted_password).decode('utf-8')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
passwd = crypt(sys.argv[1])
|
|
||||||
print(passwd)
|
|
||||||
print(decrypt(passwd))
|
|
||||||
19
chat_demo/index.html
Normal file
19
chat_demo/index.html
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
<iframe src="http://localhost:9222/next-chats/widget?shared_id=9dcfc68696c611f0bb789b9b8b765d12&from=chat&auth=U4MDU3NzkwOTZjNzExZjBiYjc4OWI5Yj&mode=master&streaming=false"
|
||||||
|
style="position:fixed;bottom:0;right:0;width:100px;height:100px;border:none;background:transparent;z-index:9999"
|
||||||
|
frameborder="0" allow="microphone;camera"></iframe>
|
||||||
|
<script>
|
||||||
|
window.addEventListener('message',e=>{
|
||||||
|
if(e.origin!=='http://localhost:9222')return;
|
||||||
|
if(e.data.type==='CREATE_CHAT_WINDOW'){
|
||||||
|
if(document.getElementById('chat-win'))return;
|
||||||
|
const i=document.createElement('iframe');
|
||||||
|
i.id='chat-win';i.src=e.data.src;
|
||||||
|
i.style.cssText='position:fixed;bottom:104px;right:24px;width:380px;height:500px;border:none;background:transparent;z-index:9998;display:none';
|
||||||
|
i.frameBorder='0';i.allow='microphone;camera';
|
||||||
|
document.body.appendChild(i);
|
||||||
|
}else if(e.data.type==='TOGGLE_CHAT'){
|
||||||
|
const w=document.getElementById('chat-win');
|
||||||
|
if(w)w.style.display=e.data.isOpen?'block':'none';
|
||||||
|
}else if(e.data.type==='SCROLL_PASSTHROUGH')window.scrollBy(0,e.data.deltaY);
|
||||||
|
});
|
||||||
|
</script>
|
||||||
154
chat_demo/widget_demo.html
Normal file
154
chat_demo/widget_demo.html
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Floating Chat Widget Demo</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
margin: 0;
|
||||||
|
padding: 40px;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
min-height: 100vh;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.demo-content {
|
||||||
|
max-width: 800px;
|
||||||
|
margin: 0 auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.demo-content h1 {
|
||||||
|
text-align: center;
|
||||||
|
font-size: 2.5rem;
|
||||||
|
margin-bottom: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.demo-content p {
|
||||||
|
font-size: 1.2rem;
|
||||||
|
line-height: 1.6;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.feature-list {
|
||||||
|
background: rgba(255, 255, 255, 0.1);
|
||||||
|
border-radius: 10px;
|
||||||
|
padding: 2rem;
|
||||||
|
margin: 2rem 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.feature-list h3 {
|
||||||
|
margin-top: 0;
|
||||||
|
font-size: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.feature-list ul {
|
||||||
|
list-style-type: none;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.feature-list li {
|
||||||
|
padding: 0.5rem 0;
|
||||||
|
padding-left: 1.5rem;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.feature-list li:before {
|
||||||
|
content: "✓";
|
||||||
|
position: absolute;
|
||||||
|
left: 0;
|
||||||
|
color: #4ade80;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="demo-content">
|
||||||
|
<h1>🚀 Floating Chat Widget Demo</h1>
|
||||||
|
|
||||||
|
<p>
|
||||||
|
Welcome to our demo page! This page simulates a real website with content.
|
||||||
|
Look for the floating chat button in the bottom-right corner - just like Intercom!
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div class="feature-list">
|
||||||
|
<h3>🎯 Widget Features</h3>
|
||||||
|
<ul>
|
||||||
|
<li>Floating button that stays visible while scrolling</li>
|
||||||
|
<li>Click to open/close the chat window</li>
|
||||||
|
<li>Minimize button to collapse the chat</li>
|
||||||
|
<li>Professional Intercom-style design</li>
|
||||||
|
<li>Unread message indicator (red badge)</li>
|
||||||
|
<li>Transparent background integration</li>
|
||||||
|
<li>Responsive design for all screen sizes</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p>
|
||||||
|
The chat widget is completely separate from your website's content and won't
|
||||||
|
interfere with your existing layout or functionality. It's designed to be
|
||||||
|
lightweight and performant.
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<p>
|
||||||
|
Try scrolling this page - notice how the chat button stays in position.
|
||||||
|
Click it to start a conversation with our AI assistant!
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div class="feature-list">
|
||||||
|
<h3>🔧 Implementation</h3>
|
||||||
|
<ul>
|
||||||
|
<li>Simple iframe embed - just copy and paste</li>
|
||||||
|
<li>No JavaScript dependencies required</li>
|
||||||
|
<li>Works on any website or platform</li>
|
||||||
|
<li>Customizable appearance and behavior</li>
|
||||||
|
<li>Secure and privacy-focused</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p>
|
||||||
|
This is just placeholder content to demonstrate how the widget integrates
|
||||||
|
seamlessly with your existing website content. The widget floats above
|
||||||
|
everything else without disrupting your user experience.
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<p style="margin-top: 4rem; text-align: center; font-style: italic;">
|
||||||
|
🎉 Ready to add this to your website? Get your embed code from the admin panel!
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<iframe id="main-widget" src="http://localhost:9222/next-chats/widget?shared_id=9dcfc68696c611f0bb789b9b8b765d12&from=chat&auth=U4MDU3NzkwOTZjNzExZjBiYjc4OWI5Yj&visible_avatar=1&locale=zh&mode=master&streaming=false"
|
||||||
|
style="position:fixed;bottom:0;right:0;width:100px;height:100px;border:none;background:transparent;z-index:9999;opacity:0;transition:opacity 0.2s ease"
|
||||||
|
frameborder="0" allow="microphone;camera"></iframe>
|
||||||
|
<script>
|
||||||
|
window.addEventListener('message',e=>{
|
||||||
|
if(e.origin!=='http://localhost:9222')return;
|
||||||
|
if(e.data.type==='WIDGET_READY'){
|
||||||
|
// Show the main widget when React is ready
|
||||||
|
const mainWidget = document.getElementById('main-widget');
|
||||||
|
if(mainWidget) mainWidget.style.opacity = '1';
|
||||||
|
}else if(e.data.type==='CREATE_CHAT_WINDOW'){
|
||||||
|
if(document.getElementById('chat-win'))return;
|
||||||
|
const i=document.createElement('iframe');
|
||||||
|
i.id='chat-win';i.src=e.data.src;
|
||||||
|
i.style.cssText='position:fixed;bottom:104px;right:24px;width:380px;height:500px;border:none;background:transparent;z-index:9998;display:none;opacity:0;transition:opacity 0.2s ease';
|
||||||
|
i.frameBorder='0';i.allow='microphone;camera';
|
||||||
|
document.body.appendChild(i);
|
||||||
|
}else if(e.data.type==='TOGGLE_CHAT'){
|
||||||
|
const w=document.getElementById('chat-win');
|
||||||
|
if(w){
|
||||||
|
if(e.data.isOpen){
|
||||||
|
w.style.display='block';
|
||||||
|
// Wait for the iframe content to be ready before showing
|
||||||
|
setTimeout(() => w.style.opacity='1', 100);
|
||||||
|
}else{
|
||||||
|
w.style.opacity='0';
|
||||||
|
setTimeout(() => w.style.display='none', 200);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}else if(e.data.type==='SCROLL_PASSTHROUGH')window.scrollBy(0,e.data.deltaY);
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@ -219,6 +219,70 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "TokenPony",
|
||||||
|
"logo": "",
|
||||||
|
"tags": "LLM",
|
||||||
|
"status": "1",
|
||||||
|
"llm": [
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-8b",
|
||||||
|
"tags": "LLM,CHAT,131k",
|
||||||
|
"max_tokens": 131000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "deepseek-v3-0324",
|
||||||
|
"tags": "LLM,CHAT,128k",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-32b",
|
||||||
|
"tags": "LLM,CHAT,131k",
|
||||||
|
"max_tokens": 131000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "kimi-k2-instruct",
|
||||||
|
"tags": "LLM,CHAT,128K",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "deepseek-r1-0528",
|
||||||
|
"tags": "LLM,CHAT,164k",
|
||||||
|
"max_tokens": 164000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-coder-480b",
|
||||||
|
"tags": "LLM,CHAT,1024k",
|
||||||
|
"max_tokens": 1024000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "glm-4.5",
|
||||||
|
"tags": "LLM,CHAT,131K",
|
||||||
|
"max_tokens": 131000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "deepseek-v3.1",
|
||||||
|
"tags": "LLM,CHAT,128k",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Tongyi-Qianwen",
|
"name": "Tongyi-Qianwen",
|
||||||
"logo": "",
|
"logo": "",
|
||||||
@ -338,7 +402,7 @@
|
|||||||
"is_tools": true
|
"is_tools": true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"llm_name": "qwen3-max-preview",
|
"llm_name": "qwen3-max",
|
||||||
"tags": "LLM,CHAT,256k",
|
"tags": "LLM,CHAT,256k",
|
||||||
"max_tokens": 256000,
|
"max_tokens": 256000,
|
||||||
"model_type": "chat",
|
"model_type": "chat",
|
||||||
@ -372,6 +436,27 @@
|
|||||||
"model_type": "chat",
|
"model_type": "chat",
|
||||||
"is_tools": true
|
"is_tools": true
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-vl-plus",
|
||||||
|
"tags": "LLM,CHAT,IMAGE2TEXT,256k",
|
||||||
|
"max_tokens": 256000,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-vl-235b-a22b-instruct",
|
||||||
|
"tags": "LLM,CHAT,IMAGE2TEXT,128k",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-vl-235b-a22b-thinking",
|
||||||
|
"tags": "LLM,CHAT,IMAGE2TEXT,128k",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"llm_name": "qwen3-235b-a22b-instruct-2507",
|
"llm_name": "qwen3-235b-a22b-instruct-2507",
|
||||||
"tags": "LLM,CHAT,128k",
|
"tags": "LLM,CHAT,128k",
|
||||||
@ -393,6 +478,20 @@
|
|||||||
"model_type": "chat",
|
"model_type": "chat",
|
||||||
"is_tools": true
|
"is_tools": true
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-next-80b-a3b-instruct",
|
||||||
|
"tags": "LLM,CHAT,128k",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-next-80b-a3b-thinking",
|
||||||
|
"tags": "LLM,CHAT,128k",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"llm_name": "qwen3-0.6b",
|
"llm_name": "qwen3-0.6b",
|
||||||
"tags": "LLM,CHAT,32k",
|
"tags": "LLM,CHAT,32k",
|
||||||
@ -558,6 +657,13 @@
|
|||||||
"tags": "SPEECH2TEXT,8k",
|
"tags": "SPEECH2TEXT,8k",
|
||||||
"max_tokens": 8000,
|
"max_tokens": 8000,
|
||||||
"model_type": "speech2text"
|
"model_type": "speech2text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qianwen-deepresearch-30b-a3b-131k",
|
||||||
|
"tags": "LLM,CHAT,1M,AGENT,DEEPRESEARCH",
|
||||||
|
"max_tokens": 1000000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -625,7 +731,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"llm_name": "glm-4",
|
"llm_name": "glm-4",
|
||||||
"tags":"LLM,CHAT,128K",
|
"tags": "LLM,CHAT,128K",
|
||||||
"max_tokens": 128000,
|
"max_tokens": 128000,
|
||||||
"model_type": "chat",
|
"model_type": "chat",
|
||||||
"is_tools": true
|
"is_tools": true
|
||||||
@ -4477,6 +4583,273 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "CometAPI",
|
||||||
|
"logo": "",
|
||||||
|
"tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
|
||||||
|
"status": "1",
|
||||||
|
"llm": [
|
||||||
|
{
|
||||||
|
"llm_name": "gpt-5-chat-latest",
|
||||||
|
"tags": "LLM,CHAT,400k",
|
||||||
|
"max_tokens": 400000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "chatgpt-4o-latest",
|
||||||
|
"tags": "LLM,CHAT,128k",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "gpt-5-mini",
|
||||||
|
"tags": "LLM,CHAT,400k",
|
||||||
|
"max_tokens": 400000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "gpt-5-nano",
|
||||||
|
"tags": "LLM,CHAT,400k",
|
||||||
|
"max_tokens": 400000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "gpt-5",
|
||||||
|
"tags": "LLM,CHAT,400k",
|
||||||
|
"max_tokens": 400000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "gpt-4.1-mini",
|
||||||
|
"tags": "LLM,CHAT,1M",
|
||||||
|
"max_tokens": 1047576,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "gpt-4.1-nano",
|
||||||
|
"tags": "LLM,CHAT,1M",
|
||||||
|
"max_tokens": 1047576,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "gpt-4.1",
|
||||||
|
"tags": "LLM,CHAT,1M",
|
||||||
|
"max_tokens": 1047576,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "gpt-4o-mini",
|
||||||
|
"tags": "LLM,CHAT,128k",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "o4-mini-2025-04-16",
|
||||||
|
"tags": "LLM,CHAT,200k",
|
||||||
|
"max_tokens": 200000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "o3-pro-2025-06-10",
|
||||||
|
"tags": "LLM,CHAT,200k",
|
||||||
|
"max_tokens": 200000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "claude-opus-4-1-20250805",
|
||||||
|
"tags": "LLM,CHAT,200k,IMAGE2TEXT",
|
||||||
|
"max_tokens": 200000,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "claude-opus-4-1-20250805-thinking",
|
||||||
|
"tags": "LLM,CHAT,200k,IMAGE2TEXT",
|
||||||
|
"max_tokens": 200000,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "claude-sonnet-4-20250514",
|
||||||
|
"tags": "LLM,CHAT,200k,IMAGE2TEXT",
|
||||||
|
"max_tokens": 200000,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "claude-sonnet-4-20250514-thinking",
|
||||||
|
"tags": "LLM,CHAT,200k,IMAGE2TEXT",
|
||||||
|
"max_tokens": 200000,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "claude-3-7-sonnet-latest",
|
||||||
|
"tags": "LLM,CHAT,200k",
|
||||||
|
"max_tokens": 200000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "claude-3-5-haiku-latest",
|
||||||
|
"tags": "LLM,CHAT,200k",
|
||||||
|
"max_tokens": 200000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "gemini-2.5-pro",
|
||||||
|
"tags": "LLM,CHAT,1M,IMAGE2TEXT",
|
||||||
|
"max_tokens": 1000000,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "gemini-2.5-flash",
|
||||||
|
"tags": "LLM,CHAT,1M,IMAGE2TEXT",
|
||||||
|
"max_tokens": 1000000,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "gemini-2.5-flash-lite",
|
||||||
|
"tags": "LLM,CHAT,1M,IMAGE2TEXT",
|
||||||
|
"max_tokens": 1000000,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "gemini-2.0-flash",
|
||||||
|
"tags": "LLM,CHAT,1M,IMAGE2TEXT",
|
||||||
|
"max_tokens": 1000000,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "grok-4-0709",
|
||||||
|
"tags": "LLM,CHAT,131k",
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "grok-3",
|
||||||
|
"tags": "LLM,CHAT,131k",
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "grok-3-mini",
|
||||||
|
"tags": "LLM,CHAT,131k",
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "grok-2-image-1212",
|
||||||
|
"tags": "LLM,CHAT,32k,IMAGE2TEXT",
|
||||||
|
"max_tokens": 32768,
|
||||||
|
"model_type": "image2text",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "deepseek-v3.1",
|
||||||
|
"tags": "LLM,CHAT,64k",
|
||||||
|
"max_tokens": 64000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "deepseek-v3",
|
||||||
|
"tags": "LLM,CHAT,64k",
|
||||||
|
"max_tokens": 64000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "deepseek-r1-0528",
|
||||||
|
"tags": "LLM,CHAT,164k",
|
||||||
|
"max_tokens": 164000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "deepseek-chat",
|
||||||
|
"tags": "LLM,CHAT,32k",
|
||||||
|
"max_tokens": 32000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "deepseek-reasoner",
|
||||||
|
"tags": "LLM,CHAT,64k",
|
||||||
|
"max_tokens": 64000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-30b-a3b",
|
||||||
|
"tags": "LLM,CHAT,128k",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen3-coder-plus-2025-07-22",
|
||||||
|
"tags": "LLM,CHAT,128k",
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"model_type": "chat",
|
||||||
|
"is_tools": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "text-embedding-ada-002",
|
||||||
|
"tags": "TEXT EMBEDDING,8K",
|
||||||
|
"max_tokens": 8191,
|
||||||
|
"model_type": "embedding",
|
||||||
|
"is_tools": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "text-embedding-3-small",
|
||||||
|
"tags": "TEXT EMBEDDING,8K",
|
||||||
|
"max_tokens": 8191,
|
||||||
|
"model_type": "embedding",
|
||||||
|
"is_tools": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "text-embedding-3-large",
|
||||||
|
"tags": "TEXT EMBEDDING,8K",
|
||||||
|
"max_tokens": 8191,
|
||||||
|
"model_type": "embedding",
|
||||||
|
"is_tools": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "whisper-1",
|
||||||
|
"tags": "SPEECH2TEXT",
|
||||||
|
"max_tokens": 26214400,
|
||||||
|
"model_type": "speech2text",
|
||||||
|
"is_tools": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "tts-1",
|
||||||
|
"tags": "TTS",
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"model_type": "tts",
|
||||||
|
"is_tools": false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Meituan",
|
"name": "Meituan",
|
||||||
"logo": "",
|
"logo": "",
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
ragflow:
|
ragflow:
|
||||||
host: 0.0.0.0
|
host: 0.0.0.0
|
||||||
http_port: 9380
|
http_port: 9380
|
||||||
|
admin:
|
||||||
|
host: 0.0.0.0
|
||||||
|
http_port: 9381
|
||||||
mysql:
|
mysql:
|
||||||
name: 'rag_flow'
|
name: 'rag_flow'
|
||||||
user: 'root'
|
user: 'root'
|
||||||
|
|||||||
@ -22,10 +22,10 @@ from openpyxl import Workbook, load_workbook
|
|||||||
from rag.nlp import find_codec
|
from rag.nlp import find_codec
|
||||||
|
|
||||||
# copied from `/openpyxl/cell/cell.py`
|
# copied from `/openpyxl/cell/cell.py`
|
||||||
ILLEGAL_CHARACTERS_RE = re.compile(r'[\000-\010]|[\013-\014]|[\016-\037]')
|
ILLEGAL_CHARACTERS_RE = re.compile(r"[\000-\010]|[\013-\014]|[\016-\037]")
|
||||||
|
|
||||||
|
|
||||||
class RAGFlowExcelParser:
|
class RAGFlowExcelParser:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_excel_to_workbook(file_like_object):
|
def _load_excel_to_workbook(file_like_object):
|
||||||
if isinstance(file_like_object, bytes):
|
if isinstance(file_like_object, bytes):
|
||||||
@ -36,7 +36,7 @@ class RAGFlowExcelParser:
|
|||||||
file_head = file_like_object.read(4)
|
file_head = file_like_object.read(4)
|
||||||
file_like_object.seek(0)
|
file_like_object.seek(0)
|
||||||
|
|
||||||
if not (file_head.startswith(b'PK\x03\x04') or file_head.startswith(b'\xD0\xCF\x11\xE0')):
|
if not (file_head.startswith(b"PK\x03\x04") or file_head.startswith(b"\xd0\xcf\x11\xe0")):
|
||||||
logging.info("Not an Excel file, converting CSV to Excel Workbook")
|
logging.info("Not an Excel file, converting CSV to Excel Workbook")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -48,7 +48,7 @@ class RAGFlowExcelParser:
|
|||||||
raise Exception(f"Failed to parse CSV and convert to Excel Workbook: {e_csv}")
|
raise Exception(f"Failed to parse CSV and convert to Excel Workbook: {e_csv}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return load_workbook(file_like_object,data_only= True)
|
return load_workbook(file_like_object, data_only=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(f"openpyxl load error: {e}, try pandas instead")
|
logging.info(f"openpyxl load error: {e}, try pandas instead")
|
||||||
try:
|
try:
|
||||||
@ -59,7 +59,7 @@ class RAGFlowExcelParser:
|
|||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.info(f"pandas with default engine load error: {ex}, try calamine instead")
|
logging.info(f"pandas with default engine load error: {ex}, try calamine instead")
|
||||||
file_like_object.seek(0)
|
file_like_object.seek(0)
|
||||||
df = pd.read_excel(file_like_object, engine='calamine')
|
df = pd.read_excel(file_like_object, engine="calamine")
|
||||||
return RAGFlowExcelParser._dataframe_to_workbook(df)
|
return RAGFlowExcelParser._dataframe_to_workbook(df)
|
||||||
except Exception as e_pandas:
|
except Exception as e_pandas:
|
||||||
raise Exception(f"pandas.read_excel error: {e_pandas}, original openpyxl error: {e}")
|
raise Exception(f"pandas.read_excel error: {e_pandas}, original openpyxl error: {e}")
|
||||||
@ -116,9 +116,7 @@ class RAGFlowExcelParser:
|
|||||||
tb = ""
|
tb = ""
|
||||||
tb += f"<table><caption>{sheetname}</caption>"
|
tb += f"<table><caption>{sheetname}</caption>"
|
||||||
tb += tb_rows_0
|
tb += tb_rows_0
|
||||||
for r in list(
|
for r in list(rows[1 + chunk_i * chunk_rows : min(1 + (chunk_i + 1) * chunk_rows, len(rows))]):
|
||||||
rows[1 + chunk_i * chunk_rows: min(1 + (chunk_i + 1) * chunk_rows, len(rows))]
|
|
||||||
):
|
|
||||||
tb += "<tr>"
|
tb += "<tr>"
|
||||||
for i, c in enumerate(r):
|
for i, c in enumerate(r):
|
||||||
if c.value is None:
|
if c.value is None:
|
||||||
@ -133,8 +131,16 @@ class RAGFlowExcelParser:
|
|||||||
|
|
||||||
def markdown(self, fnm):
|
def markdown(self, fnm):
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
|
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
|
||||||
df = pd.read_excel(file_like_object)
|
try:
|
||||||
|
file_like_object.seek(0)
|
||||||
|
df = pd.read_excel(file_like_object)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file")
|
||||||
|
file_like_object.seek(0)
|
||||||
|
df = pd.read_csv(file_like_object)
|
||||||
|
df = df.replace(r"^\s*$", "", regex=True)
|
||||||
return df.to_markdown(index=False)
|
return df.to_markdown(index=False)
|
||||||
|
|
||||||
def __call__(self, fnm):
|
def __call__(self, fnm):
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
||||||
from rag.prompts import vision_llm_figure_describe_prompt
|
from rag.prompts.generator import vision_llm_figure_describe_prompt
|
||||||
|
|
||||||
|
|
||||||
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
||||||
|
|||||||
@ -37,7 +37,7 @@ TITLE_TAGS = {"h1": "#", "h2": "##", "h3": "###", "h4": "#####", "h5": "#####",
|
|||||||
|
|
||||||
|
|
||||||
class RAGFlowHtmlParser:
|
class RAGFlowHtmlParser:
|
||||||
def __call__(self, fnm, binary=None, chunk_token_num=None):
|
def __call__(self, fnm, binary=None, chunk_token_num=512):
|
||||||
if binary:
|
if binary:
|
||||||
encoding = find_codec(binary)
|
encoding = find_codec(binary)
|
||||||
txt = binary.decode(encoding, errors="ignore")
|
txt = binary.decode(encoding, errors="ignore")
|
||||||
|
|||||||
@ -34,10 +34,10 @@ from pypdf import PdfReader as pdf2_read
|
|||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from deepdoc.vision import OCR, LayoutRecognizer, Recognizer, TableStructureRecognizer
|
from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recognizer, TableStructureRecognizer
|
||||||
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
from rag.prompts import vision_llm_describe_prompt
|
from rag.prompts.generator import vision_llm_describe_prompt
|
||||||
from rag.settings import PARALLEL_DEVICES
|
from rag.settings import PARALLEL_DEVICES
|
||||||
|
|
||||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||||
@ -64,33 +64,38 @@ class RAGFlowPdfParser:
|
|||||||
if PARALLEL_DEVICES > 1:
|
if PARALLEL_DEVICES > 1:
|
||||||
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
|
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
|
||||||
|
|
||||||
|
layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower()
|
||||||
|
if layout_recognizer_type not in ["onnx", "ascend"]:
|
||||||
|
raise RuntimeError("Unsupported layout recognizer type.")
|
||||||
|
|
||||||
if hasattr(self, "model_speciess"):
|
if hasattr(self, "model_speciess"):
|
||||||
self.layouter = LayoutRecognizer("layout." + self.model_speciess)
|
recognizer_domain = "layout." + self.model_speciess
|
||||||
else:
|
else:
|
||||||
self.layouter = LayoutRecognizer("layout")
|
recognizer_domain = "layout"
|
||||||
|
|
||||||
|
if layout_recognizer_type == "ascend":
|
||||||
|
logging.debug("Using Ascend LayoutRecognizer")
|
||||||
|
self.layouter = AscendLayoutRecognizer(recognizer_domain)
|
||||||
|
else: # onnx
|
||||||
|
logging.debug("Using Onnx LayoutRecognizer")
|
||||||
|
self.layouter = LayoutRecognizer(recognizer_domain)
|
||||||
self.tbl_det = TableStructureRecognizer()
|
self.tbl_det = TableStructureRecognizer()
|
||||||
|
|
||||||
self.updown_cnt_mdl = xgb.Booster()
|
self.updown_cnt_mdl = xgb.Booster()
|
||||||
if not settings.LIGHTEN:
|
if not settings.LIGHTEN:
|
||||||
try:
|
try:
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.updown_cnt_mdl.set_param({"device": "cuda"})
|
self.updown_cnt_mdl.set_param({"device": "cuda"})
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("RAGFlowPdfParser __init__")
|
logging.exception("RAGFlowPdfParser __init__")
|
||||||
try:
|
try:
|
||||||
model_dir = os.path.join(
|
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
|
||||||
get_project_base_directory(),
|
self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model"))
|
||||||
"rag/res/deepdoc")
|
|
||||||
self.updown_cnt_mdl.load_model(os.path.join(
|
|
||||||
model_dir, "updown_concat_xgb.model"))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
model_dir = snapshot_download(
|
model_dir = snapshot_download(repo_id="InfiniFlow/text_concat_xgb_v1.0", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir_use_symlinks=False)
|
||||||
repo_id="InfiniFlow/text_concat_xgb_v1.0",
|
self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model"))
|
||||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
|
||||||
local_dir_use_symlinks=False)
|
|
||||||
self.updown_cnt_mdl.load_model(os.path.join(
|
|
||||||
model_dir, "updown_concat_xgb.model"))
|
|
||||||
|
|
||||||
self.page_from = 0
|
self.page_from = 0
|
||||||
self.column_num = 1
|
self.column_num = 1
|
||||||
@ -102,13 +107,10 @@ class RAGFlowPdfParser:
|
|||||||
return c["bottom"] - c["top"]
|
return c["bottom"] - c["top"]
|
||||||
|
|
||||||
def _x_dis(self, a, b):
|
def _x_dis(self, a, b):
|
||||||
return min(abs(a["x1"] - b["x0"]), abs(a["x0"] - b["x1"]),
|
return min(abs(a["x1"] - b["x0"]), abs(a["x0"] - b["x1"]), abs(a["x0"] + a["x1"] - b["x0"] - b["x1"]) / 2)
|
||||||
abs(a["x0"] + a["x1"] - b["x0"] - b["x1"]) / 2)
|
|
||||||
|
|
||||||
def _y_dis(
|
def _y_dis(self, a, b):
|
||||||
self, a, b):
|
return (b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
|
||||||
return (
|
|
||||||
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
|
|
||||||
|
|
||||||
def _match_proj(self, b):
|
def _match_proj(self, b):
|
||||||
proj_patt = [
|
proj_patt = [
|
||||||
@ -130,10 +132,7 @@ class RAGFlowPdfParser:
|
|||||||
LEN = 6
|
LEN = 6
|
||||||
tks_down = rag_tokenizer.tokenize(down["text"][:LEN]).split()
|
tks_down = rag_tokenizer.tokenize(down["text"][:LEN]).split()
|
||||||
tks_up = rag_tokenizer.tokenize(up["text"][-LEN:]).split()
|
tks_up = rag_tokenizer.tokenize(up["text"][-LEN:]).split()
|
||||||
tks_all = up["text"][-LEN:].strip() \
|
tks_all = up["text"][-LEN:].strip() + (" " if re.match(r"[a-zA-Z0-9]+", up["text"][-1] + down["text"][0]) else "") + down["text"][:LEN].strip()
|
||||||
+ (" " if re.match(r"[a-zA-Z0-9]+",
|
|
||||||
up["text"][-1] + down["text"][0]) else "") \
|
|
||||||
+ down["text"][:LEN].strip()
|
|
||||||
tks_all = rag_tokenizer.tokenize(tks_all).split()
|
tks_all = rag_tokenizer.tokenize(tks_all).split()
|
||||||
fea = [
|
fea = [
|
||||||
up.get("R", -1) == down.get("R", -1),
|
up.get("R", -1) == down.get("R", -1),
|
||||||
@ -144,39 +143,30 @@ class RAGFlowPdfParser:
|
|||||||
down["layout_type"] == "text",
|
down["layout_type"] == "text",
|
||||||
up["layout_type"] == "table",
|
up["layout_type"] == "table",
|
||||||
down["layout_type"] == "table",
|
down["layout_type"] == "table",
|
||||||
True if re.search(
|
True if re.search(r"([。?!;!?;+))]|[a-z]\.)$", up["text"]) else False,
|
||||||
r"([。?!;!?;+))]|[a-z]\.)$",
|
|
||||||
up["text"]) else False,
|
|
||||||
True if re.search(r"[,:‘“、0-9(+-]$", up["text"]) else False,
|
True if re.search(r"[,:‘“、0-9(+-]$", up["text"]) else False,
|
||||||
True if re.search(
|
True if re.search(r"(^.?[/,?;:\],。;:’”?!》】)-])", down["text"]) else False,
|
||||||
r"(^.?[/,?;:\],。;:’”?!》】)-])",
|
|
||||||
down["text"]) else False,
|
|
||||||
True if re.match(r"[\((][^\(\)()]+[)\)]$", up["text"]) else False,
|
True if re.match(r"[\((][^\(\)()]+[)\)]$", up["text"]) else False,
|
||||||
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
||||||
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
||||||
True if re.search(r"[\((][^\))]+$", up["text"])
|
True if re.search(r"[\((][^\))]+$", up["text"]) and re.search(r"[\))]", down["text"]) else False,
|
||||||
and re.search(r"[\))]", down["text"]) else False,
|
|
||||||
self._match_proj(down),
|
self._match_proj(down),
|
||||||
True if re.match(r"[A-Z]", down["text"]) else False,
|
True if re.match(r"[A-Z]", down["text"]) else False,
|
||||||
True if re.match(r"[A-Z]", up["text"][-1]) else False,
|
True if re.match(r"[A-Z]", up["text"][-1]) else False,
|
||||||
True if re.match(r"[a-z0-9]", up["text"][-1]) else False,
|
True if re.match(r"[a-z0-9]", up["text"][-1]) else False,
|
||||||
True if re.match(r"[0-9.%,-]+$", down["text"]) else False,
|
True if re.match(r"[0-9.%,-]+$", down["text"]) else False,
|
||||||
up["text"].strip()[-2:] == down["text"].strip()[-2:] if len(up["text"].strip()
|
up["text"].strip()[-2:] == down["text"].strip()[-2:] if len(up["text"].strip()) > 1 and len(down["text"].strip()) > 1 else False,
|
||||||
) > 1 and len(
|
|
||||||
down["text"].strip()) > 1 else False,
|
|
||||||
up["x0"] > down["x1"],
|
up["x0"] > down["x1"],
|
||||||
abs(self.__height(up) - self.__height(down)) / min(self.__height(up),
|
abs(self.__height(up) - self.__height(down)) / min(self.__height(up), self.__height(down)),
|
||||||
self.__height(down)),
|
|
||||||
self._x_dis(up, down) / max(w, 0.000001),
|
self._x_dis(up, down) / max(w, 0.000001),
|
||||||
(len(up["text"]) - len(down["text"])) /
|
(len(up["text"]) - len(down["text"])) / max(len(up["text"]), len(down["text"])),
|
||||||
max(len(up["text"]), len(down["text"])),
|
|
||||||
len(tks_all) - len(tks_up) - len(tks_down),
|
len(tks_all) - len(tks_up) - len(tks_down),
|
||||||
len(tks_down) - len(tks_up),
|
len(tks_down) - len(tks_up),
|
||||||
tks_down[-1] == tks_up[-1] if tks_down and tks_up else False,
|
tks_down[-1] == tks_up[-1] if tks_down and tks_up else False,
|
||||||
max(down["in_row"], up["in_row"]),
|
max(down["in_row"], up["in_row"]),
|
||||||
abs(down["in_row"] - up["in_row"]),
|
abs(down["in_row"] - up["in_row"]),
|
||||||
len(tks_down) == 1 and rag_tokenizer.tag(tks_down[0]).find("n") >= 0,
|
len(tks_down) == 1 and rag_tokenizer.tag(tks_down[0]).find("n") >= 0,
|
||||||
len(tks_up) == 1 and rag_tokenizer.tag(tks_up[0]).find("n") >= 0
|
len(tks_up) == 1 and rag_tokenizer.tag(tks_up[0]).find("n") >= 0,
|
||||||
]
|
]
|
||||||
return fea
|
return fea
|
||||||
|
|
||||||
@ -187,9 +177,7 @@ class RAGFlowPdfParser:
|
|||||||
for i in range(len(arr) - 1):
|
for i in range(len(arr) - 1):
|
||||||
for j in range(i, -1, -1):
|
for j in range(i, -1, -1):
|
||||||
# restore the order using th
|
# restore the order using th
|
||||||
if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threshold \
|
if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threshold and arr[j + 1]["top"] < arr[j]["top"] and arr[j + 1]["page_number"] == arr[j]["page_number"]:
|
||||||
and arr[j + 1]["top"] < arr[j]["top"] \
|
|
||||||
and arr[j + 1]["page_number"] == arr[j]["page_number"]:
|
|
||||||
tmp = arr[j]
|
tmp = arr[j]
|
||||||
arr[j] = arr[j + 1]
|
arr[j] = arr[j + 1]
|
||||||
arr[j + 1] = tmp
|
arr[j + 1] = tmp
|
||||||
@ -197,8 +185,7 @@ class RAGFlowPdfParser:
|
|||||||
|
|
||||||
def _has_color(self, o):
|
def _has_color(self, o):
|
||||||
if o.get("ncs", "") == "DeviceGray":
|
if o.get("ncs", "") == "DeviceGray":
|
||||||
if o["stroking_color"] and o["stroking_color"][0] == 1 and o["non_stroking_color"] and \
|
if o["stroking_color"] and o["stroking_color"][0] == 1 and o["non_stroking_color"] and o["non_stroking_color"][0] == 1:
|
||||||
o["non_stroking_color"][0] == 1:
|
|
||||||
if re.match(r"[a-zT_\[\]\(\)-]+", o.get("text", "")):
|
if re.match(r"[a-zT_\[\]\(\)-]+", o.get("text", "")):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@ -216,8 +203,7 @@ class RAGFlowPdfParser:
|
|||||||
if not tbls:
|
if not tbls:
|
||||||
continue
|
continue
|
||||||
for tb in tbls: # for table
|
for tb in tbls: # for table
|
||||||
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
|
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, tb["x1"] + MARGIN, tb["bottom"] + MARGIN
|
||||||
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
|
|
||||||
left *= ZM
|
left *= ZM
|
||||||
top *= ZM
|
top *= ZM
|
||||||
right *= ZM
|
right *= ZM
|
||||||
@ -232,14 +218,13 @@ class RAGFlowPdfParser:
|
|||||||
tbcnt = np.cumsum(tbcnt)
|
tbcnt = np.cumsum(tbcnt)
|
||||||
for i in range(len(tbcnt) - 1): # for page
|
for i in range(len(tbcnt) - 1): # for page
|
||||||
pg = []
|
pg = []
|
||||||
for j, tb_items in enumerate(
|
for j, tb_items in enumerate(recos[tbcnt[i] : tbcnt[i + 1]]): # for table
|
||||||
recos[tbcnt[i]: tbcnt[i + 1]]): # for table
|
poss = pos[tbcnt[i] : tbcnt[i + 1]]
|
||||||
poss = pos[tbcnt[i]: tbcnt[i + 1]]
|
|
||||||
for it in tb_items: # for table components
|
for it in tb_items: # for table components
|
||||||
it["x0"] = (it["x0"] + poss[j][0])
|
it["x0"] = it["x0"] + poss[j][0]
|
||||||
it["x1"] = (it["x1"] + poss[j][0])
|
it["x1"] = it["x1"] + poss[j][0]
|
||||||
it["top"] = (it["top"] + poss[j][1])
|
it["top"] = it["top"] + poss[j][1]
|
||||||
it["bottom"] = (it["bottom"] + poss[j][1])
|
it["bottom"] = it["bottom"] + poss[j][1]
|
||||||
for n in ["x0", "x1", "top", "bottom"]:
|
for n in ["x0", "x1", "top", "bottom"]:
|
||||||
it[n] /= ZM
|
it[n] /= ZM
|
||||||
it["top"] += self.page_cum_height[i]
|
it["top"] += self.page_cum_height[i]
|
||||||
@ -250,8 +235,7 @@ class RAGFlowPdfParser:
|
|||||||
self.tb_cpns.extend(pg)
|
self.tb_cpns.extend(pg)
|
||||||
|
|
||||||
def gather(kwd, fzy=10, ption=0.6):
|
def gather(kwd, fzy=10, ption=0.6):
|
||||||
eles = Recognizer.sort_Y_firstly(
|
eles = Recognizer.sort_Y_firstly([r for r in self.tb_cpns if re.match(kwd, r["label"])], fzy)
|
||||||
[r for r in self.tb_cpns if re.match(kwd, r["label"])], fzy)
|
|
||||||
eles = Recognizer.layouts_cleanup(self.boxes, eles, 5, ption)
|
eles = Recognizer.layouts_cleanup(self.boxes, eles, 5, ption)
|
||||||
return Recognizer.sort_Y_firstly(eles, 0)
|
return Recognizer.sort_Y_firstly(eles, 0)
|
||||||
|
|
||||||
@ -259,8 +243,7 @@ class RAGFlowPdfParser:
|
|||||||
headers = gather(r".*header$")
|
headers = gather(r".*header$")
|
||||||
rows = gather(r".* (row|header)")
|
rows = gather(r".* (row|header)")
|
||||||
spans = gather(r".*spanning")
|
spans = gather(r".*spanning")
|
||||||
clmns = sorted([r for r in self.tb_cpns if re.match(
|
clmns = sorted([r for r in self.tb_cpns if re.match(r"table column$", r["label"])], key=lambda x: (x["pn"], x["layoutno"], x["x0"]))
|
||||||
r"table column$", r["label"])], key=lambda x: (x["pn"], x["layoutno"], x["x0"]))
|
|
||||||
clmns = Recognizer.layouts_cleanup(self.boxes, clmns, 5, 0.5)
|
clmns = Recognizer.layouts_cleanup(self.boxes, clmns, 5, 0.5)
|
||||||
for b in self.boxes:
|
for b in self.boxes:
|
||||||
if b.get("layout_type", "") != "table":
|
if b.get("layout_type", "") != "table":
|
||||||
@ -271,8 +254,7 @@ class RAGFlowPdfParser:
|
|||||||
b["R_top"] = rows[ii]["top"]
|
b["R_top"] = rows[ii]["top"]
|
||||||
b["R_bott"] = rows[ii]["bottom"]
|
b["R_bott"] = rows[ii]["bottom"]
|
||||||
|
|
||||||
ii = Recognizer.find_overlapped_with_threshold(
|
ii = Recognizer.find_overlapped_with_threshold(b, headers, thr=0.3)
|
||||||
b, headers, thr=0.3)
|
|
||||||
if ii is not None:
|
if ii is not None:
|
||||||
b["H_top"] = headers[ii]["top"]
|
b["H_top"] = headers[ii]["top"]
|
||||||
b["H_bott"] = headers[ii]["bottom"]
|
b["H_bott"] = headers[ii]["bottom"]
|
||||||
@ -305,12 +287,12 @@ class RAGFlowPdfParser:
|
|||||||
return
|
return
|
||||||
bxs = [(line[0], line[1][0]) for line in bxs]
|
bxs = [(line[0], line[1][0]) for line in bxs]
|
||||||
bxs = Recognizer.sort_Y_firstly(
|
bxs = Recognizer.sort_Y_firstly(
|
||||||
[{"x0": b[0][0] / ZM, "x1": b[1][0] / ZM,
|
[
|
||||||
"top": b[0][1] / ZM, "text": "", "txt": t,
|
{"x0": b[0][0] / ZM, "x1": b[1][0] / ZM, "top": b[0][1] / ZM, "text": "", "txt": t, "bottom": b[-1][1] / ZM, "chars": [], "page_number": pagenum}
|
||||||
"bottom": b[-1][1] / ZM,
|
for b, t in bxs
|
||||||
"chars": [],
|
if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]
|
||||||
"page_number": pagenum} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
|
],
|
||||||
self.mean_height[pagenum-1] / 3
|
self.mean_height[pagenum - 1] / 3,
|
||||||
)
|
)
|
||||||
|
|
||||||
# merge chars in the same rect
|
# merge chars in the same rect
|
||||||
@ -321,7 +303,7 @@ class RAGFlowPdfParser:
|
|||||||
continue
|
continue
|
||||||
ch = c["bottom"] - c["top"]
|
ch = c["bottom"] - c["top"]
|
||||||
bh = bxs[ii]["bottom"] - bxs[ii]["top"]
|
bh = bxs[ii]["bottom"] - bxs[ii]["top"]
|
||||||
if abs(ch - bh) / max(ch, bh) >= 0.7 and c["text"] != ' ':
|
if abs(ch - bh) / max(ch, bh) >= 0.7 and c["text"] != " ":
|
||||||
self.lefted_chars.append(c)
|
self.lefted_chars.append(c)
|
||||||
continue
|
continue
|
||||||
bxs[ii]["chars"].append(c)
|
bxs[ii]["chars"].append(c)
|
||||||
@ -345,8 +327,7 @@ class RAGFlowPdfParser:
|
|||||||
img_np = np.array(img)
|
img_np = np.array(img)
|
||||||
for b in bxs:
|
for b in bxs:
|
||||||
if not b["text"]:
|
if not b["text"]:
|
||||||
left, right, top, bott = b["x0"] * ZM, b["x1"] * \
|
left, right, top, bott = b["x0"] * ZM, b["x1"] * ZM, b["top"] * ZM, b["bottom"] * ZM
|
||||||
ZM, b["top"] * ZM, b["bottom"] * ZM
|
|
||||||
b["box_image"] = self.ocr.get_rotate_crop_image(img_np, np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32))
|
b["box_image"] = self.ocr.get_rotate_crop_image(img_np, np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32))
|
||||||
boxes_to_reg.append(b)
|
boxes_to_reg.append(b)
|
||||||
del b["txt"]
|
del b["txt"]
|
||||||
@ -356,21 +337,17 @@ class RAGFlowPdfParser:
|
|||||||
del boxes_to_reg[i]["box_image"]
|
del boxes_to_reg[i]["box_image"]
|
||||||
logging.info(f"__ocr recognize {len(bxs)} boxes cost {timer() - start}s")
|
logging.info(f"__ocr recognize {len(bxs)} boxes cost {timer() - start}s")
|
||||||
bxs = [b for b in bxs if b["text"]]
|
bxs = [b for b in bxs if b["text"]]
|
||||||
if self.mean_height[pagenum-1] == 0:
|
if self.mean_height[pagenum - 1] == 0:
|
||||||
self.mean_height[pagenum-1] = np.median([b["bottom"] - b["top"]
|
self.mean_height[pagenum - 1] = np.median([b["bottom"] - b["top"] for b in bxs])
|
||||||
for b in bxs])
|
|
||||||
self.boxes.append(bxs)
|
self.boxes.append(bxs)
|
||||||
|
|
||||||
def _layouts_rec(self, ZM, drop=True):
|
def _layouts_rec(self, ZM, drop=True):
|
||||||
assert len(self.page_images) == len(self.boxes)
|
assert len(self.page_images) == len(self.boxes)
|
||||||
self.boxes, self.page_layout = self.layouter(
|
self.boxes, self.page_layout = self.layouter(self.page_images, self.boxes, ZM, drop=drop)
|
||||||
self.page_images, self.boxes, ZM, drop=drop)
|
|
||||||
# cumlative Y
|
# cumlative Y
|
||||||
for i in range(len(self.boxes)):
|
for i in range(len(self.boxes)):
|
||||||
self.boxes[i]["top"] += \
|
self.boxes[i]["top"] += self.page_cum_height[self.boxes[i]["page_number"] - 1]
|
||||||
self.page_cum_height[self.boxes[i]["page_number"] - 1]
|
self.boxes[i]["bottom"] += self.page_cum_height[self.boxes[i]["page_number"] - 1]
|
||||||
self.boxes[i]["bottom"] += \
|
|
||||||
self.page_cum_height[self.boxes[i]["page_number"] - 1]
|
|
||||||
|
|
||||||
def _text_merge(self):
|
def _text_merge(self):
|
||||||
# merge adjusted boxes
|
# merge adjusted boxes
|
||||||
@ -390,12 +367,10 @@ class RAGFlowPdfParser:
|
|||||||
while i < len(bxs) - 1:
|
while i < len(bxs) - 1:
|
||||||
b = bxs[i]
|
b = bxs[i]
|
||||||
b_ = bxs[i + 1]
|
b_ = bxs[i + 1]
|
||||||
if b.get("layoutno", "0") != b_.get("layoutno", "1") or b.get("layout_type", "") in ["table", "figure",
|
if b.get("layoutno", "0") != b_.get("layoutno", "1") or b.get("layout_type", "") in ["table", "figure", "equation"]:
|
||||||
"equation"]:
|
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
if abs(self._y_dis(b, b_)
|
if abs(self._y_dis(b, b_)) < self.mean_height[bxs[i]["page_number"] - 1] / 3:
|
||||||
) < self.mean_height[bxs[i]["page_number"] - 1] / 3:
|
|
||||||
# merge
|
# merge
|
||||||
bxs[i]["x1"] = b_["x1"]
|
bxs[i]["x1"] = b_["x1"]
|
||||||
bxs[i]["top"] = (b["top"] + b_["top"]) / 2
|
bxs[i]["top"] = (b["top"] + b_["top"]) / 2
|
||||||
@ -408,16 +383,14 @@ class RAGFlowPdfParser:
|
|||||||
|
|
||||||
dis_thr = 1
|
dis_thr = 1
|
||||||
dis = b["x1"] - b_["x0"]
|
dis = b["x1"] - b_["x0"]
|
||||||
if b.get("layout_type", "") != "text" or b_.get(
|
if b.get("layout_type", "") != "text" or b_.get("layout_type", "") != "text":
|
||||||
"layout_type", "") != "text":
|
|
||||||
if end_with(b, ",") or start_with(b_, "(,"):
|
if end_with(b, ",") or start_with(b_, "(,"):
|
||||||
dis_thr = -8
|
dis_thr = -8
|
||||||
else:
|
else:
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if abs(self._y_dis(b, b_)) < self.mean_height[bxs[i]["page_number"] - 1] / 5 \
|
if abs(self._y_dis(b, b_)) < self.mean_height[bxs[i]["page_number"] - 1] / 5 and dis >= dis_thr and b["x1"] < b_["x1"]:
|
||||||
and dis >= dis_thr and b["x1"] < b_["x1"]:
|
|
||||||
# merge
|
# merge
|
||||||
bxs[i]["x1"] = b_["x1"]
|
bxs[i]["x1"] = b_["x1"]
|
||||||
bxs[i]["top"] = (b["top"] + b_["top"]) / 2
|
bxs[i]["top"] = (b["top"] + b_["top"]) / 2
|
||||||
@ -429,23 +402,22 @@ class RAGFlowPdfParser:
|
|||||||
self.boxes = bxs
|
self.boxes = bxs
|
||||||
|
|
||||||
def _naive_vertical_merge(self, zoomin=3):
|
def _naive_vertical_merge(self, zoomin=3):
|
||||||
bxs = Recognizer.sort_Y_firstly(
|
import math
|
||||||
self.boxes, np.median(
|
bxs = Recognizer.sort_Y_firstly(self.boxes, np.median(self.mean_height) / 3)
|
||||||
self.mean_height) / 3)
|
|
||||||
|
|
||||||
column_width = np.median([b["x1"] - b["x0"] for b in self.boxes])
|
column_width = np.median([b["x1"] - b["x0"] for b in self.boxes])
|
||||||
|
if not column_width or math.isnan(column_width):
|
||||||
|
column_width = self.mean_width[0]
|
||||||
self.column_num = int(self.page_images[0].size[0] / zoomin / column_width)
|
self.column_num = int(self.page_images[0].size[0] / zoomin / column_width)
|
||||||
if column_width < self.page_images[0].size[0] / zoomin / self.column_num:
|
if column_width < self.page_images[0].size[0] / zoomin / self.column_num:
|
||||||
logging.info("Multi-column................... {} {}".format(column_width,
|
logging.info("Multi-column................... {} {}".format(column_width, self.page_images[0].size[0] / zoomin / self.column_num))
|
||||||
self.page_images[0].size[0] / zoomin / self.column_num))
|
|
||||||
self.boxes = self.sort_X_by_page(self.boxes, column_width / self.column_num)
|
self.boxes = self.sort_X_by_page(self.boxes, column_width / self.column_num)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
while i + 1 < len(bxs):
|
while i + 1 < len(bxs):
|
||||||
b = bxs[i]
|
b = bxs[i]
|
||||||
b_ = bxs[i + 1]
|
b_ = bxs[i + 1]
|
||||||
if b["page_number"] < b_["page_number"] and re.match(
|
if b["page_number"] < b_["page_number"] and re.match(r"[0-9 •一—-]+$", b["text"]):
|
||||||
r"[0-9 •一—-]+$", b["text"]):
|
|
||||||
bxs.pop(i)
|
bxs.pop(i)
|
||||||
continue
|
continue
|
||||||
if not b["text"].strip():
|
if not b["text"].strip():
|
||||||
@ -453,8 +425,7 @@ class RAGFlowPdfParser:
|
|||||||
continue
|
continue
|
||||||
concatting_feats = [
|
concatting_feats = [
|
||||||
b["text"].strip()[-1] in ",;:'\",、‘“;:-",
|
b["text"].strip()[-1] in ",;:'\",、‘“;:-",
|
||||||
len(b["text"].strip()) > 1 and b["text"].strip(
|
len(b["text"].strip()) > 1 and b["text"].strip()[-2] in ",;:'\",‘“、;:",
|
||||||
)[-2] in ",;:'\",‘“、;:",
|
|
||||||
b_["text"].strip() and b_["text"].strip()[0] in "。;?!?”)),,、:",
|
b_["text"].strip() and b_["text"].strip()[0] in "。;?!?”)),,、:",
|
||||||
]
|
]
|
||||||
# features for not concating
|
# features for not concating
|
||||||
@ -462,21 +433,20 @@ class RAGFlowPdfParser:
|
|||||||
b.get("layoutno", 0) != b_.get("layoutno", 0),
|
b.get("layoutno", 0) != b_.get("layoutno", 0),
|
||||||
b["text"].strip()[-1] in "。?!?",
|
b["text"].strip()[-1] in "。?!?",
|
||||||
self.is_english and b["text"].strip()[-1] in ".!?",
|
self.is_english and b["text"].strip()[-1] in ".!?",
|
||||||
b["page_number"] == b_["page_number"] and b_["top"] -
|
b["page_number"] == b_["page_number"] and b_["top"] - b["bottom"] > self.mean_height[b["page_number"] - 1] * 1.5,
|
||||||
b["bottom"] > self.mean_height[b["page_number"] - 1] * 1.5,
|
b["page_number"] < b_["page_number"] and abs(b["x0"] - b_["x0"]) > self.mean_width[b["page_number"] - 1] * 4,
|
||||||
b["page_number"] < b_["page_number"] and abs(
|
|
||||||
b["x0"] - b_["x0"]) > self.mean_width[b["page_number"] - 1] * 4,
|
|
||||||
]
|
]
|
||||||
# split features
|
# split features
|
||||||
detach_feats = [b["x1"] < b_["x0"],
|
detach_feats = [b["x1"] < b_["x0"], b["x0"] > b_["x1"]]
|
||||||
b["x0"] > b_["x1"]]
|
|
||||||
if (any(feats) and not any(concatting_feats)) or any(detach_feats):
|
if (any(feats) and not any(concatting_feats)) or any(detach_feats):
|
||||||
logging.debug("{} {} {} {}".format(
|
logging.debug(
|
||||||
b["text"],
|
"{} {} {} {}".format(
|
||||||
b_["text"],
|
b["text"],
|
||||||
any(feats),
|
b_["text"],
|
||||||
any(concatting_feats),
|
any(feats),
|
||||||
))
|
any(concatting_feats),
|
||||||
|
)
|
||||||
|
)
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
# merge up and down
|
# merge up and down
|
||||||
@ -529,14 +499,11 @@ class RAGFlowPdfParser:
|
|||||||
if not concat_between_pages and down["page_number"] > up["page_number"]:
|
if not concat_between_pages and down["page_number"] > up["page_number"]:
|
||||||
break
|
break
|
||||||
|
|
||||||
if up.get("R", "") != down.get(
|
if up.get("R", "") != down.get("R", "") and up["text"][-1] != ",":
|
||||||
"R", "") and up["text"][-1] != ",":
|
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if re.match(r"[0-9]{2,3}/[0-9]{3}$", up["text"]) \
|
if re.match(r"[0-9]{2,3}/[0-9]{3}$", up["text"]) or re.match(r"[0-9]{2,3}/[0-9]{3}$", down["text"]) or not down["text"].strip():
|
||||||
or re.match(r"[0-9]{2,3}/[0-9]{3}$", down["text"]) \
|
|
||||||
or not down["text"].strip():
|
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -544,14 +511,12 @@ class RAGFlowPdfParser:
|
|||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if up["x1"] < down["x0"] - 10 * \
|
if up["x1"] < down["x0"] - 10 * mw or up["x0"] > down["x1"] + 10 * mw:
|
||||||
mw or up["x0"] > down["x1"] + 10 * mw:
|
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if i - dp < 5 and up.get("layout_type") == "text":
|
if i - dp < 5 and up.get("layout_type") == "text":
|
||||||
if up.get("layoutno", "1") == down.get(
|
if up.get("layoutno", "1") == down.get("layoutno", "2"):
|
||||||
"layoutno", "2"):
|
|
||||||
dfs(down, i + 1)
|
dfs(down, i + 1)
|
||||||
boxes.pop(i)
|
boxes.pop(i)
|
||||||
return
|
return
|
||||||
@ -559,8 +524,7 @@ class RAGFlowPdfParser:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
fea = self._updown_concat_features(up, down)
|
fea = self._updown_concat_features(up, down)
|
||||||
if self.updown_cnt_mdl.predict(
|
if self.updown_cnt_mdl.predict(xgb.DMatrix([fea]))[0] <= 0.5:
|
||||||
xgb.DMatrix([fea]))[0] <= 0.5:
|
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
dfs(down, i + 1)
|
dfs(down, i + 1)
|
||||||
@ -584,16 +548,14 @@ class RAGFlowPdfParser:
|
|||||||
c["text"] = c["text"].strip()
|
c["text"] = c["text"].strip()
|
||||||
if not c["text"]:
|
if not c["text"]:
|
||||||
continue
|
continue
|
||||||
if t["text"] and re.match(
|
if t["text"] and re.match(r"[0-9\.a-zA-Z]+$", t["text"][-1] + c["text"][-1]):
|
||||||
r"[0-9\.a-zA-Z]+$", t["text"][-1] + c["text"][-1]):
|
|
||||||
t["text"] += " "
|
t["text"] += " "
|
||||||
t["text"] += c["text"]
|
t["text"] += c["text"]
|
||||||
t["x0"] = min(t["x0"], c["x0"])
|
t["x0"] = min(t["x0"], c["x0"])
|
||||||
t["x1"] = max(t["x1"], c["x1"])
|
t["x1"] = max(t["x1"], c["x1"])
|
||||||
t["page_number"] = min(t["page_number"], c["page_number"])
|
t["page_number"] = min(t["page_number"], c["page_number"])
|
||||||
t["bottom"] = c["bottom"]
|
t["bottom"] = c["bottom"]
|
||||||
if not t["layout_type"] \
|
if not t["layout_type"] and c["layout_type"]:
|
||||||
and c["layout_type"]:
|
|
||||||
t["layout_type"] = c["layout_type"]
|
t["layout_type"] = c["layout_type"]
|
||||||
boxes.append(t)
|
boxes.append(t)
|
||||||
|
|
||||||
@ -605,25 +567,20 @@ class RAGFlowPdfParser:
|
|||||||
findit = False
|
findit = False
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(self.boxes):
|
while i < len(self.boxes):
|
||||||
if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$",
|
if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", re.sub(r"( | |\u3000)+", "", self.boxes[i]["text"].lower())):
|
||||||
re.sub(r"( | |\u3000)+", "", self.boxes[i]["text"].lower())):
|
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
findit = True
|
findit = True
|
||||||
eng = re.match(
|
eng = re.match(r"[0-9a-zA-Z :'.-]{5,}", self.boxes[i]["text"].strip())
|
||||||
r"[0-9a-zA-Z :'.-]{5,}",
|
|
||||||
self.boxes[i]["text"].strip())
|
|
||||||
self.boxes.pop(i)
|
self.boxes.pop(i)
|
||||||
if i >= len(self.boxes):
|
if i >= len(self.boxes):
|
||||||
break
|
break
|
||||||
prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join(
|
prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join(self.boxes[i]["text"].strip().split()[:2])
|
||||||
self.boxes[i]["text"].strip().split()[:2])
|
|
||||||
while not prefix:
|
while not prefix:
|
||||||
self.boxes.pop(i)
|
self.boxes.pop(i)
|
||||||
if i >= len(self.boxes):
|
if i >= len(self.boxes):
|
||||||
break
|
break
|
||||||
prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join(
|
prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join(self.boxes[i]["text"].strip().split()[:2])
|
||||||
self.boxes[i]["text"].strip().split()[:2])
|
|
||||||
self.boxes.pop(i)
|
self.boxes.pop(i)
|
||||||
if i >= len(self.boxes) or not prefix:
|
if i >= len(self.boxes) or not prefix:
|
||||||
break
|
break
|
||||||
@ -662,10 +619,12 @@ class RAGFlowPdfParser:
|
|||||||
self.boxes.pop(i + 1)
|
self.boxes.pop(i + 1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if b["text"].strip()[0] != b_["text"].strip()[0] \
|
if (
|
||||||
or b["text"].strip()[0].lower() in set("qwertyuopasdfghjklzxcvbnm") \
|
b["text"].strip()[0] != b_["text"].strip()[0]
|
||||||
or rag_tokenizer.is_chinese(b["text"].strip()[0]) \
|
or b["text"].strip()[0].lower() in set("qwertyuopasdfghjklzxcvbnm")
|
||||||
or b["top"] > b_["bottom"]:
|
or rag_tokenizer.is_chinese(b["text"].strip()[0])
|
||||||
|
or b["top"] > b_["bottom"]
|
||||||
|
):
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
b_["text"] = b["text"] + "\n" + b_["text"]
|
b_["text"] = b["text"] + "\n" + b_["text"]
|
||||||
@ -685,12 +644,8 @@ class RAGFlowPdfParser:
|
|||||||
if "layoutno" not in self.boxes[i]:
|
if "layoutno" not in self.boxes[i]:
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
lout_no = str(self.boxes[i]["page_number"]) + \
|
lout_no = str(self.boxes[i]["page_number"]) + "-" + str(self.boxes[i]["layoutno"])
|
||||||
"-" + str(self.boxes[i]["layoutno"])
|
if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title", "figure caption", "reference"]:
|
||||||
if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption",
|
|
||||||
"title",
|
|
||||||
"figure caption",
|
|
||||||
"reference"]:
|
|
||||||
nomerge_lout_no.append(lst_lout_no)
|
nomerge_lout_no.append(lst_lout_no)
|
||||||
if self.boxes[i]["layout_type"] == "table":
|
if self.boxes[i]["layout_type"] == "table":
|
||||||
if re.match(r"(数据|资料|图表)*来源[:: ]", self.boxes[i]["text"]):
|
if re.match(r"(数据|资料|图表)*来源[:: ]", self.boxes[i]["text"]):
|
||||||
@ -716,8 +671,7 @@ class RAGFlowPdfParser:
|
|||||||
|
|
||||||
# merge table on different pages
|
# merge table on different pages
|
||||||
nomerge_lout_no = set(nomerge_lout_no)
|
nomerge_lout_no = set(nomerge_lout_no)
|
||||||
tbls = sorted([(k, bxs) for k, bxs in tables.items()],
|
tbls = sorted([(k, bxs) for k, bxs in tables.items()], key=lambda x: (x[1][0]["top"], x[1][0]["x0"]))
|
||||||
key=lambda x: (x[1][0]["top"], x[1][0]["x0"]))
|
|
||||||
|
|
||||||
i = len(tbls) - 1
|
i = len(tbls) - 1
|
||||||
while i - 1 >= 0:
|
while i - 1 >= 0:
|
||||||
@ -758,9 +712,7 @@ class RAGFlowPdfParser:
|
|||||||
if b.get("layout_type", "").find("caption") >= 0:
|
if b.get("layout_type", "").find("caption") >= 0:
|
||||||
continue
|
continue
|
||||||
y_dis = self._y_dis(c, b)
|
y_dis = self._y_dis(c, b)
|
||||||
x_dis = self._x_dis(
|
x_dis = self._x_dis(c, b) if not x_overlapped(c, b) else 0
|
||||||
c, b) if not x_overlapped(
|
|
||||||
c, b) else 0
|
|
||||||
dis = y_dis * y_dis + x_dis * x_dis
|
dis = y_dis * y_dis + x_dis * x_dis
|
||||||
if dis < minv:
|
if dis < minv:
|
||||||
mink = k
|
mink = k
|
||||||
@ -774,18 +726,10 @@ class RAGFlowPdfParser:
|
|||||||
# continue
|
# continue
|
||||||
if tv < fv and tk:
|
if tv < fv and tk:
|
||||||
tables[tk].insert(0, c)
|
tables[tk].insert(0, c)
|
||||||
logging.debug(
|
logging.debug("TABLE:" + self.boxes[i]["text"] + "; Cap: " + tk)
|
||||||
"TABLE:" +
|
|
||||||
self.boxes[i]["text"] +
|
|
||||||
"; Cap: " +
|
|
||||||
tk)
|
|
||||||
elif fk:
|
elif fk:
|
||||||
figures[fk].insert(0, c)
|
figures[fk].insert(0, c)
|
||||||
logging.debug(
|
logging.debug("FIGURE:" + self.boxes[i]["text"] + "; Cap: " + tk)
|
||||||
"FIGURE:" +
|
|
||||||
self.boxes[i]["text"] +
|
|
||||||
"; Cap: " +
|
|
||||||
tk)
|
|
||||||
self.boxes.pop(i)
|
self.boxes.pop(i)
|
||||||
|
|
||||||
def cropout(bxs, ltype, poss):
|
def cropout(bxs, ltype, poss):
|
||||||
@ -794,29 +738,19 @@ class RAGFlowPdfParser:
|
|||||||
if len(pn) < 2:
|
if len(pn) < 2:
|
||||||
pn = list(pn)[0]
|
pn = list(pn)[0]
|
||||||
ht = self.page_cum_height[pn]
|
ht = self.page_cum_height[pn]
|
||||||
b = {
|
b = {"x0": np.min([b["x0"] for b in bxs]), "top": np.min([b["top"] for b in bxs]) - ht, "x1": np.max([b["x1"] for b in bxs]), "bottom": np.max([b["bottom"] for b in bxs]) - ht}
|
||||||
"x0": np.min([b["x0"] for b in bxs]),
|
|
||||||
"top": np.min([b["top"] for b in bxs]) - ht,
|
|
||||||
"x1": np.max([b["x1"] for b in bxs]),
|
|
||||||
"bottom": np.max([b["bottom"] for b in bxs]) - ht
|
|
||||||
}
|
|
||||||
louts = [layout for layout in self.page_layout[pn] if layout["type"] == ltype]
|
louts = [layout for layout in self.page_layout[pn] if layout["type"] == ltype]
|
||||||
ii = Recognizer.find_overlapped(b, louts, naive=True)
|
ii = Recognizer.find_overlapped(b, louts, naive=True)
|
||||||
if ii is not None:
|
if ii is not None:
|
||||||
b = louts[ii]
|
b = louts[ii]
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logging.warning(f"Missing layout match: {pn + 1},%s" % (bxs[0].get("layoutno", "")))
|
||||||
f"Missing layout match: {pn + 1},%s" %
|
|
||||||
(bxs[0].get(
|
|
||||||
"layoutno", "")))
|
|
||||||
|
|
||||||
left, top, right, bott = b["x0"], b["top"], b["x1"], b["bottom"]
|
left, top, right, bott = b["x0"], b["top"], b["x1"], b["bottom"]
|
||||||
if right < left:
|
if right < left:
|
||||||
right = left + 1
|
right = left + 1
|
||||||
poss.append((pn + self.page_from, left, right, top, bott))
|
poss.append((pn + self.page_from, left, right, top, bott))
|
||||||
return self.page_images[pn] \
|
return self.page_images[pn].crop((left * ZM, top * ZM, right * ZM, bott * ZM))
|
||||||
.crop((left * ZM, top * ZM,
|
|
||||||
right * ZM, bott * ZM))
|
|
||||||
pn = {}
|
pn = {}
|
||||||
for b in bxs:
|
for b in bxs:
|
||||||
p = b["page_number"] - 1
|
p = b["page_number"] - 1
|
||||||
@ -825,10 +759,7 @@ class RAGFlowPdfParser:
|
|||||||
pn[p].append(b)
|
pn[p].append(b)
|
||||||
pn = sorted(pn.items(), key=lambda x: x[0])
|
pn = sorted(pn.items(), key=lambda x: x[0])
|
||||||
imgs = [cropout(arr, ltype, poss) for p, arr in pn]
|
imgs = [cropout(arr, ltype, poss) for p, arr in pn]
|
||||||
pic = Image.new("RGB",
|
pic = Image.new("RGB", (int(np.max([i.size[0] for i in imgs])), int(np.sum([m.size[1] for m in imgs]))), (245, 245, 245))
|
||||||
(int(np.max([i.size[0] for i in imgs])),
|
|
||||||
int(np.sum([m.size[1] for m in imgs]))),
|
|
||||||
(245, 245, 245))
|
|
||||||
height = 0
|
height = 0
|
||||||
for img in imgs:
|
for img in imgs:
|
||||||
pic.paste(img, (0, int(height)))
|
pic.paste(img, (0, int(height)))
|
||||||
@ -848,30 +779,20 @@ class RAGFlowPdfParser:
|
|||||||
poss = []
|
poss = []
|
||||||
|
|
||||||
if separate_tables_figures:
|
if separate_tables_figures:
|
||||||
figure_results.append(
|
figure_results.append((cropout(bxs, "figure", poss), [txt]))
|
||||||
(cropout(
|
|
||||||
bxs,
|
|
||||||
"figure", poss),
|
|
||||||
[txt]))
|
|
||||||
figure_positions.append(poss)
|
figure_positions.append(poss)
|
||||||
else:
|
else:
|
||||||
res.append(
|
res.append((cropout(bxs, "figure", poss), [txt]))
|
||||||
(cropout(
|
|
||||||
bxs,
|
|
||||||
"figure", poss),
|
|
||||||
[txt]))
|
|
||||||
positions.append(poss)
|
positions.append(poss)
|
||||||
|
|
||||||
for k, bxs in tables.items():
|
for k, bxs in tables.items():
|
||||||
if not bxs:
|
if not bxs:
|
||||||
continue
|
continue
|
||||||
bxs = Recognizer.sort_Y_firstly(bxs, np.mean(
|
bxs = Recognizer.sort_Y_firstly(bxs, np.mean([(b["bottom"] - b["top"]) / 2 for b in bxs]))
|
||||||
[(b["bottom"] - b["top"]) / 2 for b in bxs]))
|
|
||||||
|
|
||||||
poss = []
|
poss = []
|
||||||
|
|
||||||
res.append((cropout(bxs, "table", poss),
|
res.append((cropout(bxs, "table", poss), self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english)))
|
||||||
self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english)))
|
|
||||||
positions.append(poss)
|
positions.append(poss)
|
||||||
|
|
||||||
if separate_tables_figures:
|
if separate_tables_figures:
|
||||||
@ -905,7 +826,7 @@ class RAGFlowPdfParser:
|
|||||||
(r"[0-9]+)", 10),
|
(r"[0-9]+)", 10),
|
||||||
(r"[\((][0-9]+[)\)]", 11),
|
(r"[\((][0-9]+[)\)]", 11),
|
||||||
(r"[零一二三四五六七八九十百]+是", 12),
|
(r"[零一二三四五六七八九十百]+是", 12),
|
||||||
(r"[⚫•➢✓]", 12)
|
(r"[⚫•➢✓]", 12),
|
||||||
]:
|
]:
|
||||||
if re.match(p, line):
|
if re.match(p, line):
|
||||||
return j
|
return j
|
||||||
@ -924,12 +845,9 @@ class RAGFlowPdfParser:
|
|||||||
if pn[-1] - 1 >= page_images_cnt:
|
if pn[-1] - 1 >= page_images_cnt:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
|
return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##".format("-".join([str(p) for p in pn]), bx["x0"], bx["x1"], top, bott)
|
||||||
.format("-".join([str(p) for p in pn]),
|
|
||||||
bx["x0"], bx["x1"], top, bott)
|
|
||||||
|
|
||||||
def __filterout_scraps(self, boxes, ZM):
|
def __filterout_scraps(self, boxes, ZM):
|
||||||
|
|
||||||
def width(b):
|
def width(b):
|
||||||
return b["x1"] - b["x0"]
|
return b["x1"] - b["x0"]
|
||||||
|
|
||||||
@ -939,8 +857,7 @@ class RAGFlowPdfParser:
|
|||||||
def usefull(b):
|
def usefull(b):
|
||||||
if b.get("layout_type"):
|
if b.get("layout_type"):
|
||||||
return True
|
return True
|
||||||
if width(
|
if width(b) > self.page_images[b["page_number"] - 1].size[0] / ZM / 3:
|
||||||
b) > self.page_images[b["page_number"] - 1].size[0] / ZM / 3:
|
|
||||||
return True
|
return True
|
||||||
if b["bottom"] - b["top"] > self.mean_height[b["page_number"] - 1]:
|
if b["bottom"] - b["top"] > self.mean_height[b["page_number"] - 1]:
|
||||||
return True
|
return True
|
||||||
@ -952,31 +869,23 @@ class RAGFlowPdfParser:
|
|||||||
widths = []
|
widths = []
|
||||||
pw = self.page_images[boxes[0]["page_number"] - 1].size[0] / ZM
|
pw = self.page_images[boxes[0]["page_number"] - 1].size[0] / ZM
|
||||||
mh = self.mean_height[boxes[0]["page_number"] - 1]
|
mh = self.mean_height[boxes[0]["page_number"] - 1]
|
||||||
mj = self.proj_match(
|
mj = self.proj_match(boxes[0]["text"]) or boxes[0].get("layout_type", "") == "title"
|
||||||
boxes[0]["text"]) or boxes[0].get(
|
|
||||||
"layout_type",
|
|
||||||
"") == "title"
|
|
||||||
|
|
||||||
def dfs(line, st):
|
def dfs(line, st):
|
||||||
nonlocal mh, pw, lines, widths
|
nonlocal mh, pw, lines, widths
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
widths.append(width(line))
|
widths.append(width(line))
|
||||||
mmj = self.proj_match(
|
mmj = self.proj_match(line["text"]) or line.get("layout_type", "") == "title"
|
||||||
line["text"]) or line.get(
|
|
||||||
"layout_type",
|
|
||||||
"") == "title"
|
|
||||||
for i in range(st + 1, min(st + 20, len(boxes))):
|
for i in range(st + 1, min(st + 20, len(boxes))):
|
||||||
if (boxes[i]["page_number"] - line["page_number"]) > 0:
|
if (boxes[i]["page_number"] - line["page_number"]) > 0:
|
||||||
break
|
break
|
||||||
if not mmj and self._y_dis(
|
if not mmj and self._y_dis(line, boxes[i]) >= 3 * mh and height(line) < 1.5 * mh:
|
||||||
line, boxes[i]) >= 3 * mh and height(line) < 1.5 * mh:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if not usefull(boxes[i]):
|
if not usefull(boxes[i]):
|
||||||
continue
|
continue
|
||||||
if mmj or \
|
if mmj or (self._x_dis(boxes[i], line) < pw / 10):
|
||||||
(self._x_dis(boxes[i], line) < pw / 10): \
|
# and abs(width(boxes[i])-width_mean)/max(width(boxes[i]),width_mean)<0.5):
|
||||||
# and abs(width(boxes[i])-width_mean)/max(width(boxes[i]),width_mean)<0.5):
|
|
||||||
# concat following
|
# concat following
|
||||||
dfs(boxes[i], i)
|
dfs(boxes[i], i)
|
||||||
boxes.pop(i)
|
boxes.pop(i)
|
||||||
@ -992,11 +901,9 @@ class RAGFlowPdfParser:
|
|||||||
boxes.pop(0)
|
boxes.pop(0)
|
||||||
mw = np.mean(widths)
|
mw = np.mean(widths)
|
||||||
if mj or mw / pw >= 0.35 or mw > 200:
|
if mj or mw / pw >= 0.35 or mw > 200:
|
||||||
res.append(
|
res.append("\n".join([c["text"] + self._line_tag(c, ZM) for c in lines]))
|
||||||
"\n".join([c["text"] + self._line_tag(c, ZM) for c in lines]))
|
|
||||||
else:
|
else:
|
||||||
logging.debug("REMOVED: " +
|
logging.debug("REMOVED: " + "<<".join([c["text"] for c in lines]))
|
||||||
"<<".join([c["text"] for c in lines]))
|
|
||||||
|
|
||||||
return "\n\n".join(res)
|
return "\n\n".join(res)
|
||||||
|
|
||||||
@ -1004,16 +911,14 @@ class RAGFlowPdfParser:
|
|||||||
def total_page_number(fnm, binary=None):
|
def total_page_number(fnm, binary=None):
|
||||||
try:
|
try:
|
||||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||||
pdf = pdfplumber.open(
|
pdf = pdfplumber.open(fnm) if not binary else pdfplumber.open(BytesIO(binary))
|
||||||
fnm) if not binary else pdfplumber.open(BytesIO(binary))
|
|
||||||
total_page = len(pdf.pages)
|
total_page = len(pdf.pages)
|
||||||
pdf.close()
|
pdf.close()
|
||||||
return total_page
|
return total_page
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("total_page_number")
|
logging.exception("total_page_number")
|
||||||
|
|
||||||
def __images__(self, fnm, zoomin=3, page_from=0,
|
def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None):
|
||||||
page_to=299, callback=None):
|
|
||||||
self.lefted_chars = []
|
self.lefted_chars = []
|
||||||
self.mean_height = []
|
self.mean_height = []
|
||||||
self.mean_width = []
|
self.mean_width = []
|
||||||
@ -1025,10 +930,9 @@ class RAGFlowPdfParser:
|
|||||||
start = timer()
|
start = timer()
|
||||||
try:
|
try:
|
||||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||||
with (pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm))) as pdf:
|
with pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm)) as pdf:
|
||||||
self.pdf = pdf
|
self.pdf = pdf
|
||||||
self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).annotated for i, p in
|
self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).annotated for i, p in enumerate(self.pdf.pages[page_from:page_to])]
|
||||||
enumerate(self.pdf.pages[page_from:page_to])]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]]
|
self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]]
|
||||||
@ -1044,11 +948,11 @@ class RAGFlowPdfParser:
|
|||||||
|
|
||||||
self.outlines = []
|
self.outlines = []
|
||||||
try:
|
try:
|
||||||
with (pdf2_read(fnm if isinstance(fnm, str)
|
with pdf2_read(fnm if isinstance(fnm, str) else BytesIO(fnm)) as pdf:
|
||||||
else BytesIO(fnm))) as pdf:
|
|
||||||
self.pdf = pdf
|
self.pdf = pdf
|
||||||
|
|
||||||
outlines = self.pdf.outline
|
outlines = self.pdf.outline
|
||||||
|
|
||||||
def dfs(arr, depth):
|
def dfs(arr, depth):
|
||||||
for a in arr:
|
for a in arr:
|
||||||
if isinstance(a, dict):
|
if isinstance(a, dict):
|
||||||
@ -1065,11 +969,11 @@ class RAGFlowPdfParser:
|
|||||||
logging.warning("Miss outlines")
|
logging.warning("Miss outlines")
|
||||||
|
|
||||||
logging.debug("Images converted.")
|
logging.debug("Images converted.")
|
||||||
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
|
self.is_english = [
|
||||||
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
|
re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i])))))
|
||||||
range(len(self.page_chars))]
|
for i in range(len(self.page_chars))
|
||||||
if sum([1 if e else 0 for e in self.is_english]) > len(
|
]
|
||||||
self.page_images) / 2:
|
if sum([1 if e else 0 for e in self.is_english]) > len(self.page_images) / 2:
|
||||||
self.is_english = True
|
self.is_english = True
|
||||||
else:
|
else:
|
||||||
self.is_english = False
|
self.is_english = False
|
||||||
@ -1077,10 +981,12 @@ class RAGFlowPdfParser:
|
|||||||
async def __img_ocr(i, id, img, chars, limiter):
|
async def __img_ocr(i, id, img, chars, limiter):
|
||||||
j = 0
|
j = 0
|
||||||
while j + 1 < len(chars):
|
while j + 1 < len(chars):
|
||||||
if chars[j]["text"] and chars[j + 1]["text"] \
|
if (
|
||||||
and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \
|
chars[j]["text"]
|
||||||
and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"],
|
and chars[j + 1]["text"]
|
||||||
chars[j]["width"]) / 2:
|
and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"])
|
||||||
|
and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"], chars[j]["width"]) / 2
|
||||||
|
):
|
||||||
chars[j]["text"] += " "
|
chars[j]["text"] += " "
|
||||||
j += 1
|
j += 1
|
||||||
|
|
||||||
@ -1096,12 +1002,8 @@ class RAGFlowPdfParser:
|
|||||||
async def __img_ocr_launcher():
|
async def __img_ocr_launcher():
|
||||||
def __ocr_preprocess():
|
def __ocr_preprocess():
|
||||||
chars = self.page_chars[i] if not self.is_english else []
|
chars = self.page_chars[i] if not self.is_english else []
|
||||||
self.mean_height.append(
|
self.mean_height.append(np.median(sorted([c["height"] for c in chars])) if chars else 0)
|
||||||
np.median(sorted([c["height"] for c in chars])) if chars else 0
|
self.mean_width.append(np.median(sorted([c["width"] for c in chars])) if chars else 8)
|
||||||
)
|
|
||||||
self.mean_width.append(
|
|
||||||
np.median(sorted([c["width"] for c in chars])) if chars else 8
|
|
||||||
)
|
|
||||||
self.page_cum_height.append(img.size[1] / zoomin)
|
self.page_cum_height.append(img.size[1] / zoomin)
|
||||||
return chars
|
return chars
|
||||||
|
|
||||||
@ -1110,8 +1012,7 @@ class RAGFlowPdfParser:
|
|||||||
for i, img in enumerate(self.page_images):
|
for i, img in enumerate(self.page_images):
|
||||||
chars = __ocr_preprocess()
|
chars = __ocr_preprocess()
|
||||||
|
|
||||||
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars,
|
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % PARALLEL_DEVICES])
|
||||||
self.parallel_limiter[i % PARALLEL_DEVICES])
|
|
||||||
await trio.sleep(0.1)
|
await trio.sleep(0.1)
|
||||||
else:
|
else:
|
||||||
for i, img in enumerate(self.page_images):
|
for i, img in enumerate(self.page_images):
|
||||||
@ -1124,11 +1025,9 @@ class RAGFlowPdfParser:
|
|||||||
|
|
||||||
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")
|
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")
|
||||||
|
|
||||||
if not self.is_english and not any(
|
if not self.is_english and not any([c for c in self.page_chars]) and self.boxes:
|
||||||
[c for c in self.page_chars]) and self.boxes:
|
|
||||||
bxes = [b for bxs in self.boxes for b in bxs]
|
bxes = [b for bxs in self.boxes for b in bxs]
|
||||||
self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}",
|
self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))]))
|
||||||
"".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))]))
|
|
||||||
|
|
||||||
logging.debug("Is it English:", self.is_english)
|
logging.debug("Is it English:", self.is_english)
|
||||||
|
|
||||||
@ -1144,8 +1043,7 @@ class RAGFlowPdfParser:
|
|||||||
self._text_merge()
|
self._text_merge()
|
||||||
self._concat_downward()
|
self._concat_downward()
|
||||||
self._filter_forpages()
|
self._filter_forpages()
|
||||||
tbls = self._extract_table_figure(
|
tbls = self._extract_table_figure(need_image, zoomin, return_html, False)
|
||||||
need_image, zoomin, return_html, False)
|
|
||||||
return self.__filterout_scraps(deepcopy(self.boxes), zoomin), tbls
|
return self.__filterout_scraps(deepcopy(self.boxes), zoomin), tbls
|
||||||
|
|
||||||
def parse_into_bboxes(self, fnm, callback=None, zoomin=3):
|
def parse_into_bboxes(self, fnm, callback=None, zoomin=3):
|
||||||
@ -1177,11 +1075,11 @@ class RAGFlowPdfParser:
|
|||||||
def insert_table_figures(tbls_or_figs, layout_type):
|
def insert_table_figures(tbls_or_figs, layout_type):
|
||||||
def min_rectangle_distance(rect1, rect2):
|
def min_rectangle_distance(rect1, rect2):
|
||||||
import math
|
import math
|
||||||
|
|
||||||
pn1, left1, right1, top1, bottom1 = rect1
|
pn1, left1, right1, top1, bottom1 = rect1
|
||||||
pn2, left2, right2, top2, bottom2 = rect2
|
pn2, left2, right2, top2, bottom2 = rect2
|
||||||
if (right1 >= left2 and right2 >= left1 and
|
if right1 >= left2 and right2 >= left1 and bottom1 >= top2 and bottom2 >= top1:
|
||||||
bottom1 >= top2 and bottom2 >= top1):
|
return 0 + (pn1 - pn2) * 10000
|
||||||
return 0 + (pn1-pn2)*10000
|
|
||||||
if right1 < left2:
|
if right1 < left2:
|
||||||
dx = left2 - right1
|
dx = left2 - right1
|
||||||
elif right2 < left1:
|
elif right2 < left1:
|
||||||
@ -1194,18 +1092,16 @@ class RAGFlowPdfParser:
|
|||||||
dy = top1 - bottom2
|
dy = top1 - bottom2
|
||||||
else:
|
else:
|
||||||
dy = 0
|
dy = 0
|
||||||
return math.sqrt(dx*dx + dy*dy) + (pn1-pn2)*10000
|
return math.sqrt(dx * dx + dy * dy) + (pn1 - pn2) * 10000
|
||||||
|
|
||||||
for (img, txt), poss in tbls_or_figs:
|
for (img, txt), poss in tbls_or_figs:
|
||||||
bboxes = [(i, (b["page_number"], b["x0"], b["x1"], b["top"], b["bottom"])) for i, b in enumerate(self.boxes)]
|
bboxes = [(i, (b["page_number"], b["x0"], b["x1"], b["top"], b["bottom"])) for i, b in enumerate(self.boxes)]
|
||||||
dists = [(min_rectangle_distance((pn, left, right, top, bott), rect),i) for i, rect in bboxes for pn, left, right, top, bott in poss]
|
dists = [(min_rectangle_distance((pn, left, right, top, bott), rect), i) for i, rect in bboxes for pn, left, right, top, bott in poss]
|
||||||
min_i = np.argmin(dists, axis=0)[0]
|
min_i = np.argmin(dists, axis=0)[0]
|
||||||
min_i, rect = bboxes[dists[min_i][-1]]
|
min_i, rect = bboxes[dists[min_i][-1]]
|
||||||
if isinstance(txt, list):
|
if isinstance(txt, list):
|
||||||
txt = "\n".join(txt)
|
txt = "\n".join(txt)
|
||||||
self.boxes.insert(min_i, {
|
self.boxes.insert(min_i, {"page_number": rect[0], "x0": rect[1], "x1": rect[2], "top": rect[3], "bottom": rect[4], "layout_type": layout_type, "text": txt, "image": img})
|
||||||
"page_number": rect[0], "x0": rect[1], "x1": rect[2], "top": rect[3], "bottom": rect[4], "layout_type": layout_type, "text": txt, "image": img
|
|
||||||
})
|
|
||||||
|
|
||||||
for b in self.boxes:
|
for b in self.boxes:
|
||||||
b["position_tag"] = self._line_tag(b, zoomin)
|
b["position_tag"] = self._line_tag(b, zoomin)
|
||||||
@ -1225,12 +1121,9 @@ class RAGFlowPdfParser:
|
|||||||
def extract_positions(txt):
|
def extract_positions(txt):
|
||||||
poss = []
|
poss = []
|
||||||
for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt):
|
for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt):
|
||||||
pn, left, right, top, bottom = tag.strip(
|
pn, left, right, top, bottom = tag.strip("#").strip("@").split("\t")
|
||||||
"#").strip("@").split("\t")
|
left, right, top, bottom = float(left), float(right), float(top), float(bottom)
|
||||||
left, right, top, bottom = float(left), float(
|
poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom))
|
||||||
right), float(top), float(bottom)
|
|
||||||
poss.append(([int(p) - 1 for p in pn.split("-")],
|
|
||||||
left, right, top, bottom))
|
|
||||||
return poss
|
return poss
|
||||||
|
|
||||||
def crop(self, text, ZM=3, need_position=False):
|
def crop(self, text, ZM=3, need_position=False):
|
||||||
@ -1241,15 +1134,12 @@ class RAGFlowPdfParser:
|
|||||||
return None, None
|
return None, None
|
||||||
return
|
return
|
||||||
|
|
||||||
max_width = max(
|
max_width = max(np.max([right - left for (_, left, right, _, _) in poss]), 6)
|
||||||
np.max([right - left for (_, left, right, _, _) in poss]), 6)
|
|
||||||
GAP = 6
|
GAP = 6
|
||||||
pos = poss[0]
|
pos = poss[0]
|
||||||
poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(
|
poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
|
||||||
0, pos[3] - 120), max(pos[3] - GAP, 0)))
|
|
||||||
pos = poss[-1]
|
pos = poss[-1]
|
||||||
poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + GAP),
|
poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + GAP), min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + 120)))
|
||||||
min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + 120)))
|
|
||||||
|
|
||||||
positions = []
|
positions = []
|
||||||
for ii, (pns, left, right, top, bottom) in enumerate(poss):
|
for ii, (pns, left, right, top, bottom) in enumerate(poss):
|
||||||
@ -1257,28 +1147,14 @@ class RAGFlowPdfParser:
|
|||||||
bottom *= ZM
|
bottom *= ZM
|
||||||
for pn in pns[1:]:
|
for pn in pns[1:]:
|
||||||
bottom += self.page_images[pn - 1].size[1]
|
bottom += self.page_images[pn - 1].size[1]
|
||||||
imgs.append(
|
imgs.append(self.page_images[pns[0]].crop((left * ZM, top * ZM, right * ZM, min(bottom, self.page_images[pns[0]].size[1]))))
|
||||||
self.page_images[pns[0]].crop((left * ZM, top * ZM,
|
|
||||||
right *
|
|
||||||
ZM, min(
|
|
||||||
bottom, self.page_images[pns[0]].size[1])
|
|
||||||
))
|
|
||||||
)
|
|
||||||
if 0 < ii < len(poss) - 1:
|
if 0 < ii < len(poss) - 1:
|
||||||
positions.append((pns[0] + self.page_from, left, right, top, min(
|
positions.append((pns[0] + self.page_from, left, right, top, min(bottom, self.page_images[pns[0]].size[1]) / ZM))
|
||||||
bottom, self.page_images[pns[0]].size[1]) / ZM))
|
|
||||||
bottom -= self.page_images[pns[0]].size[1]
|
bottom -= self.page_images[pns[0]].size[1]
|
||||||
for pn in pns[1:]:
|
for pn in pns[1:]:
|
||||||
imgs.append(
|
imgs.append(self.page_images[pn].crop((left * ZM, 0, right * ZM, min(bottom, self.page_images[pn].size[1]))))
|
||||||
self.page_images[pn].crop((left * ZM, 0,
|
|
||||||
right * ZM,
|
|
||||||
min(bottom,
|
|
||||||
self.page_images[pn].size[1])
|
|
||||||
))
|
|
||||||
)
|
|
||||||
if 0 < ii < len(poss) - 1:
|
if 0 < ii < len(poss) - 1:
|
||||||
positions.append((pn + self.page_from, left, right, 0, min(
|
positions.append((pn + self.page_from, left, right, 0, min(bottom, self.page_images[pn].size[1]) / ZM))
|
||||||
bottom, self.page_images[pn].size[1]) / ZM))
|
|
||||||
bottom -= self.page_images[pn].size[1]
|
bottom -= self.page_images[pn].size[1]
|
||||||
|
|
||||||
if not imgs:
|
if not imgs:
|
||||||
@ -1290,14 +1166,12 @@ class RAGFlowPdfParser:
|
|||||||
height += img.size[1] + GAP
|
height += img.size[1] + GAP
|
||||||
height = int(height)
|
height = int(height)
|
||||||
width = int(np.max([i.size[0] for i in imgs]))
|
width = int(np.max([i.size[0] for i in imgs]))
|
||||||
pic = Image.new("RGB",
|
pic = Image.new("RGB", (width, height), (245, 245, 245))
|
||||||
(width, height),
|
|
||||||
(245, 245, 245))
|
|
||||||
height = 0
|
height = 0
|
||||||
for ii, img in enumerate(imgs):
|
for ii, img in enumerate(imgs):
|
||||||
if ii == 0 or ii + 1 == len(imgs):
|
if ii == 0 or ii + 1 == len(imgs):
|
||||||
img = img.convert('RGBA')
|
img = img.convert("RGBA")
|
||||||
overlay = Image.new('RGBA', img.size, (0, 0, 0, 0))
|
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
||||||
overlay.putalpha(128)
|
overlay.putalpha(128)
|
||||||
img = Image.alpha_composite(img, overlay).convert("RGB")
|
img = Image.alpha_composite(img, overlay).convert("RGB")
|
||||||
pic.paste(img, (0, int(height)))
|
pic.paste(img, (0, int(height)))
|
||||||
@ -1312,14 +1186,12 @@ class RAGFlowPdfParser:
|
|||||||
pn = bx["page_number"]
|
pn = bx["page_number"]
|
||||||
top = bx["top"] - self.page_cum_height[pn - 1]
|
top = bx["top"] - self.page_cum_height[pn - 1]
|
||||||
bott = bx["bottom"] - self.page_cum_height[pn - 1]
|
bott = bx["bottom"] - self.page_cum_height[pn - 1]
|
||||||
poss.append((pn, bx["x0"], bx["x1"], top, min(
|
poss.append((pn, bx["x0"], bx["x1"], top, min(bott, self.page_images[pn - 1].size[1] / ZM)))
|
||||||
bott, self.page_images[pn - 1].size[1] / ZM)))
|
|
||||||
while bott * ZM > self.page_images[pn - 1].size[1]:
|
while bott * ZM > self.page_images[pn - 1].size[1]:
|
||||||
bott -= self.page_images[pn - 1].size[1] / ZM
|
bott -= self.page_images[pn - 1].size[1] / ZM
|
||||||
top = 0
|
top = 0
|
||||||
pn += 1
|
pn += 1
|
||||||
poss.append((pn, bx["x0"], bx["x1"], top, min(
|
poss.append((pn, bx["x0"], bx["x1"], top, min(bott, self.page_images[pn - 1].size[1] / ZM)))
|
||||||
bott, self.page_images[pn - 1].size[1] / ZM)))
|
|
||||||
return poss
|
return poss
|
||||||
|
|
||||||
|
|
||||||
@ -1328,9 +1200,7 @@ class PlainParser:
|
|||||||
self.outlines = []
|
self.outlines = []
|
||||||
lines = []
|
lines = []
|
||||||
try:
|
try:
|
||||||
self.pdf = pdf2_read(
|
self.pdf = pdf2_read(filename if isinstance(filename, str) else BytesIO(filename))
|
||||||
filename if isinstance(
|
|
||||||
filename, str) else BytesIO(filename))
|
|
||||||
for page in self.pdf.pages[from_page:to_page]:
|
for page in self.pdf.pages[from_page:to_page]:
|
||||||
lines.extend([t for t in page.extract_text().split("\n")])
|
lines.extend([t for t in page.extract_text().split("\n")])
|
||||||
|
|
||||||
@ -1367,10 +1237,8 @@ class VisionParser(RAGFlowPdfParser):
|
|||||||
def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None):
|
def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None):
|
||||||
try:
|
try:
|
||||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||||
self.pdf = pdfplumber.open(fnm) if isinstance(
|
self.pdf = pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm))
|
||||||
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in enumerate(self.pdf.pages[page_from:page_to])]
|
||||||
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
|
||||||
enumerate(self.pdf.pages[page_from:page_to])]
|
|
||||||
self.total_page = len(self.pdf.pages)
|
self.total_page = len(self.pdf.pages)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.page_images = None
|
self.page_images = None
|
||||||
@ -1397,15 +1265,15 @@ class VisionParser(RAGFlowPdfParser):
|
|||||||
text = picture_vision_llm_chunk(
|
text = picture_vision_llm_chunk(
|
||||||
binary=img_binary,
|
binary=img_binary,
|
||||||
vision_model=self.vision_model,
|
vision_model=self.vision_model,
|
||||||
prompt=vision_llm_describe_prompt(page=pdf_page_num+1),
|
prompt=vision_llm_describe_prompt(page=pdf_page_num + 1),
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
if kwargs.get("callback"):
|
if kwargs.get("callback"):
|
||||||
kwargs["callback"](idx*1./len(self.page_images), f"Processed: {idx+1}/{len(self.page_images)}")
|
kwargs["callback"](idx * 1.0 / len(self.page_images), f"Processed: {idx + 1}/{len(self.page_images)}")
|
||||||
|
|
||||||
if text:
|
if text:
|
||||||
width, height = self.page_images[idx].size
|
width, height = self.page_images[idx].size
|
||||||
all_docs.append((text, f"{pdf_page_num+1} 0 {width/zoomin} 0 {height/zoomin}"))
|
all_docs.append((text, f"{pdf_page_num + 1} 0 {width / zoomin} 0 {height / zoomin}"))
|
||||||
return all_docs, []
|
return all_docs, []
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -16,24 +16,28 @@
|
|||||||
import io
|
import io
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
import pdfplumber
|
import pdfplumber
|
||||||
|
|
||||||
from .ocr import OCR
|
from .ocr import OCR
|
||||||
from .recognizer import Recognizer
|
from .recognizer import Recognizer
|
||||||
|
from .layout_recognizer import AscendLayoutRecognizer
|
||||||
from .layout_recognizer import LayoutRecognizer4YOLOv10 as LayoutRecognizer
|
from .layout_recognizer import LayoutRecognizer4YOLOv10 as LayoutRecognizer
|
||||||
from .table_structure_recognizer import TableStructureRecognizer
|
from .table_structure_recognizer import TableStructureRecognizer
|
||||||
|
|
||||||
|
|
||||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||||
if LOCK_KEY_pdfplumber not in sys.modules:
|
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def init_in_out(args):
|
def init_in_out(args):
|
||||||
from PIL import Image
|
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from api.utils.file_utils import traversal_files
|
from api.utils.file_utils import traversal_files
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
@ -44,8 +48,7 @@ def init_in_out(args):
|
|||||||
nonlocal outputs, images
|
nonlocal outputs, images
|
||||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||||
pdf = pdfplumber.open(fnm)
|
pdf = pdfplumber.open(fnm)
|
||||||
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in enumerate(pdf.pages)]
|
||||||
enumerate(pdf.pages)]
|
|
||||||
|
|
||||||
for i, page in enumerate(images):
|
for i, page in enumerate(images):
|
||||||
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")
|
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")
|
||||||
@ -57,10 +60,10 @@ def init_in_out(args):
|
|||||||
pdf_pages(fnm)
|
pdf_pages(fnm)
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
fp = open(fnm, 'rb')
|
fp = open(fnm, "rb")
|
||||||
binary = fp.read()
|
binary = fp.read()
|
||||||
fp.close()
|
fp.close()
|
||||||
images.append(Image.open(io.BytesIO(binary)).convert('RGB'))
|
images.append(Image.open(io.BytesIO(binary)).convert("RGB"))
|
||||||
outputs.append(os.path.split(fnm)[-1])
|
outputs.append(os.path.split(fnm)[-1])
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@ -81,6 +84,7 @@ __all__ = [
|
|||||||
"OCR",
|
"OCR",
|
||||||
"Recognizer",
|
"Recognizer",
|
||||||
"LayoutRecognizer",
|
"LayoutRecognizer",
|
||||||
|
"AscendLayoutRecognizer",
|
||||||
"TableStructureRecognizer",
|
"TableStructureRecognizer",
|
||||||
"init_in_out",
|
"init_in_out",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
@ -45,28 +47,22 @@ class LayoutRecognizer(Recognizer):
|
|||||||
|
|
||||||
def __init__(self, domain):
|
def __init__(self, domain):
|
||||||
try:
|
try:
|
||||||
model_dir = os.path.join(
|
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
|
||||||
get_project_base_directory(),
|
|
||||||
"rag/res/deepdoc")
|
|
||||||
super().__init__(self.labels, domain, model_dir)
|
super().__init__(self.labels, domain, model_dir)
|
||||||
except Exception:
|
except Exception:
|
||||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir_use_symlinks=False)
|
||||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
|
||||||
local_dir_use_symlinks=False)
|
|
||||||
super().__init__(self.labels, domain, model_dir)
|
super().__init__(self.labels, domain, model_dir)
|
||||||
|
|
||||||
self.garbage_layouts = ["footer", "header", "reference"]
|
self.garbage_layouts = ["footer", "header", "reference"]
|
||||||
self.client = None
|
self.client = None
|
||||||
if os.environ.get("TENSORRT_DLA_SVR"):
|
if os.environ.get("TENSORRT_DLA_SVR"):
|
||||||
from deepdoc.vision.dla_cli import DLAClient
|
from deepdoc.vision.dla_cli import DLAClient
|
||||||
|
|
||||||
self.client = DLAClient(os.environ["TENSORRT_DLA_SVR"])
|
self.client = DLAClient(os.environ["TENSORRT_DLA_SVR"])
|
||||||
|
|
||||||
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
|
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
|
||||||
def __is_garbage(b):
|
def __is_garbage(b):
|
||||||
patt = [r"^•+$", "^[0-9]{1,2} / ?[0-9]{1,2}$",
|
patt = [r"^•+$", "^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", "\\(cid *: *[0-9]+ *\\)"]
|
||||||
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
|
|
||||||
"\\(cid *: *[0-9]+ *\\)"
|
|
||||||
]
|
|
||||||
return any([re.search(p, b["text"]) for p in patt])
|
return any([re.search(p, b["text"]) for p in patt])
|
||||||
|
|
||||||
if self.client:
|
if self.client:
|
||||||
@ -82,18 +78,23 @@ class LayoutRecognizer(Recognizer):
|
|||||||
page_layout = []
|
page_layout = []
|
||||||
for pn, lts in enumerate(layouts):
|
for pn, lts in enumerate(layouts):
|
||||||
bxs = ocr_res[pn]
|
bxs = ocr_res[pn]
|
||||||
lts = [{"type": b["type"],
|
lts = [
|
||||||
|
{
|
||||||
|
"type": b["type"],
|
||||||
"score": float(b["score"]),
|
"score": float(b["score"]),
|
||||||
"x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor,
|
"x0": b["bbox"][0] / scale_factor,
|
||||||
"top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
|
"x1": b["bbox"][2] / scale_factor,
|
||||||
|
"top": b["bbox"][1] / scale_factor,
|
||||||
|
"bottom": b["bbox"][-1] / scale_factor,
|
||||||
"page_number": pn,
|
"page_number": pn,
|
||||||
} for b in lts if float(b["score"]) >= 0.4 or b["type"] not in self.garbage_layouts]
|
}
|
||||||
lts = self.sort_Y_firstly(lts, np.mean(
|
for b in lts
|
||||||
[lt["bottom"] - lt["top"] for lt in lts]) / 2)
|
if float(b["score"]) >= 0.4 or b["type"] not in self.garbage_layouts
|
||||||
|
]
|
||||||
|
lts = self.sort_Y_firstly(lts, np.mean([lt["bottom"] - lt["top"] for lt in lts]) / 2)
|
||||||
lts = self.layouts_cleanup(bxs, lts)
|
lts = self.layouts_cleanup(bxs, lts)
|
||||||
page_layout.append(lts)
|
page_layout.append(lts)
|
||||||
|
|
||||||
# Tag layout type, layouts are ready
|
|
||||||
def findLayout(ty):
|
def findLayout(ty):
|
||||||
nonlocal bxs, lts, self
|
nonlocal bxs, lts, self
|
||||||
lts_ = [lt for lt in lts if lt["type"] == ty]
|
lts_ = [lt for lt in lts if lt["type"] == ty]
|
||||||
@ -106,21 +107,17 @@ class LayoutRecognizer(Recognizer):
|
|||||||
bxs.pop(i)
|
bxs.pop(i)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ii = self.find_overlapped_with_threshold(bxs[i], lts_,
|
ii = self.find_overlapped_with_threshold(bxs[i], lts_, thr=0.4)
|
||||||
thr=0.4)
|
if ii is None:
|
||||||
if ii is None: # belong to nothing
|
|
||||||
bxs[i]["layout_type"] = ""
|
bxs[i]["layout_type"] = ""
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
lts_[ii]["visited"] = True
|
lts_[ii]["visited"] = True
|
||||||
keep_feats = [
|
keep_feats = [
|
||||||
lts_[
|
lts_[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor,
|
||||||
ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor,
|
lts_[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor,
|
||||||
lts_[
|
|
||||||
ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor,
|
|
||||||
]
|
]
|
||||||
if drop and lts_[
|
if drop and lts_[ii]["type"] in self.garbage_layouts and not any(keep_feats):
|
||||||
ii]["type"] in self.garbage_layouts and not any(keep_feats):
|
|
||||||
if lts_[ii]["type"] not in garbages:
|
if lts_[ii]["type"] not in garbages:
|
||||||
garbages[lts_[ii]["type"]] = []
|
garbages[lts_[ii]["type"]] = []
|
||||||
garbages[lts_[ii]["type"]].append(bxs[i]["text"])
|
garbages[lts_[ii]["type"]].append(bxs[i]["text"])
|
||||||
@ -128,17 +125,14 @@ class LayoutRecognizer(Recognizer):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
bxs[i]["layoutno"] = f"{ty}-{ii}"
|
bxs[i]["layoutno"] = f"{ty}-{ii}"
|
||||||
bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[
|
bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[ii]["type"] != "equation" else "figure"
|
||||||
ii]["type"] != "equation" else "figure"
|
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
for lt in ["footer", "header", "reference", "figure caption",
|
for lt in ["footer", "header", "reference", "figure caption", "table caption", "title", "table", "text", "figure", "equation"]:
|
||||||
"table caption", "title", "table", "text", "figure", "equation"]:
|
|
||||||
findLayout(lt)
|
findLayout(lt)
|
||||||
|
|
||||||
# add box to figure layouts which has not text box
|
# add box to figure layouts which has not text box
|
||||||
for i, lt in enumerate(
|
for i, lt in enumerate([lt for lt in lts if lt["type"] in ["figure", "equation"]]):
|
||||||
[lt for lt in lts if lt["type"] in ["figure", "equation"]]):
|
|
||||||
if lt.get("visited"):
|
if lt.get("visited"):
|
||||||
continue
|
continue
|
||||||
lt = deepcopy(lt)
|
lt = deepcopy(lt)
|
||||||
@ -206,13 +200,11 @@ class LayoutRecognizer4YOLOv10(LayoutRecognizer):
|
|||||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||||
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
|
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
|
||||||
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
|
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
|
||||||
img = cv2.copyMakeBorder(
|
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) # add border
|
||||||
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
|
|
||||||
) # add border
|
|
||||||
img /= 255.0
|
img /= 255.0
|
||||||
img = img.transpose(2, 0, 1)
|
img = img.transpose(2, 0, 1)
|
||||||
img = img[np.newaxis, :, :, :].astype(np.float32)
|
img = img[np.newaxis, :, :, :].astype(np.float32)
|
||||||
inputs.append({self.input_names[0]: img, "scale_factor": [shape[1]/ww, shape[0]/hh, dw, dh]})
|
inputs.append({self.input_names[0]: img, "scale_factor": [shape[1] / ww, shape[0] / hh, dw, dh]})
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@ -230,8 +222,7 @@ class LayoutRecognizer4YOLOv10(LayoutRecognizer):
|
|||||||
boxes[:, 2] -= inputs["scale_factor"][2]
|
boxes[:, 2] -= inputs["scale_factor"][2]
|
||||||
boxes[:, 1] -= inputs["scale_factor"][3]
|
boxes[:, 1] -= inputs["scale_factor"][3]
|
||||||
boxes[:, 3] -= inputs["scale_factor"][3]
|
boxes[:, 3] -= inputs["scale_factor"][3]
|
||||||
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0],
|
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]])
|
||||||
inputs["scale_factor"][1]])
|
|
||||||
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
|
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
|
||||||
|
|
||||||
unique_class_ids = np.unique(class_ids)
|
unique_class_ids = np.unique(class_ids)
|
||||||
@ -243,8 +234,223 @@ class LayoutRecognizer4YOLOv10(LayoutRecognizer):
|
|||||||
class_keep_boxes = nms(class_boxes, class_scores, 0.45)
|
class_keep_boxes = nms(class_boxes, class_scores, 0.45)
|
||||||
indices.extend(class_indices[class_keep_boxes])
|
indices.extend(class_indices[class_keep_boxes])
|
||||||
|
|
||||||
return [{
|
return [{"type": self.label_list[class_ids[i]].lower(), "bbox": [float(t) for t in boxes[i].tolist()], "score": float(scores[i])} for i in indices]
|
||||||
"type": self.label_list[class_ids[i]].lower(),
|
|
||||||
"bbox": [float(t) for t in boxes[i].tolist()],
|
|
||||||
"score": float(scores[i])
|
class AscendLayoutRecognizer(Recognizer):
|
||||||
} for i in indices]
|
labels = [
|
||||||
|
"title",
|
||||||
|
"Text",
|
||||||
|
"Reference",
|
||||||
|
"Figure",
|
||||||
|
"Figure caption",
|
||||||
|
"Table",
|
||||||
|
"Table caption",
|
||||||
|
"Table caption",
|
||||||
|
"Equation",
|
||||||
|
"Figure caption",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, domain):
|
||||||
|
from ais_bench.infer.interface import InferSession
|
||||||
|
|
||||||
|
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
|
||||||
|
model_file_path = os.path.join(model_dir, domain + ".om")
|
||||||
|
|
||||||
|
if not os.path.exists(model_file_path):
|
||||||
|
raise ValueError(f"Model file not found: {model_file_path}")
|
||||||
|
|
||||||
|
device_id = int(os.getenv("ASCEND_LAYOUT_RECOGNIZER_DEVICE_ID", 0))
|
||||||
|
self.session = InferSession(device_id=device_id, model_path=model_file_path)
|
||||||
|
self.input_shape = self.session.get_inputs()[0].shape[2:4] # H,W
|
||||||
|
self.garbage_layouts = ["footer", "header", "reference"]
|
||||||
|
|
||||||
|
def preprocess(self, image_list):
|
||||||
|
inputs = []
|
||||||
|
H, W = self.input_shape
|
||||||
|
for img in image_list:
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
|
||||||
|
|
||||||
|
r = min(H / h, W / w)
|
||||||
|
new_unpad = (int(round(w * r)), int(round(h * r)))
|
||||||
|
dw, dh = (W - new_unpad[0]) / 2.0, (H - new_unpad[1]) / 2.0
|
||||||
|
|
||||||
|
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||||
|
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||||
|
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||||
|
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
|
||||||
|
|
||||||
|
img /= 255.0
|
||||||
|
img = img.transpose(2, 0, 1)[np.newaxis, :, :, :].astype(np.float32)
|
||||||
|
|
||||||
|
inputs.append(
|
||||||
|
{
|
||||||
|
"image": img,
|
||||||
|
"scale_factor": [w / new_unpad[0], h / new_unpad[1]],
|
||||||
|
"pad": [dw, dh],
|
||||||
|
"orig_shape": [h, w],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def postprocess(self, boxes, inputs, thr=0.25):
|
||||||
|
arr = np.squeeze(boxes)
|
||||||
|
if arr.ndim == 1:
|
||||||
|
arr = arr.reshape(1, -1)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
if arr.shape[1] == 6:
|
||||||
|
# [x1,y1,x2,y2,score,cls]
|
||||||
|
m = arr[:, 4] >= thr
|
||||||
|
arr = arr[m]
|
||||||
|
if arr.size == 0:
|
||||||
|
return []
|
||||||
|
xyxy = arr[:, :4].astype(np.float32)
|
||||||
|
scores = arr[:, 4].astype(np.float32)
|
||||||
|
cls_ids = arr[:, 5].astype(np.int32)
|
||||||
|
|
||||||
|
if "pad" in inputs:
|
||||||
|
dw, dh = inputs["pad"]
|
||||||
|
sx, sy = inputs["scale_factor"]
|
||||||
|
xyxy[:, [0, 2]] -= dw
|
||||||
|
xyxy[:, [1, 3]] -= dh
|
||||||
|
xyxy *= np.array([sx, sy, sx, sy], dtype=np.float32)
|
||||||
|
else:
|
||||||
|
# backup
|
||||||
|
sx, sy = inputs["scale_factor"]
|
||||||
|
xyxy *= np.array([sx, sy, sx, sy], dtype=np.float32)
|
||||||
|
|
||||||
|
keep_indices = []
|
||||||
|
for c in np.unique(cls_ids):
|
||||||
|
idx = np.where(cls_ids == c)[0]
|
||||||
|
k = nms(xyxy[idx], scores[idx], 0.45)
|
||||||
|
keep_indices.extend(idx[k])
|
||||||
|
|
||||||
|
for i in keep_indices:
|
||||||
|
cid = int(cls_ids[i])
|
||||||
|
if 0 <= cid < len(self.labels):
|
||||||
|
results.append({"type": self.labels[cid].lower(), "bbox": [float(t) for t in xyxy[i].tolist()], "score": float(scores[i])})
|
||||||
|
return results
|
||||||
|
|
||||||
|
raise ValueError(f"Unexpected output shape: {arr.shape}")
|
||||||
|
|
||||||
|
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
|
||||||
|
import re
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
assert len(image_list) == len(ocr_res)
|
||||||
|
|
||||||
|
images = [np.array(im) if not isinstance(im, np.ndarray) else im for im in image_list]
|
||||||
|
layouts_all_pages = [] # list of list[{"type","score","bbox":[x1,y1,x2,y2]}]
|
||||||
|
|
||||||
|
conf_thr = max(thr, 0.08)
|
||||||
|
|
||||||
|
batch_loop_cnt = math.ceil(float(len(images)) / batch_size)
|
||||||
|
for bi in range(batch_loop_cnt):
|
||||||
|
s = bi * batch_size
|
||||||
|
e = min((bi + 1) * batch_size, len(images))
|
||||||
|
batch_images = images[s:e]
|
||||||
|
|
||||||
|
inputs_list = self.preprocess(batch_images)
|
||||||
|
logging.debug("preprocess done")
|
||||||
|
|
||||||
|
for ins in inputs_list:
|
||||||
|
feeds = [ins["image"]]
|
||||||
|
out_list = self.session.infer(feeds=feeds, mode="static")
|
||||||
|
|
||||||
|
for out in out_list:
|
||||||
|
lts = self.postprocess(out, ins, conf_thr)
|
||||||
|
|
||||||
|
page_lts = []
|
||||||
|
for b in lts:
|
||||||
|
if float(b["score"]) >= 0.4 or b["type"] not in self.garbage_layouts:
|
||||||
|
x0, y0, x1, y1 = b["bbox"]
|
||||||
|
page_lts.append(
|
||||||
|
{
|
||||||
|
"type": b["type"],
|
||||||
|
"score": float(b["score"]),
|
||||||
|
"x0": float(x0) / scale_factor,
|
||||||
|
"x1": float(x1) / scale_factor,
|
||||||
|
"top": float(y0) / scale_factor,
|
||||||
|
"bottom": float(y1) / scale_factor,
|
||||||
|
"page_number": len(layouts_all_pages),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
layouts_all_pages.append(page_lts)
|
||||||
|
|
||||||
|
def _is_garbage_text(box):
|
||||||
|
patt = [r"^•+$", r"^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", r"^http://[^ ]{12,}", r"\(cid *: *[0-9]+ *\)"]
|
||||||
|
return any(re.search(p, box.get("text", "")) for p in patt)
|
||||||
|
|
||||||
|
boxes_out = []
|
||||||
|
page_layout = []
|
||||||
|
garbages = {}
|
||||||
|
|
||||||
|
for pn, lts in enumerate(layouts_all_pages):
|
||||||
|
if lts:
|
||||||
|
avg_h = np.mean([lt["bottom"] - lt["top"] for lt in lts])
|
||||||
|
lts = self.sort_Y_firstly(lts, avg_h / 2 if avg_h > 0 else 0)
|
||||||
|
|
||||||
|
bxs = ocr_res[pn]
|
||||||
|
lts = self.layouts_cleanup(bxs, lts)
|
||||||
|
page_layout.append(lts)
|
||||||
|
|
||||||
|
def _tag_layout(ty):
|
||||||
|
nonlocal bxs, lts
|
||||||
|
lts_of_ty = [lt for lt in lts if lt["type"] == ty]
|
||||||
|
i = 0
|
||||||
|
while i < len(bxs):
|
||||||
|
if bxs[i].get("layout_type"):
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
if _is_garbage_text(bxs[i]):
|
||||||
|
bxs.pop(i)
|
||||||
|
continue
|
||||||
|
|
||||||
|
ii = self.find_overlapped_with_threshold(bxs[i], lts_of_ty, thr=0.4)
|
||||||
|
if ii is None:
|
||||||
|
bxs[i]["layout_type"] = ""
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
lts_of_ty[ii]["visited"] = True
|
||||||
|
|
||||||
|
keep_feats = [
|
||||||
|
lts_of_ty[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].shape[0] * 0.9 / scale_factor,
|
||||||
|
lts_of_ty[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].shape[0] * 0.1 / scale_factor,
|
||||||
|
]
|
||||||
|
if drop and lts_of_ty[ii]["type"] in self.garbage_layouts and not any(keep_feats):
|
||||||
|
garbages.setdefault(lts_of_ty[ii]["type"], []).append(bxs[i].get("text", ""))
|
||||||
|
bxs.pop(i)
|
||||||
|
continue
|
||||||
|
|
||||||
|
bxs[i]["layoutno"] = f"{ty}-{ii}"
|
||||||
|
bxs[i]["layout_type"] = lts_of_ty[ii]["type"] if lts_of_ty[ii]["type"] != "equation" else "figure"
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
for ty in ["footer", "header", "reference", "figure caption", "table caption", "title", "table", "text", "figure", "equation"]:
|
||||||
|
_tag_layout(ty)
|
||||||
|
|
||||||
|
figs = [lt for lt in lts if lt["type"] in ["figure", "equation"]]
|
||||||
|
for i, lt in enumerate(figs):
|
||||||
|
if lt.get("visited"):
|
||||||
|
continue
|
||||||
|
lt = deepcopy(lt)
|
||||||
|
lt.pop("type", None)
|
||||||
|
lt["text"] = ""
|
||||||
|
lt["layout_type"] = "figure"
|
||||||
|
lt["layoutno"] = f"figure-{i}"
|
||||||
|
bxs.append(lt)
|
||||||
|
|
||||||
|
boxes_out.extend(bxs)
|
||||||
|
|
||||||
|
garbag_set = set()
|
||||||
|
for k, lst in garbages.items():
|
||||||
|
cnt = Counter(lst)
|
||||||
|
for g, c in cnt.items():
|
||||||
|
if c > 1:
|
||||||
|
garbag_set.add(g)
|
||||||
|
|
||||||
|
ocr_res_new = [b for b in boxes_out if b["text"].strip() not in garbag_set]
|
||||||
|
return ocr_res_new, page_layout
|
||||||
|
|||||||
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import copy
|
import copy
|
||||||
import time
|
import time
|
||||||
@ -348,6 +348,13 @@ class TextRecognizer:
|
|||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
# close session and release manually
|
||||||
|
logging.info('Close text recognizer.')
|
||||||
|
if hasattr(self, "predictor"):
|
||||||
|
del self.predictor
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
def __call__(self, img_list):
|
def __call__(self, img_list):
|
||||||
img_num = len(img_list)
|
img_num = len(img_list)
|
||||||
# Calculate the aspect ratio of all text bars
|
# Calculate the aspect ratio of all text bars
|
||||||
@ -395,6 +402,9 @@ class TextRecognizer:
|
|||||||
|
|
||||||
return rec_res, time.time() - st
|
return rec_res, time.time() - st
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
class TextDetector:
|
class TextDetector:
|
||||||
def __init__(self, model_dir, device_id: int | None = None):
|
def __init__(self, model_dir, device_id: int | None = None):
|
||||||
@ -479,6 +489,12 @@ class TextDetector:
|
|||||||
dt_boxes = np.array(dt_boxes_new)
|
dt_boxes = np.array(dt_boxes_new)
|
||||||
return dt_boxes
|
return dt_boxes
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
logging.info("Close text detector.")
|
||||||
|
if hasattr(self, "predictor"):
|
||||||
|
del self.predictor
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
ori_im = img.copy()
|
ori_im = img.copy()
|
||||||
data = {'image': img}
|
data = {'image': img}
|
||||||
@ -508,6 +524,9 @@ class TextDetector:
|
|||||||
|
|
||||||
return dt_boxes, time.time() - st
|
return dt_boxes, time.time() - st
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
class OCR:
|
class OCR:
|
||||||
def __init__(self, model_dir=None):
|
def __init__(self, model_dir=None):
|
||||||
|
|||||||
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
@ -406,6 +406,12 @@ class Recognizer:
|
|||||||
"score": float(scores[i])
|
"score": float(scores[i])
|
||||||
} for i in indices]
|
} for i in indices]
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
logging.info("Close recognizer.")
|
||||||
|
if hasattr(self, "ort_sess"):
|
||||||
|
del self.ort_sess
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
def __call__(self, image_list, thr=0.7, batch_size=16):
|
def __call__(self, image_list, thr=0.7, batch_size=16):
|
||||||
res = []
|
res = []
|
||||||
images = []
|
images = []
|
||||||
@ -430,5 +436,7 @@ class Recognizer:
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from huggingface_hub import snapshot_download
|
|||||||
|
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
|
|
||||||
from .recognizer import Recognizer
|
from .recognizer import Recognizer
|
||||||
|
|
||||||
|
|
||||||
@ -38,31 +39,49 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
try:
|
try:
|
||||||
super().__init__(self.labels, "tsr", os.path.join(
|
super().__init__(self.labels, "tsr", os.path.join(get_project_base_directory(), "rag/res/deepdoc"))
|
||||||
get_project_base_directory(),
|
|
||||||
"rag/res/deepdoc"))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc",
|
super().__init__(
|
||||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
self.labels,
|
||||||
local_dir_use_symlinks=False))
|
"tsr",
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="InfiniFlow/deepdoc",
|
||||||
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, images, thr=0.2):
|
def __call__(self, images, thr=0.2):
|
||||||
tbls = super().__call__(images, thr)
|
table_structure_recognizer_type = os.getenv("TABLE_STRUCTURE_RECOGNIZER_TYPE", "onnx").lower()
|
||||||
|
if table_structure_recognizer_type not in ["onnx", "ascend"]:
|
||||||
|
raise RuntimeError("Unsupported table structure recognizer type.")
|
||||||
|
|
||||||
|
if table_structure_recognizer_type == "onnx":
|
||||||
|
logging.debug("Using Onnx table structure recognizer", flush=True)
|
||||||
|
tbls = super().__call__(images, thr)
|
||||||
|
else: # ascend
|
||||||
|
logging.debug("Using Ascend table structure recognizer", flush=True)
|
||||||
|
tbls = self._run_ascend_tsr(images, thr)
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
# align left&right for rows, align top&bottom for columns
|
# align left&right for rows, align top&bottom for columns
|
||||||
for tbl in tbls:
|
for tbl in tbls:
|
||||||
lts = [{"label": b["type"],
|
lts = [
|
||||||
|
{
|
||||||
|
"label": b["type"],
|
||||||
"score": b["score"],
|
"score": b["score"],
|
||||||
"x0": b["bbox"][0], "x1": b["bbox"][2],
|
"x0": b["bbox"][0],
|
||||||
"top": b["bbox"][1], "bottom": b["bbox"][-1]
|
"x1": b["bbox"][2],
|
||||||
} for b in tbl]
|
"top": b["bbox"][1],
|
||||||
|
"bottom": b["bbox"][-1],
|
||||||
|
}
|
||||||
|
for b in tbl
|
||||||
|
]
|
||||||
if not lts:
|
if not lts:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
left = [b["x0"] for b in lts if b["label"].find(
|
left = [b["x0"] for b in lts if b["label"].find("row") > 0 or b["label"].find("header") > 0]
|
||||||
"row") > 0 or b["label"].find("header") > 0]
|
right = [b["x1"] for b in lts if b["label"].find("row") > 0 or b["label"].find("header") > 0]
|
||||||
right = [b["x1"] for b in lts if b["label"].find(
|
|
||||||
"row") > 0 or b["label"].find("header") > 0]
|
|
||||||
if not left:
|
if not left:
|
||||||
continue
|
continue
|
||||||
left = np.mean(left) if len(left) > 4 else np.min(left)
|
left = np.mean(left) if len(left) > 4 else np.min(left)
|
||||||
@ -93,11 +112,8 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_caption(bx):
|
def is_caption(bx):
|
||||||
patt = [
|
patt = [r"[图表]+[ 0-9::]{2,}"]
|
||||||
r"[图表]+[ 0-9::]{2,}"
|
if any([re.match(p, bx["text"].strip()) for p in patt]) or bx.get("layout_type", "").find("caption") >= 0:
|
||||||
]
|
|
||||||
if any([re.match(p, bx["text"].strip()) for p in patt]) \
|
|
||||||
or bx.get("layout_type", "").find("caption") >= 0:
|
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -115,7 +131,7 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
(r"^[0-9A-Z/\._~-]+$", "Ca"),
|
(r"^[0-9A-Z/\._~-]+$", "Ca"),
|
||||||
(r"^[A-Z]*[a-z' -]+$", "En"),
|
(r"^[A-Z]*[a-z' -]+$", "En"),
|
||||||
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
|
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
|
||||||
(r"^.{1}$", "Sg")
|
(r"^.{1}$", "Sg"),
|
||||||
]
|
]
|
||||||
for p, n in patt:
|
for p, n in patt:
|
||||||
if re.search(p, b["text"].strip()):
|
if re.search(p, b["text"].strip()):
|
||||||
@ -156,21 +172,19 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
|
rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
|
||||||
rowh = np.min(rowh) if rowh else 0
|
rowh = np.min(rowh) if rowh else 0
|
||||||
boxes = Recognizer.sort_R_firstly(boxes, rowh / 2)
|
boxes = Recognizer.sort_R_firstly(boxes, rowh / 2)
|
||||||
#for b in boxes:print(b)
|
# for b in boxes:print(b)
|
||||||
boxes[0]["rn"] = 0
|
boxes[0]["rn"] = 0
|
||||||
rows = [[boxes[0]]]
|
rows = [[boxes[0]]]
|
||||||
btm = boxes[0]["bottom"]
|
btm = boxes[0]["bottom"]
|
||||||
for b in boxes[1:]:
|
for b in boxes[1:]:
|
||||||
b["rn"] = len(rows) - 1
|
b["rn"] = len(rows) - 1
|
||||||
lst_r = rows[-1]
|
lst_r = rows[-1]
|
||||||
if lst_r[-1].get("R", "") != b.get("R", "") \
|
if lst_r[-1].get("R", "") != b.get("R", "") or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")): # new row
|
||||||
or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
|
|
||||||
): # new row
|
|
||||||
btm = b["bottom"]
|
btm = b["bottom"]
|
||||||
b["rn"] += 1
|
b["rn"] += 1
|
||||||
rows.append([b])
|
rows.append([b])
|
||||||
continue
|
continue
|
||||||
btm = (btm + b["bottom"]) / 2.
|
btm = (btm + b["bottom"]) / 2.0
|
||||||
rows[-1].append(b)
|
rows[-1].append(b)
|
||||||
|
|
||||||
colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
|
colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
|
||||||
@ -186,14 +200,14 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
for b in boxes[1:]:
|
for b in boxes[1:]:
|
||||||
b["cn"] = len(cols) - 1
|
b["cn"] = len(cols) - 1
|
||||||
lst_c = cols[-1]
|
lst_c = cols[-1]
|
||||||
if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][
|
if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1]["page_number"]) or (
|
||||||
"page_number"]) \
|
b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")
|
||||||
or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col
|
): # new col
|
||||||
right = b["x1"]
|
right = b["x1"]
|
||||||
b["cn"] += 1
|
b["cn"] += 1
|
||||||
cols.append([b])
|
cols.append([b])
|
||||||
continue
|
continue
|
||||||
right = (right + b["x1"]) / 2.
|
right = (right + b["x1"]) / 2.0
|
||||||
cols[-1].append(b)
|
cols[-1].append(b)
|
||||||
|
|
||||||
tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
|
tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
|
||||||
@ -214,10 +228,8 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
if e > 1:
|
if e > 1:
|
||||||
j += 1
|
j += 1
|
||||||
continue
|
continue
|
||||||
f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
|
f = (j > 0 and tbl[ii][j - 1] and tbl[ii][j - 1][0].get("text")) or j == 0
|
||||||
[j - 1][0].get("text")) or j == 0
|
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii][j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
|
||||||
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
|
|
||||||
[j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
|
|
||||||
if f and ff:
|
if f and ff:
|
||||||
j += 1
|
j += 1
|
||||||
continue
|
continue
|
||||||
@ -228,13 +240,11 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
if j > 0 and not f:
|
if j > 0 and not f:
|
||||||
for i in range(len(tbl)):
|
for i in range(len(tbl)):
|
||||||
if tbl[i][j - 1]:
|
if tbl[i][j - 1]:
|
||||||
left = min(left, np.min(
|
left = min(left, np.min([bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
|
||||||
[bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
|
|
||||||
if j + 1 < len(tbl[0]) and not ff:
|
if j + 1 < len(tbl[0]) and not ff:
|
||||||
for i in range(len(tbl)):
|
for i in range(len(tbl)):
|
||||||
if tbl[i][j + 1]:
|
if tbl[i][j + 1]:
|
||||||
right = min(right, np.min(
|
right = min(right, np.min([a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
|
||||||
[a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
|
|
||||||
assert left < 100000 or right < 100000
|
assert left < 100000 or right < 100000
|
||||||
if left < right:
|
if left < right:
|
||||||
for jj in range(j, len(tbl[0])):
|
for jj in range(j, len(tbl[0])):
|
||||||
@ -260,8 +270,7 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
for i in range(len(tbl)):
|
for i in range(len(tbl)):
|
||||||
tbl[i].pop(j)
|
tbl[i].pop(j)
|
||||||
cols.pop(j)
|
cols.pop(j)
|
||||||
assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (
|
assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (len(cols), len(tbl[0]))
|
||||||
len(cols), len(tbl[0]))
|
|
||||||
|
|
||||||
if len(cols) >= 4:
|
if len(cols) >= 4:
|
||||||
# remove single in row
|
# remove single in row
|
||||||
@ -277,10 +286,8 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
if e > 1:
|
if e > 1:
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
|
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1][jj][0].get("text")) or i == 0
|
||||||
[jj][0].get("text")) or i == 0
|
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1][jj][0].get("text")) or i + 1 >= len(tbl)
|
||||||
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
|
|
||||||
[jj][0].get("text")) or i + 1 >= len(tbl)
|
|
||||||
if f and ff:
|
if f and ff:
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
@ -292,13 +299,11 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
if i > 0 and not f:
|
if i > 0 and not f:
|
||||||
for j in range(len(tbl[i - 1])):
|
for j in range(len(tbl[i - 1])):
|
||||||
if tbl[i - 1][j]:
|
if tbl[i - 1][j]:
|
||||||
up = min(up, np.min(
|
up = min(up, np.min([bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
|
||||||
[bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
|
|
||||||
if i + 1 < len(tbl) and not ff:
|
if i + 1 < len(tbl) and not ff:
|
||||||
for j in range(len(tbl[i + 1])):
|
for j in range(len(tbl[i + 1])):
|
||||||
if tbl[i + 1][j]:
|
if tbl[i + 1][j]:
|
||||||
down = min(down, np.min(
|
down = min(down, np.min([a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
|
||||||
[a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
|
|
||||||
assert up < 100000 or down < 100000
|
assert up < 100000 or down < 100000
|
||||||
if up < down:
|
if up < down:
|
||||||
for ii in range(i, len(tbl)):
|
for ii in range(i, len(tbl)):
|
||||||
@ -333,22 +338,15 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
cnt += 1
|
cnt += 1
|
||||||
if max_type == "Nu" and arr[0]["btype"] == "Nu":
|
if max_type == "Nu" and arr[0]["btype"] == "Nu":
|
||||||
continue
|
continue
|
||||||
if any([a.get("H") for a in arr]) \
|
if any([a.get("H") for a in arr]) or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
|
||||||
or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
|
|
||||||
h += 1
|
h += 1
|
||||||
if h / cnt > 0.5:
|
if h / cnt > 0.5:
|
||||||
hdset.add(i)
|
hdset.add(i)
|
||||||
|
|
||||||
if html:
|
if html:
|
||||||
return TableStructureRecognizer.__html_table(cap, hdset,
|
return TableStructureRecognizer.__html_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, True))
|
||||||
TableStructureRecognizer.__cal_spans(boxes, rows,
|
|
||||||
cols, tbl, True)
|
|
||||||
)
|
|
||||||
|
|
||||||
return TableStructureRecognizer.__desc_table(cap, hdset,
|
return TableStructureRecognizer.__desc_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, False), is_english)
|
||||||
TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl,
|
|
||||||
False),
|
|
||||||
is_english)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __html_table(cap, hdset, tbl):
|
def __html_table(cap, hdset, tbl):
|
||||||
@ -367,10 +365,8 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
continue
|
continue
|
||||||
txt = ""
|
txt = ""
|
||||||
if arr:
|
if arr:
|
||||||
h = min(np.min([c["bottom"] - c["top"]
|
h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10)
|
||||||
for c in arr]) / 2, 10)
|
txt = " ".join([c["text"] for c in Recognizer.sort_Y_firstly(arr, h)])
|
||||||
txt = " ".join([c["text"]
|
|
||||||
for c in Recognizer.sort_Y_firstly(arr, h)])
|
|
||||||
txts.append(txt)
|
txts.append(txt)
|
||||||
sp = ""
|
sp = ""
|
||||||
if arr[0].get("colspan"):
|
if arr[0].get("colspan"):
|
||||||
@ -436,15 +432,11 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
if headers[j][k].find(headers[j - 1][k]) >= 0:
|
if headers[j][k].find(headers[j - 1][k]) >= 0:
|
||||||
continue
|
continue
|
||||||
if len(headers[j][k]) > len(headers[j - 1][k]):
|
if len(headers[j][k]) > len(headers[j - 1][k]):
|
||||||
headers[j][k] += (de if headers[j][k]
|
headers[j][k] += (de if headers[j][k] else "") + headers[j - 1][k]
|
||||||
else "") + headers[j - 1][k]
|
|
||||||
else:
|
else:
|
||||||
headers[j][k] = headers[j - 1][k] \
|
headers[j][k] = headers[j - 1][k] + (de if headers[j - 1][k] else "") + headers[j][k]
|
||||||
+ (de if headers[j - 1][k] else "") \
|
|
||||||
+ headers[j][k]
|
|
||||||
|
|
||||||
logging.debug(
|
logging.debug(f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
|
||||||
f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
|
|
||||||
row_txt = []
|
row_txt = []
|
||||||
for i in range(rowno):
|
for i in range(rowno):
|
||||||
if i in hdr_rowno:
|
if i in hdr_rowno:
|
||||||
@ -503,14 +495,10 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def __cal_spans(boxes, rows, cols, tbl, html=True):
|
def __cal_spans(boxes, rows, cols, tbl, html=True):
|
||||||
# caculate span
|
# caculate span
|
||||||
clft = [np.mean([c.get("C_left", c["x0"]) for c in cln])
|
clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) for cln in cols]
|
||||||
for cln in cols]
|
crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) for cln in cols]
|
||||||
crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln])
|
rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) for row in rows]
|
||||||
for cln in cols]
|
rbtm = [np.mean([c.get("R_btm", c["bottom"]) for c in row]) for row in rows]
|
||||||
rtop = [np.mean([c.get("R_top", c["top"]) for c in row])
|
|
||||||
for row in rows]
|
|
||||||
rbtm = [np.mean([c.get("R_btm", c["bottom"])
|
|
||||||
for c in row]) for row in rows]
|
|
||||||
for b in boxes:
|
for b in boxes:
|
||||||
if "SP" not in b:
|
if "SP" not in b:
|
||||||
continue
|
continue
|
||||||
@ -585,3 +573,40 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
tbl[rowspan[0]][colspan[0]] = arr
|
tbl[rowspan[0]][colspan[0]] = arr
|
||||||
|
|
||||||
return tbl
|
return tbl
|
||||||
|
|
||||||
|
def _run_ascend_tsr(self, image_list, thr=0.2, batch_size=16):
|
||||||
|
import math
|
||||||
|
|
||||||
|
from ais_bench.infer.interface import InferSession
|
||||||
|
|
||||||
|
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
|
||||||
|
model_file_path = os.path.join(model_dir, "tsr.om")
|
||||||
|
|
||||||
|
if not os.path.exists(model_file_path):
|
||||||
|
raise ValueError(f"Model file not found: {model_file_path}")
|
||||||
|
|
||||||
|
device_id = int(os.getenv("ASCEND_LAYOUT_RECOGNIZER_DEVICE_ID", 0))
|
||||||
|
session = InferSession(device_id=device_id, model_path=model_file_path)
|
||||||
|
|
||||||
|
images = [np.array(im) if not isinstance(im, np.ndarray) else im for im in image_list]
|
||||||
|
results = []
|
||||||
|
|
||||||
|
conf_thr = max(thr, 0.08)
|
||||||
|
|
||||||
|
batch_loop_cnt = math.ceil(float(len(images)) / batch_size)
|
||||||
|
for bi in range(batch_loop_cnt):
|
||||||
|
s = bi * batch_size
|
||||||
|
e = min((bi + 1) * batch_size, len(images))
|
||||||
|
batch_images = images[s:e]
|
||||||
|
|
||||||
|
inputs_list = self.preprocess(batch_images)
|
||||||
|
for ins in inputs_list:
|
||||||
|
feeds = []
|
||||||
|
if "image" in ins:
|
||||||
|
feeds.append(ins["image"])
|
||||||
|
else:
|
||||||
|
feeds.append(ins[self.input_names[0]])
|
||||||
|
output_list = session.infer(feeds=feeds, mode="static")
|
||||||
|
bb = self.postprocess(output_list, ins, conf_thr)
|
||||||
|
results.append(bb)
|
||||||
|
return results
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
ragflow:
|
ragflow:
|
||||||
host: ${RAGFLOW_HOST:-0.0.0.0}
|
host: ${RAGFLOW_HOST:-0.0.0.0}
|
||||||
http_port: 9380
|
http_port: 9380
|
||||||
|
admin:
|
||||||
|
host: ${RAGFLOW_HOST:-0.0.0.0}
|
||||||
|
http_port: 9381
|
||||||
mysql:
|
mysql:
|
||||||
name: '${MYSQL_DBNAME:-rag_flow}'
|
name: '${MYSQL_DBNAME:-rag_flow}'
|
||||||
user: '${MYSQL_USER:-root}'
|
user: '${MYSQL_USER:-root}'
|
||||||
@ -29,7 +32,6 @@ redis:
|
|||||||
db: 1
|
db: 1
|
||||||
password: '${REDIS_PASSWORD:-infini_rag_flow}'
|
password: '${REDIS_PASSWORD:-infini_rag_flow}'
|
||||||
host: '${REDIS_HOST:-redis}:6379'
|
host: '${REDIS_HOST:-redis}:6379'
|
||||||
|
|
||||||
# postgres:
|
# postgres:
|
||||||
# name: '${POSTGRES_DBNAME:-rag_flow}'
|
# name: '${POSTGRES_DBNAME:-rag_flow}'
|
||||||
# user: '${POSTGRES_USER:-rag_flow}'
|
# user: '${POSTGRES_USER:-rag_flow}'
|
||||||
@ -65,15 +67,26 @@ redis:
|
|||||||
# secret: 'secret'
|
# secret: 'secret'
|
||||||
# tenant_id: 'tenant_id'
|
# tenant_id: 'tenant_id'
|
||||||
# container_name: 'container_name'
|
# container_name: 'container_name'
|
||||||
|
# The OSS object storage uses the MySQL configuration above by default. If you need to switch to another object storage service, please uncomment and configure the following parameters.
|
||||||
|
# opendal:
|
||||||
|
# scheme: 'mysql' # Storage type, such as s3, oss, azure, etc.
|
||||||
|
# config:
|
||||||
|
# oss_table: 'opendal_storage'
|
||||||
# user_default_llm:
|
# user_default_llm:
|
||||||
# factory: 'Tongyi-Qianwen'
|
# factory: 'BAAI'
|
||||||
# api_key: 'sk-xxxxxxxxxxxxx'
|
# api_key: 'backup'
|
||||||
# base_url: ''
|
# base_url: 'backup_base_url'
|
||||||
# default_models:
|
# default_models:
|
||||||
# chat_model: 'qwen-plus'
|
# chat_model:
|
||||||
# embedding_model: 'BAAI/bge-large-zh-v1.5@BAAI'
|
# name: 'qwen2.5-7b-instruct'
|
||||||
# rerank_model: ''
|
# factory: 'xxxx'
|
||||||
# asr_model: ''
|
# api_key: 'xxxx'
|
||||||
|
# base_url: 'https://api.xx.com'
|
||||||
|
# embedding_model:
|
||||||
|
# name: 'bge-m3'
|
||||||
|
# rerank_model: 'bge-reranker-v2'
|
||||||
|
# asr_model:
|
||||||
|
# model: 'whisper-large-v3' # alias of name
|
||||||
# image2text_model: ''
|
# image2text_model: ''
|
||||||
# oauth:
|
# oauth:
|
||||||
# oauth2:
|
# oauth2:
|
||||||
@ -109,3 +122,14 @@ redis:
|
|||||||
# switch: false
|
# switch: false
|
||||||
# component: false
|
# component: false
|
||||||
# dataset: false
|
# dataset: false
|
||||||
|
# smtp:
|
||||||
|
# mail_server: ""
|
||||||
|
# mail_port: 465
|
||||||
|
# mail_use_ssl: true
|
||||||
|
# mail_use_tls: false
|
||||||
|
# mail_username: ""
|
||||||
|
# mail_password: ""
|
||||||
|
# mail_default_sender:
|
||||||
|
# - "RAGFlow" # display name
|
||||||
|
# - "" # sender email address
|
||||||
|
# mail_frontend_url: "https://your-frontend.example.com"
|
||||||
|
|||||||
@ -3,6 +3,6 @@
|
|||||||
"position": 40,
|
"position": 40,
|
||||||
"link": {
|
"link": {
|
||||||
"type": "generated-index",
|
"type": "generated-index",
|
||||||
"description": "Guides and references on accessing RAGFlow's knowledge bases via MCP."
|
"description": "Guides and references on accessing RAGFlow's datasets via MCP."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,9 +14,9 @@ A RAGFlow Model Context Protocol (MCP) server is designed as an independent comp
|
|||||||
An MCP server can start up in either self-host mode (default) or host mode:
|
An MCP server can start up in either self-host mode (default) or host mode:
|
||||||
|
|
||||||
- **Self-host mode**:
|
- **Self-host mode**:
|
||||||
When launching an MCP server in self-host mode, you must provide an API key to authenticate the MCP server with the RAGFlow server. In this mode, the MCP server can access *only* the datasets (knowledge bases) of a specified tenant on the RAGFlow server.
|
When launching an MCP server in self-host mode, you must provide an API key to authenticate the MCP server with the RAGFlow server. In this mode, the MCP server can access *only* the datasets of a specified tenant on the RAGFlow server.
|
||||||
- **Host mode**:
|
- **Host mode**:
|
||||||
In host mode, each MCP client can access their own knowledge bases on the RAGFlow server. However, each client request must include a valid API key to authenticate the client with the RAGFlow server.
|
In host mode, each MCP client can access their own datasets on the RAGFlow server. However, each client request must include a valid API key to authenticate the client with the RAGFlow server.
|
||||||
|
|
||||||
Once a connection is established, an MCP server communicates with its client in MCP HTTP+SSE (Server-Sent Events) mode, unidirectionally pushing responses from the RAGFlow server to its client in real time.
|
Once a connection is established, an MCP server communicates with its client in MCP HTTP+SSE (Server-Sent Events) mode, unidirectionally pushing responses from the RAGFlow server to its client in real time.
|
||||||
|
|
||||||
|
|||||||
15
docs/faq.mdx
15
docs/faq.mdx
@ -498,7 +498,7 @@ To switch your document engine from Elasticsearch to [Infinity](https://github.c
|
|||||||
|
|
||||||
### Where are my uploaded files stored in RAGFlow's image?
|
### Where are my uploaded files stored in RAGFlow's image?
|
||||||
|
|
||||||
All uploaded files are stored in Minio, RAGFlow's object storage solution. For instance, if you upload your file directly to a knowledge base, it is located at `<knowledgebase_id>/filename`.
|
All uploaded files are stored in Minio, RAGFlow's object storage solution. For instance, if you upload your file directly to a dataset, it is located at `<knowledgebase_id>/filename`.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@ -507,3 +507,16 @@ All uploaded files are stored in Minio, RAGFlow's object storage solution. For i
|
|||||||
You can control the batch size for document parsing and embedding by setting the environment variables `DOC_BULK_SIZE` and `EMBEDDING_BATCH_SIZE`. Increasing these values may improve throughput for large-scale data processing, but will also increase memory usage. Adjust them according to your hardware resources.
|
You can control the batch size for document parsing and embedding by setting the environment variables `DOC_BULK_SIZE` and `EMBEDDING_BATCH_SIZE`. Increasing these values may improve throughput for large-scale data processing, but will also increase memory usage. Adjust them according to your hardware resources.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
### How to accelerate the question-answering speed of my chat assistant?
|
||||||
|
|
||||||
|
See [here](./guides/chat/best_practices/accelerate_question_answering.mdx).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### How to accelerate the question-answering speed of my Agent?
|
||||||
|
|
||||||
|
See [here](./guides/agent/best_practices/accelerate_agent_question_answering.md).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
|||||||
@ -26,6 +26,84 @@ An **Agent** component is essential when you need the LLM to assist with summari
|
|||||||
|
|
||||||
2. If your Agent involves dataset retrieval, ensure you [have properly configured your target knowledge base(s)](../../dataset/configure_knowledge_base.md).
|
2. If your Agent involves dataset retrieval, ensure you [have properly configured your target knowledge base(s)](../../dataset/configure_knowledge_base.md).
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
### 1. Click on an **Agent** component to show its configuration panel
|
||||||
|
|
||||||
|
The corresponding configuration panel appears to the right of the canvas. Use this panel to define and fine-tune the **Agent** component's behavior.
|
||||||
|
|
||||||
|
### 2. Select your model
|
||||||
|
|
||||||
|
Click **Model**, and select a chat model from the dropdown menu.
|
||||||
|
|
||||||
|
:::tip NOTE
|
||||||
|
If no model appears, check if your have added a chat model on the **Model providers** page.
|
||||||
|
:::
|
||||||
|
|
||||||
|
### 3. Update system prompt (Optional)
|
||||||
|
|
||||||
|
The system prompt typically defines your model's role. You can either keep the system prompt as is or customize it to override the default.
|
||||||
|
|
||||||
|
|
||||||
|
### 4. Update user prompt
|
||||||
|
|
||||||
|
The user prompt typically defines your model's task. You will find the `sys.query` variable auto-populated. Type `/` or click **(x)** to view or add variables.
|
||||||
|
|
||||||
|
In this quickstart, we assume your **Agent** component is used standalone (without tools or sub-Agents below), then you may also need to specify retrieved chunks using the `formalized_content` variable:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### 5. Skip Tools and Agent
|
||||||
|
|
||||||
|
The **+ Add tools** and **+ Add agent** sections are used *only* when you need to configure your **Agent** component as a planner (with tools or sub-Agents beneath). In this quickstart, we assume your **Agent** component is used standalone (without tools or sub-Agents beneath).
|
||||||
|
|
||||||
|
### 6. Choose the next component
|
||||||
|
|
||||||
|
When necessary, click the **+** button on the **Agent** component to choose the next component in the worflow from the dropdown list.
|
||||||
|
|
||||||
|
## Connect to an MCP server as a client
|
||||||
|
|
||||||
|
:::danger IMPORTANT
|
||||||
|
In this section, we assume your **Agent** will be configured as a planner, with a Tavily tool beneath it.
|
||||||
|
:::
|
||||||
|
|
||||||
|
### 1. Navigate to the MCP configuration page
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### 2. Configure your Tavily MCP server
|
||||||
|
|
||||||
|
Update your MCP server's name, URL (including the API key), server type, and other necessary settings. When configured correctly, the available tools will be displayed.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### 3. Navigate to your Agent's editing page
|
||||||
|
|
||||||
|
### 4. Connect to your MCP server
|
||||||
|
|
||||||
|
1. Click **+ Add tools**:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
2. Click **MCP** to show the available MCP servers.
|
||||||
|
|
||||||
|
3. Select your MCP server:
|
||||||
|
|
||||||
|
*The target MCP server appears below your Agent component, and your Agent will autonomously decide when to invoke the available tools it offers.*
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### 5. Update system prompt to specify trigger conditions (Optional)
|
||||||
|
|
||||||
|
To ensure reliable tool calls, you may specify within the system prompt which tasks should trigger each tool call.
|
||||||
|
|
||||||
|
### 6. View the availabe tools of your MCP server
|
||||||
|
|
||||||
|
On the canvas, click the newly-populated Tavily server to view and select its available tools:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
## Configurations
|
## Configurations
|
||||||
|
|
||||||
### Model
|
### Model
|
||||||
@ -69,7 +147,7 @@ An **Agent** component relies on keys (variables) to specify its data inputs. It
|
|||||||
|
|
||||||
#### Advanced usage
|
#### Advanced usage
|
||||||
|
|
||||||
From v0.20.5 onwards, four framework-level prompt blocks are available in the **System prompt** field. Type `/` or click **(x)** to view them; they appear under the **Framework** entry in the dropdown menu.
|
From v0.20.5 onwards, four framework-level prompt blocks are available in the **System prompt** field, enabling you to customize and *override* prompts at the framework level. Type `/` or click **(x)** to view them; they appear under the **Framework** entry in the dropdown menu.
|
||||||
|
|
||||||
- `task_analysis` prompt block
|
- `task_analysis` prompt block
|
||||||
- This block is responsible for analyzing tasks — either a user task or a task assigned by the lead Agent when the **Agent** component is acting as a Sub-Agent.
|
- This block is responsible for analyzing tasks — either a user task or a task assigned by the lead Agent when the **Agent** component is acting as a Sub-Agent.
|
||||||
@ -100,6 +178,12 @@ From v0.20.5 onwards, four framework-level prompt blocks are available in the **
|
|||||||
- `citation_guidelines` prompt block
|
- `citation_guidelines` prompt block
|
||||||
- Reference design: [citation_prompt.md](https://github.com/infiniflow/ragflow/blob/main/rag/prompts/citation_prompt.md)
|
- Reference design: [citation_prompt.md](https://github.com/infiniflow/ragflow/blob/main/rag/prompts/citation_prompt.md)
|
||||||
|
|
||||||
|
*The screenshots below show the framework prompt blocks available to an **Agent** component, both as a standalone and as a planner (with a Tavily tool below):*
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
### User prompt
|
### User prompt
|
||||||
|
|
||||||
The user-defined prompt. Defaults to `sys.query`, the user query. As a general rule, when using the **Agent** component as a standalone module (not as a planner), you usually need to specify the corresponding **Retrieval** component’s output variable (`formalized_content`) here as part of the input to the LLM.
|
The user-defined prompt. Defaults to `sys.query`, the user query. As a general rule, when using the **Agent** component as a standalone module (not as a planner), you usually need to specify the corresponding **Retrieval** component’s output variable (`formalized_content`) here as part of the input to the LLM.
|
||||||
@ -129,7 +213,7 @@ Defines the maximum number of attempts the agent will make to retry a failed tas
|
|||||||
|
|
||||||
The waiting period in seconds that the agent observes before retrying a failed task, helping to prevent immediate repeated attempts and allowing system conditions to improve. Defaults to 1 second.
|
The waiting period in seconds that the agent observes before retrying a failed task, helping to prevent immediate repeated attempts and allowing system conditions to improve. Defaults to 1 second.
|
||||||
|
|
||||||
### Max rounds
|
### Max reflection rounds
|
||||||
|
|
||||||
Defines the maximum number reflection rounds of the selected chat model. Defaults to 1 round.
|
Defines the maximum number reflection rounds of the selected chat model. Defaults to 1 round.
|
||||||
|
|
||||||
@ -145,18 +229,4 @@ The global variable name for the output of the **Agent** component, which can be
|
|||||||
|
|
||||||
### Why does it take so long for my Agent to respond?
|
### Why does it take so long for my Agent to respond?
|
||||||
|
|
||||||
An Agent’s response time generally depends on two key factors: the LLM’s capabilities and the prompt, the latter reflecting task complexity. When using an Agent, you should always balance task demands with the LLM’s ability. See [How to balance task complexity with an Agent's performance and speed?](#how-to-balance-task-complexity-with-an-agents-performance-and-speed) for details.
|
See [here](../best_practices/accelerate_agent_question_answering.md) for details.
|
||||||
|
|
||||||
## Best practices
|
|
||||||
|
|
||||||
### How to balance task complexity with an Agent’s performance and speed?
|
|
||||||
|
|
||||||
- For simple tasks, such as retrieval, rewriting, formatting, or structured data extraction, use concise prompts, remove planning or reasoning instructions, enforce output length limits, and select smaller or Turbo-class models. This significantly reduces latency and cost with minimal impact on quality.
|
|
||||||
|
|
||||||
- For complex tasks, like multi-step reasoning, cross-document synthesis, or tool-based workflows, maintain or enhance prompts that include planning, reflection, and verification steps.
|
|
||||||
|
|
||||||
- In multi-Agent orchestration systems, delegate simple subtasks to sub-Agents using smaller, faster models, and reserve more powerful models for the lead Agent to handle complexity and uncertainty.
|
|
||||||
|
|
||||||
:::tip KEY INSIGHT
|
|
||||||
Focus on minimizing output tokens — through summarization, bullet points, or explicit length limits — as this has far greater impact on reducing latency than optimizing input size.
|
|
||||||
:::
|
|
||||||
@ -67,14 +67,14 @@ You can tune document parsing and embedding efficiency by setting the environmen
|
|||||||
|
|
||||||
## Frequently asked questions
|
## Frequently asked questions
|
||||||
|
|
||||||
### Is the uploaded file in a knowledge base?
|
### Is the uploaded file in a dataset?
|
||||||
|
|
||||||
No. Files uploaded to an agent as input are not stored in a knowledge base and hence will not be processed using RAGFlow's built-in OCR, DLR or TSR models, or chunked using RAGFlow's built-in chunking methods.
|
No. Files uploaded to an agent as input are not stored in a dataset and hence will not be processed using RAGFlow's built-in OCR, DLR or TSR models, or chunked using RAGFlow's built-in chunking methods.
|
||||||
|
|
||||||
### File size limit for an uploaded file
|
### File size limit for an uploaded file
|
||||||
|
|
||||||
There is no _specific_ file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete.
|
There is no _specific_ file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete.
|
||||||
|
|
||||||
:::tip NOTE
|
:::tip NOTE
|
||||||
The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a knowledge base or **File Management**. These settings DO NOT apply in this scenario.
|
The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a dataset or **File Management**. These settings DO NOT apply in this scenario.
|
||||||
:::
|
:::
|
||||||
|
|||||||
@ -49,6 +49,10 @@ You can specify multiple input sources for the **Code** component. Click **+ Add
|
|||||||
|
|
||||||
This field allows you to enter and edit your source code.
|
This field allows you to enter and edit your source code.
|
||||||
|
|
||||||
|
:::danger IMPORTANT
|
||||||
|
If your code implementation includes defined variables, whether input or output variables, ensure they are also specified in the corresponding **Input** or **Output** sections.
|
||||||
|
:::
|
||||||
|
|
||||||
#### A Python code example
|
#### A Python code example
|
||||||
|
|
||||||
```Python
|
```Python
|
||||||
@ -77,6 +81,15 @@ This field allows you to enter and edit your source code.
|
|||||||
|
|
||||||
You define the output variable(s) of the **Code** component here.
|
You define the output variable(s) of the **Code** component here.
|
||||||
|
|
||||||
|
:::danger IMPORTANT
|
||||||
|
If you define output variables here, ensure they are also defined in your code implementation; otherwise, their values will be `null`. The following are two examples:
|
||||||
|
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
:::
|
||||||
|
|
||||||
### Output
|
### Output
|
||||||
|
|
||||||
The defined output variable(s) will be auto-populated here.
|
The defined output variable(s) will be auto-populated here.
|
||||||
|
|||||||
79
docs/guides/agent/agent_component_reference/execute_sql.md
Normal file
79
docs/guides/agent/agent_component_reference/execute_sql.md
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
---
|
||||||
|
sidebar_position: 25
|
||||||
|
slug: /execute_sql
|
||||||
|
---
|
||||||
|
|
||||||
|
# Execute SQL tool
|
||||||
|
|
||||||
|
A tool that execute SQL queries on a specified relational database.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
The **Execute SQL** tool enables you to connect to a relational database and run SQL queries, whether entered directly or generated by the system’s Text2SQL capability via an **Agent** component.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- A database instance properly configured and running.
|
||||||
|
- The database must be one of the following types:
|
||||||
|
- MySQL
|
||||||
|
- PostgreSQL
|
||||||
|
- MariaDB
|
||||||
|
- Microsoft SQL Server
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
You can pair an **Agent** component with the **Execute SQL** tool, with the **Agent** generating SQL statements and the **Execute SQL** tool handling database connection and query execution. An example of this setup can be found in the **SQL Assistant** Agent template shown below:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Configurations
|
||||||
|
|
||||||
|
### SQL statement
|
||||||
|
|
||||||
|
This text input field allows you to write static SQL queries, such as `SELECT * FROM my_table`, and dynamic SQL queries using variables.
|
||||||
|
|
||||||
|
:::tip NOTE
|
||||||
|
Click **(x)** or type `/` to insert variables.
|
||||||
|
:::
|
||||||
|
|
||||||
|
For dynamic SQL queries, you can include variables in your SQL queries, such as `SELECT * FROM /sys.query`; if an **Agent** component is paired with the **Execute SQL** tool to generate SQL tasks (see the [Examples](#examples) section), you can directly insert that **Agent**'s output, `content`, into this field.
|
||||||
|
|
||||||
|
### Database type
|
||||||
|
|
||||||
|
The supported database type. Currently the following database types are available:
|
||||||
|
|
||||||
|
- MySQL
|
||||||
|
- PostreSQL
|
||||||
|
- MariaDB
|
||||||
|
- Microsoft SQL Server (Myssql)
|
||||||
|
|
||||||
|
### Database
|
||||||
|
|
||||||
|
Appears only when you select **Split** as method.
|
||||||
|
|
||||||
|
### Username
|
||||||
|
|
||||||
|
The username with access privileges to the database.
|
||||||
|
|
||||||
|
### Host
|
||||||
|
|
||||||
|
The IP address of the database server.
|
||||||
|
|
||||||
|
### Port
|
||||||
|
|
||||||
|
The port number on which the database server is listening.
|
||||||
|
|
||||||
|
### Password
|
||||||
|
|
||||||
|
The password for the database user.
|
||||||
|
|
||||||
|
### Max records
|
||||||
|
|
||||||
|
The maximum number of records returned by the SQL query to control response size and improve efficiency. Defaults to `1024`.
|
||||||
|
|
||||||
|
### Output
|
||||||
|
|
||||||
|
The **Execute SQL** tool provides two output variables:
|
||||||
|
|
||||||
|
- `formalized_content`: A string. If you reference this variable in a **Message** component, the returned records are displayed as a table.
|
||||||
|
- `json`: An object array. If you reference this variable in a **Message** component, the returned records will be presented as key-value pairs.
|
||||||
@ -9,7 +9,7 @@ A component that retrieves information from specified datasets.
|
|||||||
|
|
||||||
## Scenarios
|
## Scenarios
|
||||||
|
|
||||||
A **Retrieval** component is essential in most RAG scenarios, where information is extracted from designated knowledge bases before being sent to the LLM for content generation. A **Retrieval** component can operate either as a standalone workflow module or as a tool for an **Agent** component. In the latter role, the **Agent** component has autonomous control over when to invoke it for query and retrieval.
|
A **Retrieval** component is essential in most RAG scenarios, where information is extracted from designated datasets before being sent to the LLM for content generation. A **Retrieval** component can operate either as a standalone workflow module or as a tool for an **Agent** component. In the latter role, the **Agent** component has autonomous control over when to invoke it for query and retrieval.
|
||||||
|
|
||||||
The following screenshot shows a reference design using the **Retrieval** component, where the component serves as a tool for an **Agent** component. You can find it from the **Report Agent Using Knowledge Base** Agent template.
|
The following screenshot shows a reference design using the **Retrieval** component, where the component serves as a tool for an **Agent** component. You can find it from the **Report Agent Using Knowledge Base** Agent template.
|
||||||
|
|
||||||
@ -17,7 +17,7 @@ The following screenshot shows a reference design using the **Retrieval** compon
|
|||||||
|
|
||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Ensure you [have properly configured your target knowledge base(s)](../../dataset/configure_knowledge_base.md).
|
Ensure you [have properly configured your target dataset(s)](../../dataset/configure_knowledge_base.md).
|
||||||
|
|
||||||
## Quickstart
|
## Quickstart
|
||||||
|
|
||||||
@ -36,9 +36,9 @@ The **Retrieval** component depends on query variables to specify its queries.
|
|||||||
|
|
||||||
By default, you can use `sys.query`, which is the user query and the default output of the **Begin** component. All global variables defined before the **Retrieval** component can also be used as query statements. Use the `(x)` button or type `/` to show all the available query variables.
|
By default, you can use `sys.query`, which is the user query and the default output of the **Begin** component. All global variables defined before the **Retrieval** component can also be used as query statements. Use the `(x)` button or type `/` to show all the available query variables.
|
||||||
|
|
||||||
### 3. Select knowledge base(s) to query
|
### 3. Select dataset(s) to query
|
||||||
|
|
||||||
You can specify one or multiple knowledge bases to retrieve data from. If selecting mutiple, ensure they use the same embedding model.
|
You can specify one or multiple datasets to retrieve data from. If selecting mutiple, ensure they use the same embedding model.
|
||||||
|
|
||||||
### 4. Expand **Advanced Settings** to configure the retrieval method
|
### 4. Expand **Advanced Settings** to configure the retrieval method
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ Using a rerank model will *significantly* increase the system's response time. I
|
|||||||
|
|
||||||
### 5. Enable cross-language search
|
### 5. Enable cross-language search
|
||||||
|
|
||||||
If your user query is different from the languages of the knowledge bases, you can select the target languages in the **Cross-language search** dropdown menu. The model will then translates queries to ensure accurate matching of semantic meaning across languages.
|
If your user query is different from the languages of the datasets, you can select the target languages in the **Cross-language search** dropdown menu. The model will then translates queries to ensure accurate matching of semantic meaning across languages.
|
||||||
|
|
||||||
|
|
||||||
### 6. Test retrieval results
|
### 6. Test retrieval results
|
||||||
@ -76,10 +76,10 @@ The **Retrieval** component relies on query variables to specify its queries. Al
|
|||||||
|
|
||||||
### Knowledge bases
|
### Knowledge bases
|
||||||
|
|
||||||
Select the knowledge base(s) to retrieve data from.
|
Select the dataset(s) to retrieve data from.
|
||||||
|
|
||||||
- If no knowledge base is selected, meaning conversations with the agent will not be based on any knowledge base, ensure that the **Empty response** field is left blank to avoid an error.
|
- If no dataset is selected, meaning conversations with the agent will not be based on any dataset, ensure that the **Empty response** field is left blank to avoid an error.
|
||||||
- If you select multiple knowledge bases, you must ensure that the knowledge bases (datasets) you select use the same embedding model; otherwise, an error message would occur.
|
- If you select multiple datasets, you must ensure that the datasets you select use the same embedding model; otherwise, an error message would occur.
|
||||||
|
|
||||||
### Similarity threshold
|
### Similarity threshold
|
||||||
|
|
||||||
@ -110,11 +110,11 @@ Using a rerank model will *significantly* increase the system's response time.
|
|||||||
|
|
||||||
### Empty response
|
### Empty response
|
||||||
|
|
||||||
- Set this as a response if no results are retrieved from the knowledge base(s) for your query, or
|
- Set this as a response if no results are retrieved from the dataset(s) for your query, or
|
||||||
- Leave this field blank to allow the chat model to improvise when nothing is found.
|
- Leave this field blank to allow the chat model to improvise when nothing is found.
|
||||||
|
|
||||||
:::caution WARNING
|
:::caution WARNING
|
||||||
If you do not specify a knowledge base, you must leave this field blank; otherwise, an error would occur.
|
If you do not specify a dataset, you must leave this field blank; otherwise, an error would occur.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
### Cross-language search
|
### Cross-language search
|
||||||
@ -124,10 +124,10 @@ Select one or more languages for cross‑language search. If no language is sele
|
|||||||
### Use knowledge graph
|
### Use knowledge graph
|
||||||
|
|
||||||
:::caution IMPORTANT
|
:::caution IMPORTANT
|
||||||
Before enabling this feature, ensure you have properly [constructed a knowledge graph from each target knowledge base](../../dataset/construct_knowledge_graph.md).
|
Before enabling this feature, ensure you have properly [constructed a knowledge graph from each target dataset](../../dataset/construct_knowledge_graph.md).
|
||||||
:::
|
:::
|
||||||
|
|
||||||
Whether to use knowledge graph(s) in the specified knowledge base(s) during retrieval for multi-hop question answering. When enabled, this would involve iterative searches across entity, relationship, and community report chunks, greatly increasing retrieval time.
|
Whether to use knowledge graph(s) in the specified dataset(s) during retrieval for multi-hop question answering. When enabled, this would involve iterative searches across entity, relationship, and community report chunks, greatly increasing retrieval time.
|
||||||
|
|
||||||
### Output
|
### Output
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@ Agents and RAG are complementary techniques, each enhancing the other’s capabi
|
|||||||
Before proceeding, ensure that:
|
Before proceeding, ensure that:
|
||||||
|
|
||||||
1. You have properly set the LLM to use. See the guides on [Configure your API key](../models/llm_api_key_setup.md) or [Deploy a local LLM](../models/deploy_local_llm.mdx) for more information.
|
1. You have properly set the LLM to use. See the guides on [Configure your API key](../models/llm_api_key_setup.md) or [Deploy a local LLM](../models/deploy_local_llm.mdx) for more information.
|
||||||
2. You have a knowledge base configured and the corresponding files properly parsed. See the guide on [Configure a knowledge base](../dataset/configure_knowledge_base.md) for more information.
|
2. You have a dataset configured and the corresponding files properly parsed. See the guide on [Configure a dataset](../dataset/configure_knowledge_base.md) for more information.
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
|||||||
8
docs/guides/agent/best_practices/_category_.json
Normal file
8
docs/guides/agent/best_practices/_category_.json
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"label": "Best practices",
|
||||||
|
"position": 30,
|
||||||
|
"link": {
|
||||||
|
"type": "generated-index",
|
||||||
|
"description": "Best practices on Agent configuration."
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,58 @@
|
|||||||
|
---
|
||||||
|
sidebar_position: 1
|
||||||
|
slug: /accelerate_agent_question_answering
|
||||||
|
---
|
||||||
|
|
||||||
|
# Accelerate answering
|
||||||
|
|
||||||
|
A checklist to speed up question answering.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Please note that some of your settings may consume a significant amount of time. If you often find that your question answering is time-consuming, here is a checklist to consider:
|
||||||
|
|
||||||
|
## Balance task complexity with an Agent’s performance and speed?
|
||||||
|
|
||||||
|
An Agent’s response time generally depends on many factors, e.g., the LLM’s capabilities and the prompt, the latter reflecting task complexity. When using an Agent, you should always balance task demands with the LLM’s ability.
|
||||||
|
|
||||||
|
- For simple tasks, such as retrieval, rewriting, formatting, or structured data extraction, use concise prompts, remove planning or reasoning instructions, enforce output length limits, and select smaller or Turbo-class models. This significantly reduces latency and cost with minimal impact on quality.
|
||||||
|
|
||||||
|
- For complex tasks, like multi-step reasoning, cross-document synthesis, or tool-based workflows, maintain or enhance prompts that include planning, reflection, and verification steps.
|
||||||
|
|
||||||
|
- In multi-Agent orchestration systems, delegate simple subtasks to sub-Agents using smaller, faster models, and reserve more powerful models for the lead Agent to handle complexity and uncertainty.
|
||||||
|
|
||||||
|
:::tip KEY INSIGHT
|
||||||
|
Focus on minimizing output tokens — through summarization, bullet points, or explicit length limits — as this has far greater impact on reducing latency than optimizing input size.
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Disable Reasoning
|
||||||
|
|
||||||
|
Disabling the **Reasoning** toggle will reduce the LLM's thinking time. For a model like Qwen3, you also need to add `/no_think` to the system prompt to disable reasoning.
|
||||||
|
|
||||||
|
## Disable Rerank model
|
||||||
|
|
||||||
|
- Leaving the **Rerank model** field empty (in the corresponding **Retrieval** component) will significantly decrease retrieval time.
|
||||||
|
- When using a rerank model, ensure you have a GPU for acceleration; otherwise, the reranking process will be *prohibitively* slow.
|
||||||
|
|
||||||
|
:::tip NOTE
|
||||||
|
Please note that rerank models are essential in certain scenarios. There is always a trade-off between speed and performance; you must weigh the pros against cons for your specific case.
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Check the time taken for each task
|
||||||
|
|
||||||
|
Click the light bulb icon above the *current* dialogue and scroll down the popup window to view the time taken for each task:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
| Item name | Description |
|
||||||
|
| ----------------- | --------------------------------------------------------------------------------------------- |
|
||||||
|
| Total | Total time spent on this conversation round, including chunk retrieval and answer generation. |
|
||||||
|
| Check LLM | Time to validate the specified LLM. |
|
||||||
|
| Create retriever | Time to create a chunk retriever. |
|
||||||
|
| Bind embedding | Time to initialize an embedding model instance. |
|
||||||
|
| Bind LLM | Time to initialize an LLM instance. |
|
||||||
|
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. |
|
||||||
|
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
||||||
|
| Generate keywords | Time to extract keywords from the user query. |
|
||||||
|
| Retrieval | Time to retrieve the chunks. |
|
||||||
|
| Generate answer | Time to generate the answer. |
|
||||||
@ -22,7 +22,7 @@ When debugging your chat assistant, you can use AI search as a reference to veri
|
|||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
- Ensure that you have configured the system's default models on the **Model providers** page.
|
- Ensure that you have configured the system's default models on the **Model providers** page.
|
||||||
- Ensure that the intended knowledge bases are properly configured and the intended documents have finished file parsing.
|
- Ensure that the intended datasets are properly configured and the intended documents have finished file parsing.
|
||||||
|
|
||||||
## Frequently asked questions
|
## Frequently asked questions
|
||||||
|
|
||||||
|
|||||||
@ -6,21 +6,22 @@ slug: /accelerate_question_answering
|
|||||||
# Accelerate answering
|
# Accelerate answering
|
||||||
import APITable from '@site/src/components/APITable';
|
import APITable from '@site/src/components/APITable';
|
||||||
|
|
||||||
A checklist to speed up question answering.
|
A checklist to speed up question answering for your chat assistant.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
Please note that some of your settings may consume a significant amount of time. If you often find that your question answering is time-consuming, here is a checklist to consider:
|
Please note that some of your settings may consume a significant amount of time. If you often find that your question answering is time-consuming, here is a checklist to consider:
|
||||||
|
|
||||||
- In the **Prompt engine** tab of your **Chat Configuration** dialogue, disabling **Multi-turn optimization** will reduce the time required to get an answer from the LLM.
|
- Disabling **Multi-turn optimization** will reduce the time required to get an answer from the LLM.
|
||||||
- In the **Prompt engine** tab of your **Chat Configuration** dialogue, leaving the **Rerank model** field empty will significantly decrease retrieval time.
|
- Leaving the **Rerank model** field empty will significantly decrease retrieval time.
|
||||||
|
- Disabling the **Reasoning** toggle will reduce the LLM's thinking time. For a model like Qwen3, you also need to add `/no_think` to the system prompt to disable reasoning.
|
||||||
- When using a rerank model, ensure you have a GPU for acceleration; otherwise, the reranking process will be *prohibitively* slow.
|
- When using a rerank model, ensure you have a GPU for acceleration; otherwise, the reranking process will be *prohibitively* slow.
|
||||||
|
|
||||||
:::tip NOTE
|
:::tip NOTE
|
||||||
Please note that rerank models are essential in certain scenarios. There is always a trade-off between speed and performance; you must weigh the pros against cons for your specific case.
|
Please note that rerank models are essential in certain scenarios. There is always a trade-off between speed and performance; you must weigh the pros against cons for your specific case.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
- In the **Assistant settings** tab of your **Chat Configuration** dialogue, disabling **Keyword analysis** will reduce the time to receive an answer from the LLM.
|
- Disabling **Keyword analysis** will reduce the time to receive an answer from the LLM.
|
||||||
- When chatting with your chat assistant, click the light bulb icon above the *current* dialogue and scroll down the popup window to view the time taken for each task:
|
- When chatting with your chat assistant, click the light bulb icon above the *current* dialogue and scroll down the popup window to view the time taken for each task:
|
||||||

|

|
||||||
|
|
||||||
|
|||||||
@ -25,13 +25,13 @@ In the **Variable** section, you add, remove, or update variables.
|
|||||||
|
|
||||||
### `{knowledge}` - a reserved variable
|
### `{knowledge}` - a reserved variable
|
||||||
|
|
||||||
`{knowledge}` is the system's reserved variable, representing the chunks retrieved from the knowledge base(s) specified by **Knowledge bases** under the **Assistant settings** tab. If your chat assistant is associated with certain knowledge bases, you can keep it as is.
|
`{knowledge}` is the system's reserved variable, representing the chunks retrieved from the dataset(s) specified by **Knowledge bases** under the **Assistant settings** tab. If your chat assistant is associated with certain datasets, you can keep it as is.
|
||||||
|
|
||||||
:::info NOTE
|
:::info NOTE
|
||||||
It currently makes no difference whether `{knowledge}` is set as optional or mandatory, but please note this design will be updated in due course.
|
It currently makes no difference whether `{knowledge}` is set as optional or mandatory, but please note this design will be updated in due course.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
From v0.17.0 onward, you can start an AI chat without specifying knowledge bases. In this case, we recommend removing the `{knowledge}` variable to prevent unnecessary reference and keeping the **Empty response** field empty to avoid errors.
|
From v0.17.0 onward, you can start an AI chat without specifying datasets. In this case, we recommend removing the `{knowledge}` variable to prevent unnecessary reference and keeping the **Empty response** field empty to avoid errors.
|
||||||
|
|
||||||
### Custom variables
|
### Custom variables
|
||||||
|
|
||||||
@ -45,15 +45,15 @@ Besides `{knowledge}`, you can also define your own variables to pair with the s
|
|||||||
After you add or remove variables in the **Variable** section, ensure your changes are reflected in the system prompt to avoid inconsistencies or errors. Here's an example:
|
After you add or remove variables in the **Variable** section, ensure your changes are reflected in the system prompt to avoid inconsistencies or errors. Here's an example:
|
||||||
|
|
||||||
```
|
```
|
||||||
You are an intelligent assistant. Please answer the question by summarizing chunks from the specified knowledge base(s)...
|
You are an intelligent assistant. Please answer the question by summarizing chunks from the specified dataset(s)...
|
||||||
|
|
||||||
Your answers should follow a professional and {style} style.
|
Your answers should follow a professional and {style} style.
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
Here is the knowledge base:
|
Here is the dataset:
|
||||||
{knowledge}
|
{knowledge}
|
||||||
The above is the knowledge base.
|
The above is the dataset.
|
||||||
```
|
```
|
||||||
|
|
||||||
:::tip NOTE
|
:::tip NOTE
|
||||||
|
|||||||
@ -9,7 +9,7 @@ Initiate an AI-powered chat with a configured chat assistant.
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
Knowledge base, hallucination-free chat, and file management are the three pillars of RAGFlow. Chats in RAGFlow are based on a particular knowledge base or multiple knowledge bases. Once you have created your knowledge base, finished file parsing, and [run a retrieval test](../dataset/run_retrieval_test.md), you can go ahead and start an AI conversation.
|
Knowledge base, hallucination-free chat, and file management are the three pillars of RAGFlow. Chats in RAGFlow are based on a particular dataset or multiple datasets. Once you have created your dataset, finished file parsing, and [run a retrieval test](../dataset/run_retrieval_test.md), you can go ahead and start an AI conversation.
|
||||||
|
|
||||||
## Start an AI chat
|
## Start an AI chat
|
||||||
|
|
||||||
@ -21,12 +21,12 @@ You start an AI conversation by creating an assistant.
|
|||||||
|
|
||||||
2. Update **Assistant settings**:
|
2. Update **Assistant settings**:
|
||||||
|
|
||||||
- **Assistant name** is the name of your chat assistant. Each assistant corresponds to a dialogue with a unique combination of knowledge bases, prompts, hybrid search configurations, and large model settings.
|
- **Assistant name** is the name of your chat assistant. Each assistant corresponds to a dialogue with a unique combination of datasets, prompts, hybrid search configurations, and large model settings.
|
||||||
- **Empty response**:
|
- **Empty response**:
|
||||||
- If you wish to *confine* RAGFlow's answers to your knowledge bases, leave a response here. Then, when it doesn't retrieve an answer, it *uniformly* responds with what you set here.
|
- If you wish to *confine* RAGFlow's answers to your datasets, leave a response here. Then, when it doesn't retrieve an answer, it *uniformly* responds with what you set here.
|
||||||
- If you wish RAGFlow to *improvise* when it doesn't retrieve an answer from your knowledge bases, leave it blank, which may give rise to hallucinations.
|
- If you wish RAGFlow to *improvise* when it doesn't retrieve an answer from your datasets, leave it blank, which may give rise to hallucinations.
|
||||||
- **Show quote**: This is a key feature of RAGFlow and enabled by default. RAGFlow does not work like a black box. Instead, it clearly shows the sources of information that its responses are based on.
|
- **Show quote**: This is a key feature of RAGFlow and enabled by default. RAGFlow does not work like a black box. Instead, it clearly shows the sources of information that its responses are based on.
|
||||||
- Select the corresponding knowledge bases. You can select one or multiple knowledge bases, but ensure that they use the same embedding model, otherwise an error would occur.
|
- Select the corresponding datasets. You can select one or multiple datasets, but ensure that they use the same embedding model, otherwise an error would occur.
|
||||||
|
|
||||||
3. Update **Prompt engine**:
|
3. Update **Prompt engine**:
|
||||||
|
|
||||||
@ -37,14 +37,14 @@ You start an AI conversation by creating an assistant.
|
|||||||
- If **Rerank model** is selected, the hybrid score system uses keyword similarity and reranker score, and the default weight assigned to the reranker score is 1-0.7=0.3.
|
- If **Rerank model** is selected, the hybrid score system uses keyword similarity and reranker score, and the default weight assigned to the reranker score is 1-0.7=0.3.
|
||||||
- **Top N** determines the *maximum* number of chunks to feed to the LLM. In other words, even if more chunks are retrieved, only the top N chunks are provided as input.
|
- **Top N** determines the *maximum* number of chunks to feed to the LLM. In other words, even if more chunks are retrieved, only the top N chunks are provided as input.
|
||||||
- **Multi-turn optimization** enhances user queries using existing context in a multi-round conversation. It is enabled by default. When enabled, it will consume additional LLM tokens and significantly increase the time to generate answers.
|
- **Multi-turn optimization** enhances user queries using existing context in a multi-round conversation. It is enabled by default. When enabled, it will consume additional LLM tokens and significantly increase the time to generate answers.
|
||||||
- **Use knowledge graph** indicates whether to use knowledge graph(s) in the specified knowledge base(s) during retrieval for multi-hop question answering. When enabled, this would involve iterative searches across entity, relationship, and community report chunks, greatly increasing retrieval time.
|
- **Use knowledge graph** indicates whether to use knowledge graph(s) in the specified dataset(s) during retrieval for multi-hop question answering. When enabled, this would involve iterative searches across entity, relationship, and community report chunks, greatly increasing retrieval time.
|
||||||
- **Reasoning** indicates whether to generate answers through reasoning processes like Deepseek-R1/OpenAI o1. Once enabled, the chat model autonomously integrates Deep Research during question answering when encountering an unknown topic. This involves the chat model dynamically searching external knowledge and generating final answers through reasoning.
|
- **Reasoning** indicates whether to generate answers through reasoning processes like Deepseek-R1/OpenAI o1. Once enabled, the chat model autonomously integrates Deep Research during question answering when encountering an unknown topic. This involves the chat model dynamically searching external knowledge and generating final answers through reasoning.
|
||||||
- **Rerank model** sets the reranker model to use. It is left empty by default.
|
- **Rerank model** sets the reranker model to use. It is left empty by default.
|
||||||
- If **Rerank model** is left empty, the hybrid score system uses keyword similarity and vector similarity, and the default weight assigned to the vector similarity component is 1-0.7=0.3.
|
- If **Rerank model** is left empty, the hybrid score system uses keyword similarity and vector similarity, and the default weight assigned to the vector similarity component is 1-0.7=0.3.
|
||||||
- If **Rerank model** is selected, the hybrid score system uses keyword similarity and reranker score, and the default weight assigned to the reranker score is 1-0.7=0.3.
|
- If **Rerank model** is selected, the hybrid score system uses keyword similarity and reranker score, and the default weight assigned to the reranker score is 1-0.7=0.3.
|
||||||
- [Cross-language search](../../references/glossary.mdx#cross-language-search): Optional
|
- [Cross-language search](../../references/glossary.mdx#cross-language-search): Optional
|
||||||
Select one or more target languages from the dropdown menu. The system’s default chat model will then translate your query into the selected target language(s). This translation ensures accurate semantic matching across languages, allowing you to retrieve relevant results regardless of language differences.
|
Select one or more target languages from the dropdown menu. The system’s default chat model will then translate your query into the selected target language(s). This translation ensures accurate semantic matching across languages, allowing you to retrieve relevant results regardless of language differences.
|
||||||
- When selecting target languages, please ensure that these languages are present in the knowledge base to guarantee an effective search.
|
- When selecting target languages, please ensure that these languages are present in the dataset to guarantee an effective search.
|
||||||
- If no target language is selected, the system will search only in the language of your query, which may cause relevant information in other languages to be missed.
|
- If no target language is selected, the system will search only in the language of your query, which may cause relevant information in other languages to be missed.
|
||||||
- **Variable** refers to the variables (keys) to be used in the system prompt. `{knowledge}` is a reserved variable. Click **Add** to add more variables for the system prompt.
|
- **Variable** refers to the variables (keys) to be used in the system prompt. `{knowledge}` is a reserved variable. Click **Add** to add more variables for the system prompt.
|
||||||
- If you are uncertain about the logic behind **Variable**, leave it *as-is*.
|
- If you are uncertain about the logic behind **Variable**, leave it *as-is*.
|
||||||
@ -106,7 +106,7 @@ RAGFlow offers HTTP and Python APIs for you to integrate RAGFlow's capabilities
|
|||||||
|
|
||||||
You can use iframe to embed the created chat assistant into a third-party webpage:
|
You can use iframe to embed the created chat assistant into a third-party webpage:
|
||||||
|
|
||||||
1. Before proceeding, you must [acquire an API key](../models/llm_api_key_setup.md); otherwise, an error message would appear.
|
1. Before proceeding, you must [acquire an API key](../../develop/acquire_ragflow_api_key.md); otherwise, an error message would appear.
|
||||||
2. Hover over an intended chat assistant **>** **Edit** to show the **iframe** window:
|
2. Hover over an intended chat assistant **>** **Edit** to show the **iframe** window:
|
||||||
|
|
||||||

|

|
||||||
|
|||||||
@ -3,6 +3,6 @@
|
|||||||
"position": 0,
|
"position": 0,
|
||||||
"link": {
|
"link": {
|
||||||
"type": "generated-index",
|
"type": "generated-index",
|
||||||
"description": "Guides on configuring a knowledge base."
|
"description": "Guides on configuring a dataset."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user