mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-23 03:26:53 +08:00
Compare commits
171 Commits
ca3bd2cf9f
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 7c9b6e032b | |||
| 3beb85efa0 | |||
| bc7b864a6c | |||
| 93091f4551 | |||
| 2d9e7b4acd | |||
| 6f3f69b62e | |||
| bfd5435087 | |||
| 0e9fe68110 | |||
| 89f438fe45 | |||
| 2e2c8f6ca9 | |||
| 6cd4fd91e6 | |||
| 83e17d8c4a | |||
| e1143d40bc | |||
| f98abf14a8 | |||
| 2a87778e10 | |||
| 5836823187 | |||
| 5a7026cf55 | |||
| bc7935d627 | |||
| 7787085664 | |||
| 960ecd3158 | |||
| aee9860970 | |||
| 9ebbc5a74d | |||
| 1c65f64bda | |||
| 32841549c1 | |||
| 046d4ffdef | |||
| 4c4d434bc1 | |||
| 80612bc992 | |||
| 927db0b373 | |||
| 120648ac81 | |||
| f367189703 | |||
| 1b1554c563 | |||
| 59f3da2bdf | |||
| b40d639fdb | |||
| 05da2a5872 | |||
| 4fbaa4aae9 | |||
| 3188cd2659 | |||
| c4a982e9fa | |||
| b27dc26be3 | |||
| ab1836f216 | |||
| 7a53d2dd97 | |||
| f3d347f55f | |||
| 9da48ab0bd | |||
| 4a7e40630b | |||
| d6897b6054 | |||
| 828ae1e82f | |||
| 57d189b483 | |||
| 0a8eb11c3d | |||
| 38f0a92da9 | |||
| 067ddcbf23 | |||
| 46305ef35e | |||
| bd9163904a | |||
| b6d7733058 | |||
| 4f036a881d | |||
| 59075a0b58 | |||
| 30bd25716b | |||
| 99dae3c64c | |||
| 045314a1aa | |||
| 2b20d0b3bb | |||
| 59f4c51222 | |||
| 8c1fbfb130 | |||
| cec06bfb5d | |||
| 2167e3a3c0 | |||
| 2ea8dddef6 | |||
| 18867daba7 | |||
| d68176326d | |||
| d531bd4f1a | |||
| ac936005e6 | |||
| d8192f8f17 | |||
| eb35e2b89f | |||
| 97b983fd0b | |||
| b40a7b2e7d | |||
| 9a10558f80 | |||
| f82628c40c | |||
| 7af98328f5 | |||
| 678a4f959c | |||
| 15a8bb2e9c | |||
| b091ff2730 | |||
| 5b22f94502 | |||
| a7671583b3 | |||
| d32fa02d97 | |||
| f72a35188d | |||
| ea619dba3b | |||
| 36b0835740 | |||
| 0795616b34 | |||
| 941651a16f | |||
| 360114ed42 | |||
| ffedb2c6d3 | |||
| 947e63ca14 | |||
| 34d74d9928 | |||
| accae95126 | |||
| 68e5c86e9c | |||
| 64c75d558e | |||
| 41c84fd78f | |||
| d76912ab15 | |||
| 4fe3c24198 | |||
| 44bada64c9 | |||
| 867ec94258 | |||
| fd0a1fde6b | |||
| 653001b14f | |||
| d4f8c724ed | |||
| a7dd3b7e9e | |||
| 638c510468 | |||
| ff11e3171e | |||
| 030d6ba004 | |||
| b226e06e2d | |||
| 2e09db02f3 | |||
| 6abf55c048 | |||
| f9d4179bf2 | |||
| 64b1e0b4c3 | |||
| b65daeb945 | |||
| fbe55cef05 | |||
| 0878526ba8 | |||
| a2db3e3292 | |||
| f522391d1e | |||
| 9562762af2 | |||
| 455fd04050 | |||
| 14c250e3d7 | |||
| a093e616cf | |||
| 696397ebba | |||
| 6f1a555d5f | |||
| 1996aa0dac | |||
| f4e2783eb4 | |||
| 2fd4a3134d | |||
| f1dc2df23c | |||
| de27c006d8 | |||
| 23a9544b73 | |||
| 011bbe9556 | |||
| a442c9cac6 | |||
| 671e719d75 | |||
| 07845be5bd | |||
| 8d406bd2e6 | |||
| 2a4627d9a0 | |||
| 6814ace1aa | |||
| ca9645f39b | |||
| 8e03843145 | |||
| 51ece37db2 | |||
| 45fb2719cf | |||
| bdd9f3d4d1 | |||
| 1f60863f60 | |||
| 02e6870755 | |||
| aa08920e51 | |||
| 7818644129 | |||
| 55c9fc0017 | |||
| 140dd2c8cc | |||
| fada223249 | |||
| 00f8a80ca4 | |||
| 4e9407b4ae | |||
| 42461bc378 | |||
| 92780c486a | |||
| 81f9296d79 | |||
| 606f4e6c9e | |||
| 4cd4526492 | |||
| cc8a10376a | |||
| 5ebe334a2f | |||
| 932496a8ec | |||
| a8a060676a | |||
| 2c10ccd622 | |||
| a2211c200d | |||
| 21ba9e6d72 | |||
| ac9113b0ef | |||
| 11779697de | |||
| d6e006f086 | |||
| d39fa75d36 | |||
| f56bceb2a9 | |||
| bbaf918d74 | |||
| 89a97be2c5 | |||
| 6f2fc2f1cb | |||
| 42da080d89 | |||
| 1f4a17863f | |||
| 4d3a3a97ef | |||
| ff1020ccfb |
28
.github/workflows/tests.yml
vendored
28
.github/workflows/tests.yml
vendored
@ -86,6 +86,9 @@ jobs:
|
||||
mkdir -p ${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}
|
||||
echo "${PR_SHA} ${GITHUB_RUN_ID}" > ${PR_SHA_FP}
|
||||
fi
|
||||
ARTIFACTS_DIR=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/${GITHUB_RUN_ID}
|
||||
echo "ARTIFACTS_DIR=${ARTIFACTS_DIR}" >> ${GITHUB_ENV}
|
||||
rm -rf ${ARTIFACTS_DIR} && mkdir -p ${ARTIFACTS_DIR}
|
||||
|
||||
# https://github.com/astral-sh/ruff-action
|
||||
- name: Static check with Ruff
|
||||
@ -161,7 +164,7 @@ jobs:
|
||||
INFINITY_THRIFT_PORT=$((23817 + RUNNER_NUM * 10))
|
||||
INFINITY_HTTP_PORT=$((23820 + RUNNER_NUM * 10))
|
||||
INFINITY_PSQL_PORT=$((5432 + RUNNER_NUM * 10))
|
||||
MYSQL_PORT=$((5455 + RUNNER_NUM * 10))
|
||||
EXPOSE_MYSQL_PORT=$((5455 + RUNNER_NUM * 10))
|
||||
MINIO_PORT=$((9000 + RUNNER_NUM * 10))
|
||||
MINIO_CONSOLE_PORT=$((9001 + RUNNER_NUM * 10))
|
||||
REDIS_PORT=$((6379 + RUNNER_NUM * 10))
|
||||
@ -181,7 +184,7 @@ jobs:
|
||||
echo -e "INFINITY_THRIFT_PORT=${INFINITY_THRIFT_PORT}" >> docker/.env
|
||||
echo -e "INFINITY_HTTP_PORT=${INFINITY_HTTP_PORT}" >> docker/.env
|
||||
echo -e "INFINITY_PSQL_PORT=${INFINITY_PSQL_PORT}" >> docker/.env
|
||||
echo -e "MYSQL_PORT=${MYSQL_PORT}" >> docker/.env
|
||||
echo -e "EXPOSE_MYSQL_PORT=${EXPOSE_MYSQL_PORT}" >> docker/.env
|
||||
echo -e "MINIO_PORT=${MINIO_PORT}" >> docker/.env
|
||||
echo -e "MINIO_CONSOLE_PORT=${MINIO_CONSOLE_PORT}" >> docker/.env
|
||||
echo -e "REDIS_PORT=${REDIS_PORT}" >> docker/.env
|
||||
@ -211,14 +214,14 @@ jobs:
|
||||
done
|
||||
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
|
||||
|
||||
- name: Run frontend api tests against Elasticsearch
|
||||
- name: Run web api tests against Elasticsearch
|
||||
run: |
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee es_api_test.log
|
||||
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api/ 2>&1 | tee es_web_api_test.log
|
||||
|
||||
- name: Run http api tests against Elasticsearch
|
||||
run: |
|
||||
@ -229,6 +232,13 @@ jobs:
|
||||
done
|
||||
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log
|
||||
|
||||
- name: Collect ragflow log
|
||||
if: ${{ !cancelled() }}
|
||||
run: |
|
||||
cp -r docker/ragflow-logs ${ARTIFACTS_DIR}/ragflow-logs-es
|
||||
echo "ragflow log" && tail -n 200 docker/ragflow-logs/ragflow_server.log
|
||||
sudo rm -rf docker/ragflow-logs
|
||||
|
||||
- name: Stop ragflow:nightly
|
||||
if: always() # always run this step even if previous steps failed
|
||||
run: |
|
||||
@ -249,14 +259,14 @@ jobs:
|
||||
done
|
||||
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
|
||||
|
||||
- name: Run frontend api tests against Infinity
|
||||
- name: Run web api tests against Infinity
|
||||
run: |
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee infinity_api_test.log
|
||||
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api/ 2>&1 | tee infinity_web_api_test.log
|
||||
|
||||
- name: Run http api tests against Infinity
|
||||
run: |
|
||||
@ -267,6 +277,12 @@ jobs:
|
||||
done
|
||||
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log
|
||||
|
||||
- name: Collect ragflow log
|
||||
if: ${{ !cancelled() }}
|
||||
run: |
|
||||
cp -r docker/ragflow-logs ${ARTIFACTS_DIR}/ragflow-logs-infinity
|
||||
echo "ragflow log" && tail -n 200 docker/ragflow-logs/ragflow_server.log
|
||||
sudo rm -rf docker/ragflow-logs
|
||||
- name: Stop ragflow:nightly
|
||||
if: always() # always run this step even if previous steps failed
|
||||
run: |
|
||||
|
||||
15
.gitignore
vendored
15
.gitignore
vendored
@ -44,6 +44,7 @@ cl100k_base.tiktoken
|
||||
chrome*
|
||||
huggingface.co/
|
||||
nltk_data/
|
||||
uv-x86_64*.tar.gz
|
||||
|
||||
# Exclude hash-like temporary files like 9b5ad71b2ce5302211f9c61530b329a4922fc6a4
|
||||
*[0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f]*
|
||||
@ -51,6 +52,13 @@ nltk_data/
|
||||
.venv
|
||||
docker/data
|
||||
|
||||
# OceanBase data and conf
|
||||
docker/oceanbase/conf
|
||||
docker/oceanbase/data
|
||||
|
||||
# SeekDB data and conf
|
||||
docker/seekdb
|
||||
|
||||
|
||||
#--------------------------------------------------#
|
||||
# The following was generated with gitignore.nvim: #
|
||||
@ -197,4 +205,9 @@ ragflow_cli.egg-info
|
||||
backup
|
||||
|
||||
|
||||
.hypothesis
|
||||
.hypothesis
|
||||
|
||||
|
||||
# Added by cargo
|
||||
|
||||
/target
|
||||
|
||||
19
Dockerfile
19
Dockerfile
@ -53,7 +53,8 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt install -y ghostscript && \
|
||||
apt install -y pandoc && \
|
||||
apt install -y texlive && \
|
||||
apt install -y fonts-freefont-ttf fonts-noto-cjk
|
||||
apt install -y fonts-freefont-ttf fonts-noto-cjk && \
|
||||
apt install -y postgresql-client
|
||||
|
||||
# Install uv
|
||||
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
|
||||
@ -64,10 +65,12 @@ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps
|
||||
echo 'url = "https://pypi.tuna.tsinghua.edu.cn/simple"' >> /etc/uv/uv.toml && \
|
||||
echo 'default = true' >> /etc/uv/uv.toml; \
|
||||
fi; \
|
||||
tar xzf /deps/uv-x86_64-unknown-linux-gnu.tar.gz \
|
||||
&& cp uv-x86_64-unknown-linux-gnu/* /usr/local/bin/ \
|
||||
&& rm -rf uv-x86_64-unknown-linux-gnu \
|
||||
&& uv python install 3.11
|
||||
arch="$(uname -m)"; \
|
||||
if [ "$arch" = "x86_64" ]; then uv_arch="x86_64"; else uv_arch="aarch64"; fi; \
|
||||
tar xzf "/deps/uv-${uv_arch}-unknown-linux-gnu.tar.gz" \
|
||||
&& cp "uv-${uv_arch}-unknown-linux-gnu/"* /usr/local/bin/ \
|
||||
&& rm -rf "uv-${uv_arch}-unknown-linux-gnu" \
|
||||
&& uv python install 3.12
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
|
||||
ENV PATH=/root/.local/bin:$PATH
|
||||
@ -152,11 +155,14 @@ RUN --mount=type=cache,id=ragflow_uv,target=/root/.cache/uv,sharing=locked \
|
||||
else \
|
||||
sed -i 's|pypi.tuna.tsinghua.edu.cn|pypi.org|g' uv.lock; \
|
||||
fi; \
|
||||
uv sync --python 3.12 --frozen
|
||||
uv sync --python 3.12 --frozen && \
|
||||
# Ensure pip is available in the venv for runtime package installation (fixes #12651)
|
||||
.venv/bin/python3 -m ensurepip --upgrade
|
||||
|
||||
COPY web web
|
||||
COPY docs docs
|
||||
RUN --mount=type=cache,id=ragflow_npm,target=/root/.npm,sharing=locked \
|
||||
export NODE_OPTIONS="--max-old-space-size=4096" && \
|
||||
cd web && npm install && npm run build
|
||||
|
||||
COPY .git /ragflow/.git
|
||||
@ -187,7 +193,6 @@ COPY deepdoc deepdoc
|
||||
COPY rag rag
|
||||
COPY agent agent
|
||||
COPY graphrag graphrag
|
||||
COPY agentic_reasoning agentic_reasoning
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY mcp mcp
|
||||
COPY plugin plugin
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
FROM scratch
|
||||
|
||||
# Copy resources downloaded via download_deps.py
|
||||
COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.2.3.jar tika-server-standard-3.2.3.jar.md5 libssl*.deb uv-x86_64-unknown-linux-gnu.tar.gz /
|
||||
COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.2.3.jar tika-server-standard-3.2.3.jar.md5 libssl*.deb uv-x86_64-unknown-linux-gnu.tar.gz uv-aarch64-unknown-linux-gnu.tar.gz /
|
||||
|
||||
COPY nltk_data /nltk_data
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ cp pyproject.toml release/$PROJECT_NAME/pyproject.toml
|
||||
cp README.md release/$PROJECT_NAME/README.md
|
||||
|
||||
mkdir release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR -p
|
||||
cp admin_client.py release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR/admin_client.py
|
||||
cp ragflow_cli.py release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR/ragflow_cli.py
|
||||
|
||||
if [ -d "release/$PROJECT_NAME/$SOURCE_DIR" ]; then
|
||||
echo "✅ source dir: release/$PROJECT_NAME/$SOURCE_DIR"
|
||||
|
||||
@ -1,938 +0,0 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import getpass
|
||||
from cmd import Cmd
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from lark import Lark, Transformer, Tree
|
||||
|
||||
GRAMMAR = r"""
|
||||
start: command
|
||||
|
||||
command: sql_command | meta_command
|
||||
|
||||
sql_command: list_services
|
||||
| show_service
|
||||
| startup_service
|
||||
| shutdown_service
|
||||
| restart_service
|
||||
| list_users
|
||||
| show_user
|
||||
| drop_user
|
||||
| alter_user
|
||||
| create_user
|
||||
| activate_user
|
||||
| list_datasets
|
||||
| list_agents
|
||||
| create_role
|
||||
| drop_role
|
||||
| alter_role
|
||||
| list_roles
|
||||
| show_role
|
||||
| grant_permission
|
||||
| revoke_permission
|
||||
| alter_user_role
|
||||
| show_user_permission
|
||||
| show_version
|
||||
|
||||
// meta command definition
|
||||
meta_command: "\\" meta_command_name [meta_args]
|
||||
|
||||
meta_command_name: /[a-zA-Z?]+/
|
||||
meta_args: (meta_arg)+
|
||||
|
||||
meta_arg: /[^\\s"']+/ | quoted_string
|
||||
|
||||
// command definition
|
||||
|
||||
LIST: "LIST"i
|
||||
SERVICES: "SERVICES"i
|
||||
SHOW: "SHOW"i
|
||||
CREATE: "CREATE"i
|
||||
SERVICE: "SERVICE"i
|
||||
SHUTDOWN: "SHUTDOWN"i
|
||||
STARTUP: "STARTUP"i
|
||||
RESTART: "RESTART"i
|
||||
USERS: "USERS"i
|
||||
DROP: "DROP"i
|
||||
USER: "USER"i
|
||||
ALTER: "ALTER"i
|
||||
ACTIVE: "ACTIVE"i
|
||||
PASSWORD: "PASSWORD"i
|
||||
DATASETS: "DATASETS"i
|
||||
OF: "OF"i
|
||||
AGENTS: "AGENTS"i
|
||||
ROLE: "ROLE"i
|
||||
ROLES: "ROLES"i
|
||||
DESCRIPTION: "DESCRIPTION"i
|
||||
GRANT: "GRANT"i
|
||||
REVOKE: "REVOKE"i
|
||||
ALL: "ALL"i
|
||||
PERMISSION: "PERMISSION"i
|
||||
TO: "TO"i
|
||||
FROM: "FROM"i
|
||||
FOR: "FOR"i
|
||||
RESOURCES: "RESOURCES"i
|
||||
ON: "ON"i
|
||||
SET: "SET"i
|
||||
VERSION: "VERSION"i
|
||||
|
||||
list_services: LIST SERVICES ";"
|
||||
show_service: SHOW SERVICE NUMBER ";"
|
||||
startup_service: STARTUP SERVICE NUMBER ";"
|
||||
shutdown_service: SHUTDOWN SERVICE NUMBER ";"
|
||||
restart_service: RESTART SERVICE NUMBER ";"
|
||||
|
||||
list_users: LIST USERS ";"
|
||||
drop_user: DROP USER quoted_string ";"
|
||||
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
|
||||
show_user: SHOW USER quoted_string ";"
|
||||
create_user: CREATE USER quoted_string quoted_string ";"
|
||||
activate_user: ALTER USER ACTIVE quoted_string status ";"
|
||||
|
||||
list_datasets: LIST DATASETS OF quoted_string ";"
|
||||
list_agents: LIST AGENTS OF quoted_string ";"
|
||||
|
||||
create_role: CREATE ROLE identifier [DESCRIPTION quoted_string] ";"
|
||||
drop_role: DROP ROLE identifier ";"
|
||||
alter_role: ALTER ROLE identifier SET DESCRIPTION quoted_string ";"
|
||||
list_roles: LIST ROLES ";"
|
||||
show_role: SHOW ROLE identifier ";"
|
||||
|
||||
grant_permission: GRANT action_list ON identifier TO ROLE identifier ";"
|
||||
revoke_permission: REVOKE action_list ON identifier FROM ROLE identifier ";"
|
||||
alter_user_role: ALTER USER quoted_string SET ROLE identifier ";"
|
||||
show_user_permission: SHOW USER PERMISSION quoted_string ";"
|
||||
|
||||
show_version: SHOW VERSION ";"
|
||||
|
||||
action_list: identifier ("," identifier)*
|
||||
|
||||
identifier: WORD
|
||||
quoted_string: QUOTED_STRING
|
||||
status: WORD
|
||||
|
||||
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
|
||||
WORD: /[a-zA-Z0-9_\-\.]+/
|
||||
NUMBER: /[0-9]+/
|
||||
|
||||
%import common.WS
|
||||
%ignore WS
|
||||
"""
|
||||
|
||||
|
||||
class AdminTransformer(Transformer):
|
||||
def start(self, items):
|
||||
return items[0]
|
||||
|
||||
def command(self, items):
|
||||
return items[0]
|
||||
|
||||
def list_services(self, items):
|
||||
result = {"type": "list_services"}
|
||||
return result
|
||||
|
||||
def show_service(self, items):
|
||||
service_id = int(items[2])
|
||||
return {"type": "show_service", "number": service_id}
|
||||
|
||||
def startup_service(self, items):
|
||||
service_id = int(items[2])
|
||||
return {"type": "startup_service", "number": service_id}
|
||||
|
||||
def shutdown_service(self, items):
|
||||
service_id = int(items[2])
|
||||
return {"type": "shutdown_service", "number": service_id}
|
||||
|
||||
def restart_service(self, items):
|
||||
service_id = int(items[2])
|
||||
return {"type": "restart_service", "number": service_id}
|
||||
|
||||
def list_users(self, items):
|
||||
return {"type": "list_users"}
|
||||
|
||||
def show_user(self, items):
|
||||
user_name = items[2]
|
||||
return {"type": "show_user", "user_name": user_name}
|
||||
|
||||
def drop_user(self, items):
|
||||
user_name = items[2]
|
||||
return {"type": "drop_user", "user_name": user_name}
|
||||
|
||||
def alter_user(self, items):
|
||||
user_name = items[3]
|
||||
new_password = items[4]
|
||||
return {"type": "alter_user", "user_name": user_name, "password": new_password}
|
||||
|
||||
def create_user(self, items):
|
||||
user_name = items[2]
|
||||
password = items[3]
|
||||
return {"type": "create_user", "user_name": user_name, "password": password, "role": "user"}
|
||||
|
||||
def activate_user(self, items):
|
||||
user_name = items[3]
|
||||
activate_status = items[4]
|
||||
return {"type": "activate_user", "activate_status": activate_status, "user_name": user_name}
|
||||
|
||||
def list_datasets(self, items):
|
||||
user_name = items[3]
|
||||
return {"type": "list_datasets", "user_name": user_name}
|
||||
|
||||
def list_agents(self, items):
|
||||
user_name = items[3]
|
||||
return {"type": "list_agents", "user_name": user_name}
|
||||
|
||||
def create_role(self, items):
|
||||
role_name = items[2]
|
||||
if len(items) > 4:
|
||||
description = items[4]
|
||||
return {"type": "create_role", "role_name": role_name, "description": description}
|
||||
else:
|
||||
return {"type": "create_role", "role_name": role_name}
|
||||
|
||||
def drop_role(self, items):
|
||||
role_name = items[2]
|
||||
return {"type": "drop_role", "role_name": role_name}
|
||||
|
||||
def alter_role(self, items):
|
||||
role_name = items[2]
|
||||
description = items[5]
|
||||
return {"type": "alter_role", "role_name": role_name, "description": description}
|
||||
|
||||
def list_roles(self, items):
|
||||
return {"type": "list_roles"}
|
||||
|
||||
def show_role(self, items):
|
||||
role_name = items[2]
|
||||
return {"type": "show_role", "role_name": role_name}
|
||||
|
||||
def grant_permission(self, items):
|
||||
action_list = items[1]
|
||||
resource = items[3]
|
||||
role_name = items[6]
|
||||
return {"type": "grant_permission", "role_name": role_name, "resource": resource, "actions": action_list}
|
||||
|
||||
def revoke_permission(self, items):
|
||||
action_list = items[1]
|
||||
resource = items[3]
|
||||
role_name = items[6]
|
||||
return {"type": "revoke_permission", "role_name": role_name, "resource": resource, "actions": action_list}
|
||||
|
||||
def alter_user_role(self, items):
|
||||
user_name = items[2]
|
||||
role_name = items[5]
|
||||
return {"type": "alter_user_role", "user_name": user_name, "role_name": role_name}
|
||||
|
||||
def show_user_permission(self, items):
|
||||
user_name = items[3]
|
||||
return {"type": "show_user_permission", "user_name": user_name}
|
||||
|
||||
def show_version(self, items):
|
||||
return {"type": "show_version"}
|
||||
|
||||
def action_list(self, items):
|
||||
return items
|
||||
|
||||
def meta_command(self, items):
|
||||
command_name = str(items[0]).lower()
|
||||
args = items[1:] if len(items) > 1 else []
|
||||
|
||||
# handle quoted parameter
|
||||
parsed_args = []
|
||||
for arg in args:
|
||||
if hasattr(arg, "value"):
|
||||
parsed_args.append(arg.value)
|
||||
else:
|
||||
parsed_args.append(str(arg))
|
||||
|
||||
return {"type": "meta", "command": command_name, "args": parsed_args}
|
||||
|
||||
def meta_command_name(self, items):
|
||||
return items[0]
|
||||
|
||||
def meta_args(self, items):
|
||||
return items
|
||||
|
||||
|
||||
def encrypt(input_string):
|
||||
pub = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----"
|
||||
pub_key = RSA.importKey(pub)
|
||||
cipher = Cipher_pkcs1_v1_5.new(pub_key)
|
||||
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode("utf-8")))
|
||||
return base64.b64encode(cipher_text).decode("utf-8")
|
||||
|
||||
|
||||
def encode_to_base64(input_string):
|
||||
base64_encoded = base64.b64encode(input_string.encode("utf-8"))
|
||||
return base64_encoded.decode("utf-8")
|
||||
|
||||
|
||||
class AdminCLI(Cmd):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.parser = Lark(GRAMMAR, start="start", parser="lalr", transformer=AdminTransformer())
|
||||
self.command_history = []
|
||||
self.is_interactive = False
|
||||
self.admin_account = "admin@ragflow.io"
|
||||
self.admin_password: str = "admin"
|
||||
self.session = requests.Session()
|
||||
self.access_token: str = ""
|
||||
self.host: str = ""
|
||||
self.port: int = 0
|
||||
|
||||
intro = r"""Type "\h" for help."""
|
||||
prompt = "admin> "
|
||||
|
||||
def onecmd(self, command: str) -> bool:
|
||||
try:
|
||||
result = self.parse_command(command)
|
||||
|
||||
if isinstance(result, dict):
|
||||
if "type" in result and result.get("type") == "empty":
|
||||
return False
|
||||
|
||||
self.execute_command(result)
|
||||
|
||||
if isinstance(result, Tree):
|
||||
return False
|
||||
|
||||
if result.get("type") == "meta" and result.get("command") in ["q", "quit", "exit"]:
|
||||
return True
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nUse '\\q' to quit")
|
||||
except EOFError:
|
||||
print("\nGoodbye!")
|
||||
return True
|
||||
return False
|
||||
|
||||
def emptyline(self) -> bool:
|
||||
return False
|
||||
|
||||
def default(self, line: str) -> bool:
|
||||
return self.onecmd(line)
|
||||
|
||||
def parse_command(self, command_str: str) -> dict[str, str]:
|
||||
if not command_str.strip():
|
||||
return {"type": "empty"}
|
||||
|
||||
self.command_history.append(command_str)
|
||||
|
||||
try:
|
||||
result = self.parser.parse(command_str)
|
||||
return result
|
||||
except Exception as e:
|
||||
return {"type": "error", "message": f"Parse error: {str(e)}"}
|
||||
|
||||
def verify_admin(self, arguments: dict, single_command: bool):
|
||||
self.host = arguments["host"]
|
||||
self.port = arguments["port"]
|
||||
print("Attempt to access server for admin login")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/login"
|
||||
|
||||
attempt_count = 3
|
||||
if single_command:
|
||||
attempt_count = 1
|
||||
|
||||
try_count = 0
|
||||
while True:
|
||||
try_count += 1
|
||||
if try_count > attempt_count:
|
||||
return False
|
||||
|
||||
if single_command:
|
||||
admin_passwd = arguments["password"]
|
||||
else:
|
||||
admin_passwd = getpass.getpass(f"password for {self.admin_account}: ").strip()
|
||||
try:
|
||||
self.admin_password = encrypt(admin_passwd)
|
||||
response = self.session.post(url, json={"email": self.admin_account, "password": self.admin_password})
|
||||
if response.status_code == 200:
|
||||
res_json = response.json()
|
||||
error_code = res_json.get("code", -1)
|
||||
if error_code == 0:
|
||||
self.session.headers.update({"Content-Type": "application/json", "Authorization": response.headers["Authorization"], "User-Agent": "RAGFlow-CLI/0.23.1"})
|
||||
print("Authentication successful.")
|
||||
return True
|
||||
else:
|
||||
error_message = res_json.get("message", "Unknown error")
|
||||
print(f"Authentication failed: {error_message}, try again")
|
||||
continue
|
||||
else:
|
||||
print(f"Bad response,status: {response.status_code}, password is wrong")
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
print("Can't access server for admin login (connection failed)")
|
||||
|
||||
def _format_service_detail_table(self, data):
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
if not all([isinstance(v, list) for v in data.values()]):
|
||||
# normal table
|
||||
return data
|
||||
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
|
||||
task_executor_list = []
|
||||
for k, v in data.items():
|
||||
# display latest status
|
||||
heartbeats = sorted(v, key=lambda x: x["now"], reverse=True)
|
||||
task_executor_list.append(
|
||||
{
|
||||
"task_executor_name": k,
|
||||
**heartbeats[0],
|
||||
}
|
||||
if heartbeats
|
||||
else {"task_executor_name": k}
|
||||
)
|
||||
return task_executor_list
|
||||
|
||||
def _print_table_simple(self, data):
|
||||
if not data:
|
||||
print("No data to print")
|
||||
return
|
||||
if isinstance(data, dict):
|
||||
# handle single row data
|
||||
data = [data]
|
||||
|
||||
columns = list(set().union(*(d.keys() for d in data)))
|
||||
columns.sort()
|
||||
col_widths = {}
|
||||
|
||||
def get_string_width(text):
|
||||
half_width_chars = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\t\n\r"
|
||||
width = 0
|
||||
for char in text:
|
||||
if char in half_width_chars:
|
||||
width += 1
|
||||
else:
|
||||
width += 2
|
||||
return width
|
||||
|
||||
for col in columns:
|
||||
max_width = get_string_width(str(col))
|
||||
for item in data:
|
||||
value_len = get_string_width(str(item.get(col, "")))
|
||||
if value_len > max_width:
|
||||
max_width = value_len
|
||||
col_widths[col] = max(2, max_width)
|
||||
|
||||
# Generate delimiter
|
||||
separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
|
||||
|
||||
# Print header
|
||||
print(separator)
|
||||
header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
|
||||
print(header)
|
||||
print(separator)
|
||||
|
||||
# Print data
|
||||
for item in data:
|
||||
row = "|"
|
||||
for col in columns:
|
||||
value = str(item.get(col, ""))
|
||||
if get_string_width(value) > col_widths[col]:
|
||||
value = value[: col_widths[col] - 3] + "..."
|
||||
row += f" {value:<{col_widths[col] - (get_string_width(value) - len(value))}} |"
|
||||
print(row)
|
||||
|
||||
print(separator)
|
||||
|
||||
def run_interactive(self):
|
||||
self.is_interactive = True
|
||||
print("RAGFlow Admin command line interface - Type '\\?' for help, '\\q' to quit")
|
||||
|
||||
while True:
|
||||
try:
|
||||
command = input("admin> ").strip()
|
||||
if not command:
|
||||
continue
|
||||
|
||||
print(f"command: {command}")
|
||||
result = self.parse_command(command)
|
||||
self.execute_command(result)
|
||||
|
||||
if isinstance(result, Tree):
|
||||
continue
|
||||
|
||||
if result.get("type") == "meta" and result.get("command") in ["q", "quit", "exit"]:
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nUse '\\q' to quit")
|
||||
except EOFError:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
|
||||
def run_single_command(self, command: str):
|
||||
result = self.parse_command(command)
|
||||
self.execute_command(result)
|
||||
|
||||
def parse_connection_args(self, args: List[str]) -> Dict[str, Any]:
|
||||
parser = argparse.ArgumentParser(description="Admin CLI Client", add_help=False)
|
||||
parser.add_argument("-h", "--host", default="localhost", help="Admin service host")
|
||||
parser.add_argument("-p", "--port", type=int, default=9381, help="Admin service port")
|
||||
parser.add_argument("-w", "--password", default="admin", type=str, help="Superuser password")
|
||||
parser.add_argument("command", nargs="?", help="Single command")
|
||||
try:
|
||||
parsed_args, remaining_args = parser.parse_known_args(args)
|
||||
if remaining_args:
|
||||
command = remaining_args[0]
|
||||
return {"host": parsed_args.host, "port": parsed_args.port, "password": parsed_args.password, "command": command}
|
||||
else:
|
||||
return {
|
||||
"host": parsed_args.host,
|
||||
"port": parsed_args.port,
|
||||
}
|
||||
except SystemExit:
|
||||
return {"error": "Invalid connection arguments"}
|
||||
|
||||
def execute_command(self, parsed_command: Dict[str, Any]):
|
||||
command_dict: dict
|
||||
if isinstance(parsed_command, Tree):
|
||||
command_dict = parsed_command.children[0]
|
||||
else:
|
||||
if parsed_command["type"] == "error":
|
||||
print(f"Error: {parsed_command['message']}")
|
||||
return
|
||||
else:
|
||||
command_dict = parsed_command
|
||||
|
||||
# print(f"Parsed command: {command_dict}")
|
||||
|
||||
command_type = command_dict["type"]
|
||||
|
||||
match command_type:
|
||||
case "list_services":
|
||||
self._handle_list_services(command_dict)
|
||||
case "show_service":
|
||||
self._handle_show_service(command_dict)
|
||||
case "restart_service":
|
||||
self._handle_restart_service(command_dict)
|
||||
case "shutdown_service":
|
||||
self._handle_shutdown_service(command_dict)
|
||||
case "startup_service":
|
||||
self._handle_startup_service(command_dict)
|
||||
case "list_users":
|
||||
self._handle_list_users(command_dict)
|
||||
case "show_user":
|
||||
self._handle_show_user(command_dict)
|
||||
case "drop_user":
|
||||
self._handle_drop_user(command_dict)
|
||||
case "alter_user":
|
||||
self._handle_alter_user(command_dict)
|
||||
case "create_user":
|
||||
self._handle_create_user(command_dict)
|
||||
case "activate_user":
|
||||
self._handle_activate_user(command_dict)
|
||||
case "list_datasets":
|
||||
self._handle_list_datasets(command_dict)
|
||||
case "list_agents":
|
||||
self._handle_list_agents(command_dict)
|
||||
case "create_role":
|
||||
self._create_role(command_dict)
|
||||
case "drop_role":
|
||||
self._drop_role(command_dict)
|
||||
case "alter_role":
|
||||
self._alter_role(command_dict)
|
||||
case "list_roles":
|
||||
self._list_roles(command_dict)
|
||||
case "show_role":
|
||||
self._show_role(command_dict)
|
||||
case "grant_permission":
|
||||
self._grant_permission(command_dict)
|
||||
case "revoke_permission":
|
||||
self._revoke_permission(command_dict)
|
||||
case "alter_user_role":
|
||||
self._alter_user_role(command_dict)
|
||||
case "show_user_permission":
|
||||
self._show_user_permission(command_dict)
|
||||
case "show_version":
|
||||
self._show_version(command_dict)
|
||||
case "meta":
|
||||
self._handle_meta_command(command_dict)
|
||||
case _:
|
||||
print(f"Command '{command_type}' would be executed with API")
|
||||
|
||||
def _handle_list_services(self, command):
|
||||
print("Listing all services")
|
||||
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/services"
|
||||
response = self.session.get(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json["data"])
|
||||
else:
|
||||
print(f"Fail to get all services, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_show_service(self, command):
|
||||
service_id: int = command["number"]
|
||||
print(f"Showing service: {service_id}")
|
||||
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/services/{service_id}"
|
||||
response = self.session.get(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
res_data = res_json["data"]
|
||||
if "status" in res_data and res_data["status"] == "alive":
|
||||
print(f"Service {res_data['service_name']} is alive, ")
|
||||
if isinstance(res_data["message"], str):
|
||||
print(res_data["message"])
|
||||
else:
|
||||
data = self._format_service_detail_table(res_data["message"])
|
||||
self._print_table_simple(data)
|
||||
else:
|
||||
print(f"Service {res_data['service_name']} is down, {res_data['message']}")
|
||||
else:
|
||||
print(f"Fail to show service, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_restart_service(self, command):
|
||||
service_id: int = command["number"]
|
||||
print(f"Restart service {service_id}")
|
||||
|
||||
def _handle_shutdown_service(self, command):
|
||||
service_id: int = command["number"]
|
||||
print(f"Shutdown service {service_id}")
|
||||
|
||||
def _handle_startup_service(self, command):
|
||||
service_id: int = command["number"]
|
||||
print(f"Startup service {service_id}")
|
||||
|
||||
def _handle_list_users(self, command):
|
||||
print("Listing all users")
|
||||
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/users"
|
||||
response = self.session.get(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json["data"])
|
||||
else:
|
||||
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_show_user(self, command):
|
||||
username_tree: Tree = command["user_name"]
|
||||
user_name: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Showing user: {user_name}")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}"
|
||||
response = self.session.get(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
table_data = res_json["data"]
|
||||
table_data.pop("avatar")
|
||||
self._print_table_simple(table_data)
|
||||
else:
|
||||
print(f"Fail to get user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_drop_user(self, command):
|
||||
username_tree: Tree = command["user_name"]
|
||||
user_name: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Drop user: {user_name}")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}"
|
||||
response = self.session.delete(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
print(res_json["message"])
|
||||
else:
|
||||
print(f"Fail to drop user, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_alter_user(self, command):
|
||||
user_name_tree: Tree = command["user_name"]
|
||||
user_name: str = user_name_tree.children[0].strip("'\"")
|
||||
password_tree: Tree = command["password"]
|
||||
password: str = password_tree.children[0].strip("'\"")
|
||||
print(f"Alter user: {user_name}, password: ******")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/password"
|
||||
response = self.session.put(url, json={"new_password": encrypt(password)})
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
print(res_json["message"])
|
||||
else:
|
||||
print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_create_user(self, command):
|
||||
user_name_tree: Tree = command["user_name"]
|
||||
user_name: str = user_name_tree.children[0].strip("'\"")
|
||||
password_tree: Tree = command["password"]
|
||||
password: str = password_tree.children[0].strip("'\"")
|
||||
role: str = command["role"]
|
||||
print(f"Create user: {user_name}, password: ******, role: {role}")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/users"
|
||||
response = self.session.post(url, json={"user_name": user_name, "password": encrypt(password), "role": role})
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json["data"])
|
||||
else:
|
||||
print(f"Fail to create user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_activate_user(self, command):
|
||||
user_name_tree: Tree = command["user_name"]
|
||||
user_name: str = user_name_tree.children[0].strip("'\"")
|
||||
activate_tree: Tree = command["activate_status"]
|
||||
activate_status: str = activate_tree.children[0].strip("'\"")
|
||||
if activate_status.lower() in ["on", "off"]:
|
||||
print(f"Alter user {user_name} activate status, turn {activate_status.lower()}.")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/activate"
|
||||
response = self.session.put(url, json={"activate_status": activate_status})
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
print(res_json["message"])
|
||||
else:
|
||||
print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
|
||||
else:
|
||||
print(f"Unknown activate status: {activate_status}.")
|
||||
|
||||
def _handle_list_datasets(self, command):
|
||||
username_tree: Tree = command["user_name"]
|
||||
user_name: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Listing all datasets of user: {user_name}")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/datasets"
|
||||
response = self.session.get(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
table_data = res_json["data"]
|
||||
for t in table_data:
|
||||
t.pop("avatar")
|
||||
self._print_table_simple(table_data)
|
||||
else:
|
||||
print(f"Fail to get all datasets of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_list_agents(self, command):
|
||||
username_tree: Tree = command["user_name"]
|
||||
user_name: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Listing all agents of user: {user_name}")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/agents"
|
||||
response = self.session.get(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
table_data = res_json["data"]
|
||||
for t in table_data:
|
||||
t.pop("avatar")
|
||||
self._print_table_simple(table_data)
|
||||
else:
|
||||
print(f"Fail to get all agents of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _create_role(self, command):
|
||||
role_name_tree: Tree = command["role_name"]
|
||||
role_name: str = role_name_tree.children[0].strip("'\"")
|
||||
desc_str: str = ""
|
||||
if "description" in command:
|
||||
desc_tree: Tree = command["description"]
|
||||
desc_str = desc_tree.children[0].strip("'\"")
|
||||
|
||||
print(f"create role name: {role_name}, description: {desc_str}")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/roles"
|
||||
response = self.session.post(url, json={"role_name": role_name, "description": desc_str})
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json["data"])
|
||||
else:
|
||||
print(f"Fail to create role {role_name}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _drop_role(self, command):
|
||||
role_name_tree: Tree = command["role_name"]
|
||||
role_name: str = role_name_tree.children[0].strip("'\"")
|
||||
print(f"drop role name: {role_name}")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}"
|
||||
response = self.session.delete(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json["data"])
|
||||
else:
|
||||
print(f"Fail to drop role {role_name}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _alter_role(self, command):
|
||||
role_name_tree: Tree = command["role_name"]
|
||||
role_name: str = role_name_tree.children[0].strip("'\"")
|
||||
desc_tree: Tree = command["description"]
|
||||
desc_str: str = desc_tree.children[0].strip("'\"")
|
||||
|
||||
print(f"alter role name: {role_name}, description: {desc_str}")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}"
|
||||
response = self.session.put(url, json={"description": desc_str})
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json["data"])
|
||||
else:
|
||||
print(f"Fail to update role {role_name} with description: {desc_str}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _list_roles(self, command):
|
||||
print("Listing all roles")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/roles"
|
||||
response = self.session.get(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json["data"])
|
||||
else:
|
||||
print(f"Fail to list roles, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _show_role(self, command):
|
||||
role_name_tree: Tree = command["role_name"]
|
||||
role_name: str = role_name_tree.children[0].strip("'\"")
|
||||
print(f"show role: {role_name}")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}/permission"
|
||||
response = self.session.get(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json["data"])
|
||||
else:
|
||||
print(f"Fail to list roles, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _grant_permission(self, command):
|
||||
role_name_tree: Tree = command["role_name"]
|
||||
role_name_str: str = role_name_tree.children[0].strip("'\"")
|
||||
resource_tree: Tree = command["resource"]
|
||||
resource_str: str = resource_tree.children[0].strip("'\"")
|
||||
action_tree_list: list = command["actions"]
|
||||
actions: list = []
|
||||
for action_tree in action_tree_list:
|
||||
action_str: str = action_tree.children[0].strip("'\"")
|
||||
actions.append(action_str)
|
||||
print(f"grant role_name: {role_name_str}, resource: {resource_str}, actions: {actions}")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name_str}/permission"
|
||||
response = self.session.post(url, json={"actions": actions, "resource": resource_str})
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json["data"])
|
||||
else:
|
||||
print(f"Fail to grant role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _revoke_permission(self, command):
|
||||
role_name_tree: Tree = command["role_name"]
|
||||
role_name_str: str = role_name_tree.children[0].strip("'\"")
|
||||
resource_tree: Tree = command["resource"]
|
||||
resource_str: str = resource_tree.children[0].strip("'\"")
|
||||
action_tree_list: list = command["actions"]
|
||||
actions: list = []
|
||||
for action_tree in action_tree_list:
|
||||
action_str: str = action_tree.children[0].strip("'\"")
|
||||
actions.append(action_str)
|
||||
print(f"revoke role_name: {role_name_str}, resource: {resource_str}, actions: {actions}")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name_str}/permission"
|
||||
response = self.session.delete(url, json={"actions": actions, "resource": resource_str})
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json["data"])
|
||||
else:
|
||||
print(f"Fail to revoke role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _alter_user_role(self, command):
|
||||
role_name_tree: Tree = command["role_name"]
|
||||
role_name_str: str = role_name_tree.children[0].strip("'\"")
|
||||
user_name_tree: Tree = command["user_name"]
|
||||
user_name_str: str = user_name_tree.children[0].strip("'\"")
|
||||
print(f"alter_user_role user_name: {user_name_str}, role_name: {role_name_str}")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name_str}/role"
|
||||
response = self.session.put(url, json={"role_name": role_name_str})
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json["data"])
|
||||
else:
|
||||
print(f"Fail to alter user: {user_name_str} to role {role_name_str}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _show_user_permission(self, command):
|
||||
user_name_tree: Tree = command["user_name"]
|
||||
user_name_str: str = user_name_tree.children[0].strip("'\"")
|
||||
print(f"show_user_permission user_name: {user_name_str}")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name_str}/permission"
|
||||
response = self.session.get(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json["data"])
|
||||
else:
|
||||
print(f"Fail to show user: {user_name_str} permission, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _show_version(self, command):
|
||||
print("show_version")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/version"
|
||||
response = self.session.get(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json["data"])
|
||||
else:
|
||||
print(f"Fail to show version, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_meta_command(self, command):
|
||||
meta_command = command["command"]
|
||||
args = command.get("args", [])
|
||||
|
||||
if meta_command in ["?", "h", "help"]:
|
||||
self.show_help()
|
||||
elif meta_command in ["q", "quit", "exit"]:
|
||||
print("Goodbye!")
|
||||
else:
|
||||
print(f"Meta command '{meta_command}' with args {args}")
|
||||
|
||||
def show_help(self):
|
||||
"""Help info"""
|
||||
help_text = """
|
||||
Commands:
|
||||
LIST SERVICES
|
||||
SHOW SERVICE <service>
|
||||
STARTUP SERVICE <service>
|
||||
SHUTDOWN SERVICE <service>
|
||||
RESTART SERVICE <service>
|
||||
LIST USERS
|
||||
SHOW USER <user>
|
||||
DROP USER <user>
|
||||
CREATE USER <user> <password>
|
||||
ALTER USER PASSWORD <user> <new_password>
|
||||
ALTER USER ACTIVE <user> <on/off>
|
||||
LIST DATASETS OF <user>
|
||||
LIST AGENTS OF <user>
|
||||
|
||||
Meta Commands:
|
||||
\\?, \\h, \\help Show this help
|
||||
\\q, \\quit, \\exit Quit the CLI
|
||||
"""
|
||||
print(help_text)
|
||||
|
||||
|
||||
def main():
|
||||
import sys
|
||||
|
||||
cli = AdminCLI()
|
||||
|
||||
args = cli.parse_connection_args(sys.argv)
|
||||
if "error" in args:
|
||||
print("Error: Invalid connection arguments")
|
||||
return
|
||||
|
||||
if "command" in args:
|
||||
if "password" not in args:
|
||||
print("Error: password is missing")
|
||||
return
|
||||
if cli.verify_admin(args, single_command=True):
|
||||
command: str = args["command"]
|
||||
# print(f"Run single command: {command}")
|
||||
cli.run_single_command(command)
|
||||
else:
|
||||
if cli.verify_admin(args, single_command=False):
|
||||
print(r"""
|
||||
____ ___ ______________ ___ __ _
|
||||
/ __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___
|
||||
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \
|
||||
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /
|
||||
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
|
||||
""")
|
||||
cli.cmdloop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
166
admin/client/http_client.py
Normal file
166
admin/client/http_client.py
Normal file
@ -0,0 +1,166 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import time
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
# from requests.sessions import HTTPAdapter
|
||||
|
||||
|
||||
class HttpClient:
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 9381,
|
||||
api_version: str = "v1",
|
||||
api_key: Optional[str] = None,
|
||||
connect_timeout: float = 5.0,
|
||||
read_timeout: float = 60.0,
|
||||
verify_ssl: bool = False,
|
||||
) -> None:
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key
|
||||
self.login_token: str | None = None
|
||||
self.connect_timeout = connect_timeout
|
||||
self.read_timeout = read_timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
|
||||
def api_base(self) -> str:
|
||||
return f"{self.host}:{self.port}/api/{self.api_version}"
|
||||
|
||||
def non_api_base(self) -> str:
|
||||
return f"{self.host}:{self.port}/{self.api_version}"
|
||||
|
||||
def build_url(self, path: str, use_api_base: bool = True) -> str:
|
||||
base = self.api_base() if use_api_base else self.non_api_base()
|
||||
if self.verify_ssl:
|
||||
return f"https://{base}/{path.lstrip('/')}"
|
||||
else:
|
||||
return f"http://{base}/{path.lstrip('/')}"
|
||||
|
||||
def _headers(self, auth_kind: Optional[str], extra: Optional[Dict[str, str]]) -> Dict[str, str]:
|
||||
headers = {}
|
||||
if auth_kind == "api" and self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
elif auth_kind == "web" and self.login_token:
|
||||
headers["Authorization"] = self.login_token
|
||||
elif auth_kind == "admin" and self.login_token:
|
||||
headers["Authorization"] = self.login_token
|
||||
else:
|
||||
pass
|
||||
if extra:
|
||||
headers.update(extra)
|
||||
return headers
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
use_api_base: bool = True,
|
||||
auth_kind: Optional[str] = "api",
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
json_body: Optional[Dict[str, Any]] = None,
|
||||
data: Any = None,
|
||||
files: Any = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
stream: bool = False,
|
||||
iterations: int = 1,
|
||||
) -> requests.Response | dict:
|
||||
url = self.build_url(path, use_api_base=use_api_base)
|
||||
merged_headers = self._headers(auth_kind, headers)
|
||||
# timeout: Tuple[float, float] = (self.connect_timeout, self.read_timeout)
|
||||
session = requests.Session()
|
||||
# adapter = HTTPAdapter(pool_connections=100, pool_maxsize=100)
|
||||
# session.mount("http://", adapter)
|
||||
if iterations > 1:
|
||||
response_list = []
|
||||
total_duration = 0.0
|
||||
for _ in range(iterations):
|
||||
start_time = time.perf_counter()
|
||||
response = session.get(url, headers=merged_headers, json=json_body, data=data, stream=stream)
|
||||
# response = requests.request(
|
||||
# method=method,
|
||||
# url=url,
|
||||
# headers=merged_headers,
|
||||
# json=json_body,
|
||||
# data=data,
|
||||
# files=files,
|
||||
# params=params,
|
||||
# timeout=timeout,
|
||||
# stream=stream,
|
||||
# verify=self.verify_ssl,
|
||||
# )
|
||||
end_time = time.perf_counter()
|
||||
total_duration += end_time - start_time
|
||||
response_list.append(response)
|
||||
return {"duration": total_duration, "response_list": response_list}
|
||||
else:
|
||||
return session.get(url, headers=merged_headers, json=json_body, data=data, stream=stream)
|
||||
# return requests.request(
|
||||
# method=method,
|
||||
# url=url,
|
||||
# headers=merged_headers,
|
||||
# json=json_body,
|
||||
# data=data,
|
||||
# files=files,
|
||||
# params=params,
|
||||
# timeout=timeout,
|
||||
# stream=stream,
|
||||
# verify=self.verify_ssl,
|
||||
# )
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
use_api_base: bool = True,
|
||||
auth_kind: Optional[str] = "api",
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
json_body: Optional[Dict[str, Any]] = None,
|
||||
data: Any = None,
|
||||
files: Any = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
stream: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
response = self.request(
|
||||
method,
|
||||
path,
|
||||
use_api_base=use_api_base,
|
||||
auth_kind=auth_kind,
|
||||
headers=headers,
|
||||
json_body=json_body,
|
||||
data=data,
|
||||
files=files,
|
||||
params=params,
|
||||
stream=stream,
|
||||
)
|
||||
try:
|
||||
return response.json()
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Non-JSON response from {path}: {exc}") from exc
|
||||
|
||||
@staticmethod
|
||||
def parse_json_bytes(raw: bytes) -> Dict[str, Any]:
|
||||
try:
|
||||
return json.loads(raw.decode("utf-8"))
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Invalid JSON payload: {exc}") from exc
|
||||
623
admin/client/parser.py
Normal file
623
admin/client/parser.py
Normal file
@ -0,0 +1,623 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from lark import Transformer
|
||||
|
||||
GRAMMAR = r"""
|
||||
start: command
|
||||
|
||||
command: sql_command | meta_command
|
||||
|
||||
sql_command: login_user
|
||||
| ping_server
|
||||
| list_services
|
||||
| show_service
|
||||
| startup_service
|
||||
| shutdown_service
|
||||
| restart_service
|
||||
| register_user
|
||||
| list_users
|
||||
| show_user
|
||||
| drop_user
|
||||
| alter_user
|
||||
| create_user
|
||||
| activate_user
|
||||
| list_datasets
|
||||
| list_agents
|
||||
| create_role
|
||||
| drop_role
|
||||
| alter_role
|
||||
| list_roles
|
||||
| show_role
|
||||
| grant_permission
|
||||
| revoke_permission
|
||||
| alter_user_role
|
||||
| show_user_permission
|
||||
| show_version
|
||||
| grant_admin
|
||||
| revoke_admin
|
||||
| set_variable
|
||||
| show_variable
|
||||
| list_variables
|
||||
| list_configs
|
||||
| list_environments
|
||||
| generate_key
|
||||
| list_keys
|
||||
| drop_key
|
||||
| show_current_user
|
||||
| set_default_llm
|
||||
| set_default_vlm
|
||||
| set_default_embedding
|
||||
| set_default_reranker
|
||||
| set_default_asr
|
||||
| set_default_tts
|
||||
| reset_default_llm
|
||||
| reset_default_vlm
|
||||
| reset_default_embedding
|
||||
| reset_default_reranker
|
||||
| reset_default_asr
|
||||
| reset_default_tts
|
||||
| create_model_provider
|
||||
| drop_model_provider
|
||||
| create_user_dataset_with_parser
|
||||
| create_user_dataset_with_pipeline
|
||||
| drop_user_dataset
|
||||
| list_user_datasets
|
||||
| list_user_dataset_files
|
||||
| list_user_agents
|
||||
| list_user_chats
|
||||
| create_user_chat
|
||||
| drop_user_chat
|
||||
| list_user_model_providers
|
||||
| list_user_default_models
|
||||
| parse_dataset_docs
|
||||
| parse_dataset_sync
|
||||
| parse_dataset_async
|
||||
| import_docs_into_dataset
|
||||
| search_on_datasets
|
||||
| benchmark
|
||||
|
||||
// meta command definition
|
||||
meta_command: "\\" meta_command_name [meta_args]
|
||||
|
||||
meta_command_name: /[a-zA-Z?]+/
|
||||
meta_args: (meta_arg)+
|
||||
|
||||
meta_arg: /[^\\s"']+/ | quoted_string
|
||||
|
||||
// command definition
|
||||
|
||||
LOGIN: "LOGIN"i
|
||||
REGISTER: "REGISTER"i
|
||||
LIST: "LIST"i
|
||||
SERVICES: "SERVICES"i
|
||||
SHOW: "SHOW"i
|
||||
CREATE: "CREATE"i
|
||||
SERVICE: "SERVICE"i
|
||||
SHUTDOWN: "SHUTDOWN"i
|
||||
STARTUP: "STARTUP"i
|
||||
RESTART: "RESTART"i
|
||||
USERS: "USERS"i
|
||||
DROP: "DROP"i
|
||||
USER: "USER"i
|
||||
ALTER: "ALTER"i
|
||||
ACTIVE: "ACTIVE"i
|
||||
ADMIN: "ADMIN"i
|
||||
PASSWORD: "PASSWORD"i
|
||||
DATASET: "DATASET"i
|
||||
DATASETS: "DATASETS"i
|
||||
OF: "OF"i
|
||||
AGENTS: "AGENTS"i
|
||||
ROLE: "ROLE"i
|
||||
ROLES: "ROLES"i
|
||||
DESCRIPTION: "DESCRIPTION"i
|
||||
GRANT: "GRANT"i
|
||||
REVOKE: "REVOKE"i
|
||||
ALL: "ALL"i
|
||||
PERMISSION: "PERMISSION"i
|
||||
TO: "TO"i
|
||||
FROM: "FROM"i
|
||||
FOR: "FOR"i
|
||||
RESOURCES: "RESOURCES"i
|
||||
ON: "ON"i
|
||||
SET: "SET"i
|
||||
RESET: "RESET"i
|
||||
VERSION: "VERSION"i
|
||||
VAR: "VAR"i
|
||||
VARS: "VARS"i
|
||||
CONFIGS: "CONFIGS"i
|
||||
ENVS: "ENVS"i
|
||||
KEY: "KEY"i
|
||||
KEYS: "KEYS"i
|
||||
GENERATE: "GENERATE"i
|
||||
MODEL: "MODEL"i
|
||||
MODELS: "MODELS"i
|
||||
PROVIDER: "PROVIDER"i
|
||||
PROVIDERS: "PROVIDERS"i
|
||||
DEFAULT: "DEFAULT"i
|
||||
CHATS: "CHATS"i
|
||||
CHAT: "CHAT"i
|
||||
FILES: "FILES"i
|
||||
AS: "AS"i
|
||||
PARSE: "PARSE"i
|
||||
IMPORT: "IMPORT"i
|
||||
INTO: "INTO"i
|
||||
WITH: "WITH"i
|
||||
PARSER: "PARSER"i
|
||||
PIPELINE: "PIPELINE"i
|
||||
SEARCH: "SEARCH"i
|
||||
CURRENT: "CURRENT"i
|
||||
LLM: "LLM"i
|
||||
VLM: "VLM"i
|
||||
EMBEDDING: "EMBEDDING"i
|
||||
RERANKER: "RERANKER"i
|
||||
ASR: "ASR"i
|
||||
TTS: "TTS"i
|
||||
ASYNC: "ASYNC"i
|
||||
SYNC: "SYNC"i
|
||||
BENCHMARK: "BENCHMARK"i
|
||||
PING: "PING"i
|
||||
|
||||
login_user: LOGIN USER quoted_string ";"
|
||||
list_services: LIST SERVICES ";"
|
||||
show_service: SHOW SERVICE NUMBER ";"
|
||||
startup_service: STARTUP SERVICE NUMBER ";"
|
||||
shutdown_service: SHUTDOWN SERVICE NUMBER ";"
|
||||
restart_service: RESTART SERVICE NUMBER ";"
|
||||
|
||||
register_user: REGISTER USER quoted_string AS quoted_string PASSWORD quoted_string ";"
|
||||
list_users: LIST USERS ";"
|
||||
drop_user: DROP USER quoted_string ";"
|
||||
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
|
||||
show_user: SHOW USER quoted_string ";"
|
||||
create_user: CREATE USER quoted_string quoted_string ";"
|
||||
activate_user: ALTER USER ACTIVE quoted_string status ";"
|
||||
|
||||
list_datasets: LIST DATASETS OF quoted_string ";"
|
||||
list_agents: LIST AGENTS OF quoted_string ";"
|
||||
|
||||
create_role: CREATE ROLE identifier [DESCRIPTION quoted_string] ";"
|
||||
drop_role: DROP ROLE identifier ";"
|
||||
alter_role: ALTER ROLE identifier SET DESCRIPTION quoted_string ";"
|
||||
list_roles: LIST ROLES ";"
|
||||
show_role: SHOW ROLE identifier ";"
|
||||
|
||||
grant_permission: GRANT identifier_list ON identifier TO ROLE identifier ";"
|
||||
revoke_permission: REVOKE identifier_list ON identifier FROM ROLE identifier ";"
|
||||
alter_user_role: ALTER USER quoted_string SET ROLE identifier ";"
|
||||
show_user_permission: SHOW USER PERMISSION quoted_string ";"
|
||||
|
||||
show_version: SHOW VERSION ";"
|
||||
|
||||
grant_admin: GRANT ADMIN quoted_string ";"
|
||||
revoke_admin: REVOKE ADMIN quoted_string ";"
|
||||
|
||||
generate_key: GENERATE KEY FOR USER quoted_string ";"
|
||||
list_keys: LIST KEYS OF quoted_string ";"
|
||||
drop_key: DROP KEY quoted_string OF quoted_string ";"
|
||||
|
||||
set_variable: SET VAR identifier identifier ";"
|
||||
show_variable: SHOW VAR identifier ";"
|
||||
list_variables: LIST VARS ";"
|
||||
list_configs: LIST CONFIGS ";"
|
||||
list_environments: LIST ENVS ";"
|
||||
|
||||
benchmark: BENCHMARK NUMBER NUMBER user_statement
|
||||
|
||||
user_statement: ping_server
|
||||
| show_current_user
|
||||
| create_model_provider
|
||||
| drop_model_provider
|
||||
| set_default_llm
|
||||
| set_default_vlm
|
||||
| set_default_embedding
|
||||
| set_default_reranker
|
||||
| set_default_asr
|
||||
| set_default_tts
|
||||
| reset_default_llm
|
||||
| reset_default_vlm
|
||||
| reset_default_embedding
|
||||
| reset_default_reranker
|
||||
| reset_default_asr
|
||||
| reset_default_tts
|
||||
| create_user_dataset_with_parser
|
||||
| create_user_dataset_with_pipeline
|
||||
| drop_user_dataset
|
||||
| list_user_datasets
|
||||
| list_user_dataset_files
|
||||
| list_user_agents
|
||||
| list_user_chats
|
||||
| create_user_chat
|
||||
| drop_user_chat
|
||||
| list_user_model_providers
|
||||
| list_user_default_models
|
||||
| import_docs_into_dataset
|
||||
| search_on_datasets
|
||||
|
||||
ping_server: PING ";"
|
||||
show_current_user: SHOW CURRENT USER ";"
|
||||
create_model_provider: CREATE MODEL PROVIDER quoted_string quoted_string ";"
|
||||
drop_model_provider: DROP MODEL PROVIDER quoted_string ";"
|
||||
set_default_llm: SET DEFAULT LLM quoted_string ";"
|
||||
set_default_vlm: SET DEFAULT VLM quoted_string ";"
|
||||
set_default_embedding: SET DEFAULT EMBEDDING quoted_string ";"
|
||||
set_default_reranker: SET DEFAULT RERANKER quoted_string ";"
|
||||
set_default_asr: SET DEFAULT ASR quoted_string ";"
|
||||
set_default_tts: SET DEFAULT TTS quoted_string ";"
|
||||
|
||||
reset_default_llm: RESET DEFAULT LLM ";"
|
||||
reset_default_vlm: RESET DEFAULT VLM ";"
|
||||
reset_default_embedding: RESET DEFAULT EMBEDDING ";"
|
||||
reset_default_reranker: RESET DEFAULT RERANKER ";"
|
||||
reset_default_asr: RESET DEFAULT ASR ";"
|
||||
reset_default_tts: RESET DEFAULT TTS ";"
|
||||
|
||||
list_user_datasets: LIST DATASETS ";"
|
||||
create_user_dataset_with_parser: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PARSER quoted_string ";"
|
||||
create_user_dataset_with_pipeline: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PIPELINE quoted_string ";"
|
||||
drop_user_dataset: DROP DATASET quoted_string ";"
|
||||
list_user_dataset_files: LIST FILES OF DATASET quoted_string ";"
|
||||
list_user_agents: LIST AGENTS ";"
|
||||
list_user_chats: LIST CHATS ";"
|
||||
create_user_chat: CREATE CHAT quoted_string ";"
|
||||
drop_user_chat: DROP CHAT quoted_string ";"
|
||||
list_user_model_providers: LIST MODEL PROVIDERS ";"
|
||||
list_user_default_models: LIST DEFAULT MODELS ";"
|
||||
import_docs_into_dataset: IMPORT quoted_string INTO DATASET quoted_string ";"
|
||||
search_on_datasets: SEARCH quoted_string ON DATASETS quoted_string ";"
|
||||
|
||||
parse_dataset_docs: PARSE quoted_string OF DATASET quoted_string ";"
|
||||
parse_dataset_sync: PARSE DATASET quoted_string SYNC ";"
|
||||
parse_dataset_async: PARSE DATASET quoted_string ASYNC ";"
|
||||
|
||||
identifier_list: identifier ("," identifier)*
|
||||
|
||||
identifier: WORD
|
||||
quoted_string: QUOTED_STRING
|
||||
status: WORD
|
||||
|
||||
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
|
||||
WORD: /[a-zA-Z0-9_\-\.]+/
|
||||
NUMBER: /[0-9]+/
|
||||
|
||||
%import common.WS
|
||||
%ignore WS
|
||||
"""
|
||||
|
||||
|
||||
class RAGFlowCLITransformer(Transformer):
|
||||
def start(self, items):
|
||||
return items[0]
|
||||
|
||||
def command(self, items):
|
||||
return items[0]
|
||||
|
||||
def login_user(self, items):
|
||||
email = items[2].children[0].strip("'\"")
|
||||
return {"type": "login_user", "email": email}
|
||||
|
||||
def ping_server(self, items):
|
||||
return {"type": "ping_server"}
|
||||
|
||||
def list_services(self, items):
|
||||
result = {"type": "list_services"}
|
||||
return result
|
||||
|
||||
def show_service(self, items):
|
||||
service_id = int(items[2])
|
||||
return {"type": "show_service", "number": service_id}
|
||||
|
||||
def startup_service(self, items):
|
||||
service_id = int(items[2])
|
||||
return {"type": "startup_service", "number": service_id}
|
||||
|
||||
def shutdown_service(self, items):
|
||||
service_id = int(items[2])
|
||||
return {"type": "shutdown_service", "number": service_id}
|
||||
|
||||
def restart_service(self, items):
|
||||
service_id = int(items[2])
|
||||
return {"type": "restart_service", "number": service_id}
|
||||
|
||||
def register_user(self, items):
|
||||
user_name: str = items[2].children[0].strip("'\"")
|
||||
nickname: str = items[4].children[0].strip("'\"")
|
||||
password: str = items[6].children[0].strip("'\"")
|
||||
return {"type": "register_user", "user_name": user_name, "nickname": nickname, "password": password}
|
||||
|
||||
def list_users(self, items):
|
||||
return {"type": "list_users"}
|
||||
|
||||
def show_user(self, items):
|
||||
user_name = items[2]
|
||||
return {"type": "show_user", "user_name": user_name}
|
||||
|
||||
def drop_user(self, items):
|
||||
user_name = items[2]
|
||||
return {"type": "drop_user", "user_name": user_name}
|
||||
|
||||
def alter_user(self, items):
|
||||
user_name = items[3]
|
||||
new_password = items[4]
|
||||
return {"type": "alter_user", "user_name": user_name, "password": new_password}
|
||||
|
||||
def create_user(self, items):
|
||||
user_name = items[2]
|
||||
password = items[3]
|
||||
return {"type": "create_user", "user_name": user_name, "password": password, "role": "user"}
|
||||
|
||||
def activate_user(self, items):
|
||||
user_name = items[3]
|
||||
activate_status = items[4]
|
||||
return {"type": "activate_user", "activate_status": activate_status, "user_name": user_name}
|
||||
|
||||
def list_datasets(self, items):
|
||||
user_name = items[3]
|
||||
return {"type": "list_datasets", "user_name": user_name}
|
||||
|
||||
def list_agents(self, items):
|
||||
user_name = items[3]
|
||||
return {"type": "list_agents", "user_name": user_name}
|
||||
|
||||
def create_role(self, items):
|
||||
role_name = items[2]
|
||||
if len(items) > 4:
|
||||
description = items[4]
|
||||
return {"type": "create_role", "role_name": role_name, "description": description}
|
||||
else:
|
||||
return {"type": "create_role", "role_name": role_name}
|
||||
|
||||
def drop_role(self, items):
|
||||
role_name = items[2]
|
||||
return {"type": "drop_role", "role_name": role_name}
|
||||
|
||||
def alter_role(self, items):
|
||||
role_name = items[2]
|
||||
description = items[5]
|
||||
return {"type": "alter_role", "role_name": role_name, "description": description}
|
||||
|
||||
def list_roles(self, items):
|
||||
return {"type": "list_roles"}
|
||||
|
||||
def show_role(self, items):
|
||||
role_name = items[2]
|
||||
return {"type": "show_role", "role_name": role_name}
|
||||
|
||||
def grant_permission(self, items):
|
||||
action_list = items[1]
|
||||
resource = items[3]
|
||||
role_name = items[6]
|
||||
return {"type": "grant_permission", "role_name": role_name, "resource": resource, "actions": action_list}
|
||||
|
||||
def revoke_permission(self, items):
|
||||
action_list = items[1]
|
||||
resource = items[3]
|
||||
role_name = items[6]
|
||||
return {"type": "revoke_permission", "role_name": role_name, "resource": resource, "actions": action_list}
|
||||
|
||||
def alter_user_role(self, items):
|
||||
user_name = items[2]
|
||||
role_name = items[5]
|
||||
return {"type": "alter_user_role", "user_name": user_name, "role_name": role_name}
|
||||
|
||||
def show_user_permission(self, items):
|
||||
user_name = items[3]
|
||||
return {"type": "show_user_permission", "user_name": user_name}
|
||||
|
||||
def show_version(self, items):
|
||||
return {"type": "show_version"}
|
||||
|
||||
def grant_admin(self, items):
|
||||
user_name = items[2]
|
||||
return {"type": "grant_admin", "user_name": user_name}
|
||||
|
||||
def revoke_admin(self, items):
|
||||
user_name = items[2]
|
||||
return {"type": "revoke_admin", "user_name": user_name}
|
||||
|
||||
def generate_key(self, items):
|
||||
user_name = items[4]
|
||||
return {"type": "generate_key", "user_name": user_name}
|
||||
|
||||
def list_keys(self, items):
|
||||
user_name = items[3]
|
||||
return {"type": "list_keys", "user_name": user_name}
|
||||
|
||||
def drop_key(self, items):
|
||||
key = items[2]
|
||||
user_name = items[4]
|
||||
return {"type": "drop_key", "key": key, "user_name": user_name}
|
||||
|
||||
def set_variable(self, items):
|
||||
var_name = items[2]
|
||||
var_value = items[3]
|
||||
return {"type": "set_variable", "var_name": var_name, "var_value": var_value}
|
||||
|
||||
def show_variable(self, items):
|
||||
var_name = items[2]
|
||||
return {"type": "show_variable", "var_name": var_name}
|
||||
|
||||
def list_variables(self, items):
|
||||
return {"type": "list_variables"}
|
||||
|
||||
def list_configs(self, items):
|
||||
return {"type": "list_configs"}
|
||||
|
||||
def list_environments(self, items):
|
||||
return {"type": "list_environments"}
|
||||
|
||||
def create_model_provider(self, items):
|
||||
provider_name = items[3].children[0].strip("'\"")
|
||||
provider_key = items[4].children[0].strip("'\"")
|
||||
return {"type": "create_model_provider", "provider_name": provider_name, "provider_key": provider_key}
|
||||
|
||||
def drop_model_provider(self, items):
|
||||
provider_name = items[3].children[0].strip("'\"")
|
||||
return {"type": "drop_model_provider", "provider_name": provider_name}
|
||||
|
||||
def show_current_user(self, items):
|
||||
return {"type": "show_current_user"}
|
||||
|
||||
def set_default_llm(self, items):
|
||||
llm_id = items[3].children[0].strip("'\"")
|
||||
return {"type": "set_default_model", "model_type": "llm_id", "model_id": llm_id}
|
||||
|
||||
def set_default_vlm(self, items):
|
||||
vlm_id = items[3].children[0].strip("'\"")
|
||||
return {"type": "set_default_model", "model_type": "img2txt_id", "model_id": vlm_id}
|
||||
|
||||
def set_default_embedding(self, items):
|
||||
embedding_id = items[3].children[0].strip("'\"")
|
||||
return {"type": "set_default_model", "model_type": "embd_id", "model_id": embedding_id}
|
||||
|
||||
def set_default_reranker(self, items):
|
||||
reranker_id = items[3].children[0].strip("'\"")
|
||||
return {"type": "set_default_model", "model_type": "reranker_id", "model_id": reranker_id}
|
||||
|
||||
def set_default_asr(self, items):
|
||||
asr_id = items[3].children[0].strip("'\"")
|
||||
return {"type": "set_default_model", "model_type": "asr_id", "model_id": asr_id}
|
||||
|
||||
def set_default_tts(self, items):
|
||||
tts_id = items[3].children[0].strip("'\"")
|
||||
return {"type": "set_default_model", "model_type": "tts_id", "model_id": tts_id}
|
||||
|
||||
def reset_default_llm(self, items):
|
||||
return {"type": "reset_default_model", "model_type": "llm_id"}
|
||||
|
||||
def reset_default_vlm(self, items):
|
||||
return {"type": "reset_default_model", "model_type": "img2txt_id"}
|
||||
|
||||
def reset_default_embedding(self, items):
|
||||
return {"type": "reset_default_model", "model_type": "embd_id"}
|
||||
|
||||
def reset_default_reranker(self, items):
|
||||
return {"type": "reset_default_model", "model_type": "reranker_id"}
|
||||
|
||||
def reset_default_asr(self, items):
|
||||
return {"type": "reset_default_model", "model_type": "asr_id"}
|
||||
|
||||
def reset_default_tts(self, items):
|
||||
return {"type": "reset_default_model", "model_type": "tts_id"}
|
||||
|
||||
def list_user_datasets(self, items):
|
||||
return {"type": "list_user_datasets"}
|
||||
|
||||
def create_user_dataset_with_parser(self, items):
|
||||
dataset_name = items[2].children[0].strip("'\"")
|
||||
embedding = items[5].children[0].strip("'\"")
|
||||
parser_type = items[7].children[0].strip("'\"")
|
||||
return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding,
|
||||
"parser_type": parser_type}
|
||||
|
||||
def create_user_dataset_with_pipeline(self, items):
|
||||
dataset_name = items[2].children[0].strip("'\"")
|
||||
embedding = items[5].children[0].strip("'\"")
|
||||
pipeline = items[7].children[0].strip("'\"")
|
||||
return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding,
|
||||
"pipeline": pipeline}
|
||||
|
||||
def drop_user_dataset(self, items):
|
||||
dataset_name = items[2].children[0].strip("'\"")
|
||||
return {"type": "drop_user_dataset", "dataset_name": dataset_name}
|
||||
|
||||
def list_user_dataset_files(self, items):
|
||||
dataset_name = items[4].children[0].strip("'\"")
|
||||
return {"type": "list_user_dataset_files", "dataset_name": dataset_name}
|
||||
|
||||
def list_user_agents(self, items):
|
||||
return {"type": "list_user_agents"}
|
||||
|
||||
def list_user_chats(self, items):
|
||||
return {"type": "list_user_chats"}
|
||||
|
||||
def create_user_chat(self, items):
|
||||
chat_name = items[2].children[0].strip("'\"")
|
||||
return {"type": "create_user_chat", "chat_name": chat_name}
|
||||
|
||||
def drop_user_chat(self, items):
|
||||
chat_name = items[2].children[0].strip("'\"")
|
||||
return {"type": "drop_user_chat", "chat_name": chat_name}
|
||||
|
||||
def list_user_model_providers(self, items):
|
||||
return {"type": "list_user_model_providers"}
|
||||
|
||||
def list_user_default_models(self, items):
|
||||
return {"type": "list_user_default_models"}
|
||||
|
||||
def parse_dataset_docs(self, items):
|
||||
document_list_str = items[1].children[0].strip("'\"")
|
||||
document_names = document_list_str.split(",")
|
||||
if len(document_names) == 1:
|
||||
document_names = document_names[0]
|
||||
document_names = document_names.split(" ")
|
||||
dataset_name = items[4].children[0].strip("'\"")
|
||||
return {"type": "parse_dataset_docs", "dataset_name": dataset_name, "document_names": document_names}
|
||||
|
||||
def parse_dataset_sync(self, items):
|
||||
dataset_name = items[2].children[0].strip("'\"")
|
||||
return {"type": "parse_dataset", "dataset_name": dataset_name, "method": "sync"}
|
||||
|
||||
def parse_dataset_async(self, items):
|
||||
dataset_name = items[2].children[0].strip("'\"")
|
||||
return {"type": "parse_dataset", "dataset_name": dataset_name, "method": "async"}
|
||||
|
||||
def import_docs_into_dataset(self, items):
|
||||
document_list_str = items[1].children[0].strip("'\"")
|
||||
document_paths = document_list_str.split(",")
|
||||
if len(document_paths) == 1:
|
||||
document_paths = document_paths[0]
|
||||
document_paths = document_paths.split(" ")
|
||||
dataset_name = items[4].children[0].strip("'\"")
|
||||
return {"type": "import_docs_into_dataset", "dataset_name": dataset_name, "document_paths": document_paths}
|
||||
|
||||
def search_on_datasets(self, items):
|
||||
question = items[1].children[0].strip("'\"")
|
||||
datasets_str = items[4].children[0].strip("'\"")
|
||||
datasets = datasets_str.split(",")
|
||||
if len(datasets) == 1:
|
||||
datasets = datasets[0]
|
||||
datasets = datasets.split(" ")
|
||||
return {"type": "search_on_datasets", "datasets": datasets, "question": question}
|
||||
|
||||
def benchmark(self, items):
|
||||
concurrency: int = int(items[1])
|
||||
iterations: int = int(items[2])
|
||||
command = items[3].children[0]
|
||||
return {"type": "benchmark", "concurrency": concurrency, "iterations": iterations, "command": command}
|
||||
|
||||
def action_list(self, items):
|
||||
return items
|
||||
|
||||
def meta_command(self, items):
|
||||
command_name = str(items[0]).lower()
|
||||
args = items[1:] if len(items) > 1 else []
|
||||
|
||||
# handle quoted parameter
|
||||
parsed_args = []
|
||||
for arg in args:
|
||||
if hasattr(arg, "value"):
|
||||
parsed_args.append(arg.value)
|
||||
else:
|
||||
parsed_args.append(str(arg))
|
||||
|
||||
return {"type": "meta", "command": command_name, "args": parsed_args}
|
||||
|
||||
def meta_command_name(self, items):
|
||||
return items[0]
|
||||
|
||||
def meta_args(self, items):
|
||||
return items
|
||||
@ -20,5 +20,8 @@ test = [
|
||||
"requests-toolbelt>=1.0.0",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["ragflow_cli", "parser"]
|
||||
|
||||
[project.scripts]
|
||||
ragflow-cli = "admin_client:main"
|
||||
ragflow-cli = "ragflow_cli:main"
|
||||
|
||||
322
admin/client/ragflow_cli.py
Normal file
322
admin/client/ragflow_cli.py
Normal file
@ -0,0 +1,322 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import base64
|
||||
import getpass
|
||||
from cmd import Cmd
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
import warnings
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from lark import Lark, Tree
|
||||
from parser import GRAMMAR, RAGFlowCLITransformer
|
||||
from http_client import HttpClient
|
||||
from ragflow_client import RAGFlowClient, run_command
|
||||
from user import login_user
|
||||
|
||||
warnings.filterwarnings("ignore", category=getpass.GetPassWarning)
|
||||
|
||||
def encrypt(input_string):
|
||||
pub = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----"
|
||||
pub_key = RSA.importKey(pub)
|
||||
cipher = Cipher_pkcs1_v1_5.new(pub_key)
|
||||
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode("utf-8")))
|
||||
return base64.b64encode(cipher_text).decode("utf-8")
|
||||
|
||||
|
||||
def encode_to_base64(input_string):
|
||||
base64_encoded = base64.b64encode(input_string.encode("utf-8"))
|
||||
return base64_encoded.decode("utf-8")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class RAGFlowCLI(Cmd):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.parser = Lark(GRAMMAR, start="start", parser="lalr", transformer=RAGFlowCLITransformer())
|
||||
self.command_history = []
|
||||
self.account = "admin@ragflow.io"
|
||||
self.account_password: str = "admin"
|
||||
self.session = requests.Session()
|
||||
self.host: str = ""
|
||||
self.port: int = 0
|
||||
self.mode: str = "admin"
|
||||
self.ragflow_client = None
|
||||
|
||||
intro = r"""Type "\h" for help."""
|
||||
prompt = "ragflow> "
|
||||
|
||||
def onecmd(self, command: str) -> bool:
|
||||
try:
|
||||
result = self.parse_command(command)
|
||||
|
||||
if isinstance(result, dict):
|
||||
if "type" in result and result.get("type") == "empty":
|
||||
return False
|
||||
|
||||
self.execute_command(result)
|
||||
|
||||
if isinstance(result, Tree):
|
||||
return False
|
||||
|
||||
if result.get("type") == "meta" and result.get("command") in ["q", "quit", "exit"]:
|
||||
return True
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nUse '\\q' to quit")
|
||||
except EOFError:
|
||||
print("\nGoodbye!")
|
||||
return True
|
||||
return False
|
||||
|
||||
def emptyline(self) -> bool:
|
||||
return False
|
||||
|
||||
def default(self, line: str) -> bool:
|
||||
return self.onecmd(line)
|
||||
|
||||
def parse_command(self, command_str: str) -> dict[str, str]:
|
||||
if not command_str.strip():
|
||||
return {"type": "empty"}
|
||||
|
||||
self.command_history.append(command_str)
|
||||
|
||||
try:
|
||||
result = self.parser.parse(command_str)
|
||||
return result
|
||||
except Exception as e:
|
||||
return {"type": "error", "message": f"Parse error: {str(e)}"}
|
||||
|
||||
def verify_auth(self, arguments: dict, single_command: bool, auth: bool):
|
||||
server_type = arguments.get("type", "admin")
|
||||
http_client = HttpClient(arguments["host"], arguments["port"])
|
||||
if not auth:
|
||||
self.ragflow_client = RAGFlowClient(http_client, server_type)
|
||||
return True
|
||||
|
||||
user_name = arguments["username"]
|
||||
attempt_count = 3
|
||||
if single_command:
|
||||
attempt_count = 1
|
||||
|
||||
try_count = 0
|
||||
while True:
|
||||
try_count += 1
|
||||
if try_count > attempt_count:
|
||||
return False
|
||||
|
||||
if single_command:
|
||||
user_password = arguments["password"]
|
||||
else:
|
||||
user_password = getpass.getpass(f"password for {user_name}: ").strip()
|
||||
|
||||
try:
|
||||
token = login_user(http_client, server_type, user_name, user_password)
|
||||
http_client.login_token = token
|
||||
self.ragflow_client = RAGFlowClient(http_client, server_type)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
print("Can't access server for login (connection failed)")
|
||||
|
||||
def _format_service_detail_table(self, data):
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
if not all([isinstance(v, list) for v in data.values()]):
|
||||
# normal table
|
||||
return data
|
||||
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
|
||||
task_executor_list = []
|
||||
for k, v in data.items():
|
||||
# display latest status
|
||||
heartbeats = sorted(v, key=lambda x: x["now"], reverse=True)
|
||||
task_executor_list.append(
|
||||
{
|
||||
"task_executor_name": k,
|
||||
**heartbeats[0],
|
||||
}
|
||||
if heartbeats
|
||||
else {"task_executor_name": k}
|
||||
)
|
||||
return task_executor_list
|
||||
|
||||
def _print_table_simple(self, data):
|
||||
if not data:
|
||||
print("No data to print")
|
||||
return
|
||||
if isinstance(data, dict):
|
||||
# handle single row data
|
||||
data = [data]
|
||||
|
||||
columns = list(set().union(*(d.keys() for d in data)))
|
||||
columns.sort()
|
||||
col_widths = {}
|
||||
|
||||
def get_string_width(text):
|
||||
half_width_chars = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\t\n\r"
|
||||
width = 0
|
||||
for char in text:
|
||||
if char in half_width_chars:
|
||||
width += 1
|
||||
else:
|
||||
width += 2
|
||||
return width
|
||||
|
||||
for col in columns:
|
||||
max_width = get_string_width(str(col))
|
||||
for item in data:
|
||||
value_len = get_string_width(str(item.get(col, "")))
|
||||
if value_len > max_width:
|
||||
max_width = value_len
|
||||
col_widths[col] = max(2, max_width)
|
||||
|
||||
# Generate delimiter
|
||||
separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
|
||||
|
||||
# Print header
|
||||
print(separator)
|
||||
header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
|
||||
print(header)
|
||||
print(separator)
|
||||
|
||||
# Print data
|
||||
for item in data:
|
||||
row = "|"
|
||||
for col in columns:
|
||||
value = str(item.get(col, ""))
|
||||
if get_string_width(value) > col_widths[col]:
|
||||
value = value[: col_widths[col] - 3] + "..."
|
||||
row += f" {value:<{col_widths[col] - (get_string_width(value) - len(value))}} |"
|
||||
print(row)
|
||||
|
||||
print(separator)
|
||||
|
||||
def run_interactive(self, args):
|
||||
if self.verify_auth(args, single_command=False, auth=args["auth"]):
|
||||
print(r"""
|
||||
____ ___ ______________ ________ ____
|
||||
/ __ \/ | / ____/ ____/ /___ _ __ / ____/ / / _/
|
||||
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / / / / / /
|
||||
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / /___/ /____/ /
|
||||
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ \____/_____/___/
|
||||
""")
|
||||
self.cmdloop()
|
||||
|
||||
print("RAGFlow command line interface - Type '\\?' for help, '\\q' to quit")
|
||||
|
||||
def run_single_command(self, args):
|
||||
if self.verify_auth(args, single_command=True, auth=args["auth"]):
|
||||
command = args["command"]
|
||||
result = self.parse_command(command)
|
||||
self.execute_command(result)
|
||||
|
||||
|
||||
def parse_connection_args(self, args: List[str]) -> Dict[str, Any]:
|
||||
parser = argparse.ArgumentParser(description="RAGFlow CLI Client", add_help=False)
|
||||
parser.add_argument("-h", "--host", default="127.0.0.1", help="Admin or RAGFlow service host")
|
||||
parser.add_argument("-p", "--port", type=int, default=9381, help="Admin or RAGFlow service port")
|
||||
parser.add_argument("-w", "--password", default="admin", type=str, help="Superuser password")
|
||||
parser.add_argument("-t", "--type", default="admin", type=str, help="CLI mode, admin or user")
|
||||
parser.add_argument("-u", "--username", default=None,
|
||||
help="Username (email). In admin mode defaults to admin@ragflow.io, in user mode required.")
|
||||
parser.add_argument("command", nargs="?", help="Single command")
|
||||
try:
|
||||
parsed_args, remaining_args = parser.parse_known_args(args)
|
||||
# Determine username based on mode
|
||||
username = parsed_args.username
|
||||
if parsed_args.type == "admin":
|
||||
if username is None:
|
||||
username = "admin@ragflow.io"
|
||||
|
||||
if remaining_args:
|
||||
if remaining_args[0] == "command":
|
||||
command_str = ' '.join(remaining_args[1:]) + ';'
|
||||
auth = True
|
||||
if remaining_args[1] == "register":
|
||||
auth = False
|
||||
else:
|
||||
if username is None:
|
||||
print("Error: username (-u) is required in user mode")
|
||||
return {"error": "Username required"}
|
||||
return {
|
||||
"host": parsed_args.host,
|
||||
"port": parsed_args.port,
|
||||
"password": parsed_args.password,
|
||||
"type": parsed_args.type,
|
||||
"username": username,
|
||||
"command": command_str,
|
||||
"auth": auth
|
||||
}
|
||||
else:
|
||||
return {"error": "Invalid command"}
|
||||
else:
|
||||
auth = True
|
||||
if username is None:
|
||||
auth = False
|
||||
return {
|
||||
"host": parsed_args.host,
|
||||
"port": parsed_args.port,
|
||||
"type": parsed_args.type,
|
||||
"username": username,
|
||||
"auth": auth
|
||||
}
|
||||
except SystemExit:
|
||||
return {"error": "Invalid connection arguments"}
|
||||
|
||||
def execute_command(self, parsed_command: Dict[str, Any]):
|
||||
command_dict: dict
|
||||
if isinstance(parsed_command, Tree):
|
||||
command_dict = parsed_command.children[0]
|
||||
else:
|
||||
if parsed_command["type"] == "error":
|
||||
print(f"Error: {parsed_command['message']}")
|
||||
return
|
||||
else:
|
||||
command_dict = parsed_command
|
||||
|
||||
# print(f"Parsed command: {command_dict}")
|
||||
run_command(self.ragflow_client, command_dict)
|
||||
|
||||
def main():
|
||||
|
||||
cli = RAGFlowCLI()
|
||||
|
||||
args = cli.parse_connection_args(sys.argv)
|
||||
if "error" in args:
|
||||
print("Error: Invalid connection arguments")
|
||||
return
|
||||
|
||||
if "command" in args:
|
||||
# single command mode
|
||||
# for user mode, api key or password is ok
|
||||
# for admin mode, only password
|
||||
if "password" not in args:
|
||||
print("Error: password is missing")
|
||||
return
|
||||
|
||||
cli.run_single_command(args)
|
||||
else:
|
||||
cli.run_interactive(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1492
admin/client/ragflow_client.py
Normal file
1492
admin/client/ragflow_client.py
Normal file
File diff suppressed because it is too large
Load Diff
65
admin/client/user.py
Normal file
65
admin/client/user.py
Normal file
@ -0,0 +1,65 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from http_client import HttpClient
|
||||
|
||||
|
||||
class AuthException(Exception):
|
||||
def __init__(self, message, code=401):
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
def encrypt_password(password_plain: str) -> str:
|
||||
try:
|
||||
from api.utils.crypt import crypt
|
||||
except Exception as exc:
|
||||
raise AuthException(
|
||||
"Password encryption unavailable; install pycryptodomex (uv sync --python 3.12 --group test)."
|
||||
) from exc
|
||||
return crypt(password_plain)
|
||||
|
||||
|
||||
def register_user(client: HttpClient, email: str, nickname: str, password: str) -> None:
|
||||
password_enc = encrypt_password(password)
|
||||
payload = {"email": email, "nickname": nickname, "password": password_enc}
|
||||
res = client.request_json("POST", "/user/register", use_api_base=False, auth_kind=None, json_body=payload)
|
||||
if res.get("code") == 0:
|
||||
return
|
||||
msg = res.get("message", "")
|
||||
if "has already registered" in msg:
|
||||
return
|
||||
raise AuthException(f"Register failed: {msg}")
|
||||
|
||||
|
||||
def login_user(client: HttpClient, server_type: str, email: str, password: str) -> str:
|
||||
password_enc = encrypt_password(password)
|
||||
payload = {"email": email, "password": password_enc}
|
||||
if server_type == "admin":
|
||||
response = client.request("POST", "/admin/login", use_api_base=True, auth_kind=None, json_body=payload)
|
||||
else:
|
||||
response = client.request("POST", "/user/login", use_api_base=False, auth_kind=None, json_body=payload)
|
||||
try:
|
||||
res = response.json()
|
||||
except Exception as exc:
|
||||
raise AuthException(f"Login failed: invalid JSON response ({exc})") from exc
|
||||
if res.get("code") != 0:
|
||||
raise AuthException(f"Login failed: {res.get('message')}")
|
||||
token = response.headers.get("Authorization")
|
||||
if not token:
|
||||
raise AuthException("Login failed: missing Authorization header")
|
||||
return token
|
||||
@ -14,10 +14,12 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import time
|
||||
start_ts = time.time()
|
||||
|
||||
import os
|
||||
import signal
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
import traceback
|
||||
import faulthandler
|
||||
@ -66,7 +68,7 @@ if __name__ == '__main__':
|
||||
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
|
||||
|
||||
try:
|
||||
logging.info("RAGFlow Admin service start...")
|
||||
logging.info(f"RAGFlow admin is ready after {time.time() - start_ts}s initialization.")
|
||||
run_simple(
|
||||
hostname="0.0.0.0",
|
||||
port=9381,
|
||||
|
||||
@ -15,29 +15,33 @@
|
||||
#
|
||||
|
||||
import secrets
|
||||
from typing import Any
|
||||
|
||||
from flask import Blueprint, request
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
from datetime import datetime
|
||||
from flask import Blueprint, Response, request
|
||||
from flask_login import current_user, login_required, logout_user
|
||||
|
||||
from auth import login_verify, login_admin, check_admin_auth
|
||||
from responses import success_response, error_response
|
||||
from services import UserMgr, ServiceMgr, UserServiceMgr
|
||||
from services import UserMgr, ServiceMgr, UserServiceMgr, SettingsMgr, ConfigMgr, EnvironmentsMgr
|
||||
from roles import RoleMgr
|
||||
from api.common.exceptions import AdminException
|
||||
from common.versions import get_ragflow_version
|
||||
from api.utils.api_utils import generate_confirmation_token
|
||||
|
||||
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
|
||||
admin_bp = Blueprint("admin", __name__, url_prefix="/api/v1/admin")
|
||||
|
||||
|
||||
@admin_bp.route('/ping', methods=['GET'])
|
||||
@admin_bp.route("/ping", methods=["GET"])
|
||||
def ping():
|
||||
return success_response('PONG')
|
||||
return success_response("PONG")
|
||||
|
||||
|
||||
@admin_bp.route('/login', methods=['POST'])
|
||||
@admin_bp.route("/login", methods=["POST"])
|
||||
def login():
|
||||
if not request.json:
|
||||
return error_response('Authorize admin failed.' ,400)
|
||||
return error_response("Authorize admin failed.", 400)
|
||||
try:
|
||||
email = request.json.get("email", "")
|
||||
password = request.json.get("password", "")
|
||||
@ -46,7 +50,7 @@ def login():
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/logout', methods=['GET'])
|
||||
@admin_bp.route("/logout", methods=["GET"])
|
||||
@login_required
|
||||
def logout():
|
||||
try:
|
||||
@ -58,7 +62,7 @@ def logout():
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/auth', methods=['GET'])
|
||||
@admin_bp.route("/auth", methods=["GET"])
|
||||
@login_verify
|
||||
def auth_admin():
|
||||
try:
|
||||
@ -67,7 +71,7 @@ def auth_admin():
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users', methods=['GET'])
|
||||
@admin_bp.route("/users", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def list_users():
|
||||
@ -78,18 +82,18 @@ def list_users():
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users', methods=['POST'])
|
||||
@admin_bp.route("/users", methods=["POST"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def create_user():
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data or 'username' not in data or 'password' not in data:
|
||||
if not data or "username" not in data or "password" not in data:
|
||||
return error_response("Username and password are required", 400)
|
||||
|
||||
username = data['username']
|
||||
password = data['password']
|
||||
role = data.get('role', 'user')
|
||||
username = data["username"]
|
||||
password = data["password"]
|
||||
role = data.get("role", "user")
|
||||
|
||||
res = UserMgr.create_user(username, password, role)
|
||||
if res["success"]:
|
||||
@ -105,7 +109,7 @@ def create_user():
|
||||
return error_response(str(e))
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>', methods=['DELETE'])
|
||||
@admin_bp.route("/users/<username>", methods=["DELETE"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def delete_user(username):
|
||||
@ -122,16 +126,16 @@ def delete_user(username):
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>/password', methods=['PUT'])
|
||||
@admin_bp.route("/users/<username>/password", methods=["PUT"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def change_password(username):
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data or 'new_password' not in data:
|
||||
if not data or "new_password" not in data:
|
||||
return error_response("New password is required", 400)
|
||||
|
||||
new_password = data['new_password']
|
||||
new_password = data["new_password"]
|
||||
msg = UserMgr.update_user_password(username, new_password)
|
||||
return success_response(None, msg)
|
||||
|
||||
@ -141,15 +145,15 @@ def change_password(username):
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>/activate', methods=['PUT'])
|
||||
@admin_bp.route("/users/<username>/activate", methods=["PUT"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def alter_user_activate_status(username):
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data or 'activate_status' not in data:
|
||||
if not data or "activate_status" not in data:
|
||||
return error_response("Activation status is required", 400)
|
||||
activate_status = data['activate_status']
|
||||
activate_status = data["activate_status"]
|
||||
msg = UserMgr.update_user_activate_status(username, activate_status)
|
||||
return success_response(None, msg)
|
||||
except AdminException as e:
|
||||
@ -158,7 +162,39 @@ def alter_user_activate_status(username):
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>', methods=['GET'])
|
||||
@admin_bp.route("/users/<username>/admin", methods=["PUT"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def grant_admin(username):
|
||||
try:
|
||||
if current_user.email == username:
|
||||
return error_response(f"can't grant current user: {username}", 409)
|
||||
msg = UserMgr.grant_admin(username)
|
||||
return success_response(None, msg)
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route("/users/<username>/admin", methods=["DELETE"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def revoke_admin(username):
|
||||
try:
|
||||
if current_user.email == username:
|
||||
return error_response(f"can't grant current user: {username}", 409)
|
||||
msg = UserMgr.revoke_admin(username)
|
||||
return success_response(None, msg)
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route("/users/<username>", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def get_user_details(username):
|
||||
@ -172,7 +208,7 @@ def get_user_details(username):
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>/datasets', methods=['GET'])
|
||||
@admin_bp.route("/users/<username>/datasets", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def get_user_datasets(username):
|
||||
@ -186,7 +222,7 @@ def get_user_datasets(username):
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users/<username>/agents', methods=['GET'])
|
||||
@admin_bp.route("/users/<username>/agents", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def get_user_agents(username):
|
||||
@ -200,7 +236,7 @@ def get_user_agents(username):
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/services', methods=['GET'])
|
||||
@admin_bp.route("/services", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def get_services():
|
||||
@ -211,7 +247,7 @@ def get_services():
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/service_types/<service_type>', methods=['GET'])
|
||||
@admin_bp.route("/service_types/<service_type>", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def get_services_by_type(service_type_str):
|
||||
@ -222,7 +258,7 @@ def get_services_by_type(service_type_str):
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/services/<service_id>', methods=['GET'])
|
||||
@admin_bp.route("/services/<service_id>", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def get_service(service_id):
|
||||
@ -233,7 +269,7 @@ def get_service(service_id):
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/services/<service_id>', methods=['DELETE'])
|
||||
@admin_bp.route("/services/<service_id>", methods=["DELETE"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def shutdown_service(service_id):
|
||||
@ -244,7 +280,7 @@ def shutdown_service(service_id):
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/services/<service_id>', methods=['PUT'])
|
||||
@admin_bp.route("/services/<service_id>", methods=["PUT"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def restart_service(service_id):
|
||||
@ -255,38 +291,38 @@ def restart_service(service_id):
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/roles', methods=['POST'])
|
||||
@admin_bp.route("/roles", methods=["POST"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def create_role():
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data or 'role_name' not in data:
|
||||
if not data or "role_name" not in data:
|
||||
return error_response("Role name is required", 400)
|
||||
role_name: str = data['role_name']
|
||||
description: str = data['description']
|
||||
role_name: str = data["role_name"]
|
||||
description: str = data["description"]
|
||||
res = RoleMgr.create_role(role_name, description)
|
||||
return success_response(res)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/roles/<role_name>', methods=['PUT'])
|
||||
@admin_bp.route("/roles/<role_name>", methods=["PUT"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def update_role(role_name: str):
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data or 'description' not in data:
|
||||
if not data or "description" not in data:
|
||||
return error_response("Role description is required", 400)
|
||||
description: str = data['description']
|
||||
description: str = data["description"]
|
||||
res = RoleMgr.update_role_description(role_name, description)
|
||||
return success_response(res)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/roles/<role_name>', methods=['DELETE'])
|
||||
@admin_bp.route("/roles/<role_name>", methods=["DELETE"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def delete_role(role_name: str):
|
||||
@ -297,7 +333,7 @@ def delete_role(role_name: str):
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/roles', methods=['GET'])
|
||||
@admin_bp.route("/roles", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def list_roles():
|
||||
@ -308,7 +344,7 @@ def list_roles():
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/roles/<role_name>/permission', methods=['GET'])
|
||||
@admin_bp.route("/roles/<role_name>/permission", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def get_role_permission(role_name: str):
|
||||
@ -319,54 +355,54 @@ def get_role_permission(role_name: str):
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/roles/<role_name>/permission', methods=['POST'])
|
||||
@admin_bp.route("/roles/<role_name>/permission", methods=["POST"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def grant_role_permission(role_name: str):
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data or 'actions' not in data or 'resource' not in data:
|
||||
if not data or "actions" not in data or "resource" not in data:
|
||||
return error_response("Permission is required", 400)
|
||||
actions: list = data['actions']
|
||||
resource: str = data['resource']
|
||||
actions: list = data["actions"]
|
||||
resource: str = data["resource"]
|
||||
res = RoleMgr.grant_role_permission(role_name, actions, resource)
|
||||
return success_response(res)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/roles/<role_name>/permission', methods=['DELETE'])
|
||||
@admin_bp.route("/roles/<role_name>/permission", methods=["DELETE"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def revoke_role_permission(role_name: str):
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data or 'actions' not in data or 'resource' not in data:
|
||||
if not data or "actions" not in data or "resource" not in data:
|
||||
return error_response("Permission is required", 400)
|
||||
actions: list = data['actions']
|
||||
resource: str = data['resource']
|
||||
actions: list = data["actions"]
|
||||
resource: str = data["resource"]
|
||||
res = RoleMgr.revoke_role_permission(role_name, actions, resource)
|
||||
return success_response(res)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users/<user_name>/role', methods=['PUT'])
|
||||
@admin_bp.route("/users/<user_name>/role", methods=["PUT"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def update_user_role(user_name: str):
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data or 'role_name' not in data:
|
||||
if not data or "role_name" not in data:
|
||||
return error_response("Role name is required", 400)
|
||||
role_name: str = data['role_name']
|
||||
role_name: str = data["role_name"]
|
||||
res = RoleMgr.update_user_role(user_name, role_name)
|
||||
return success_response(res)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route('/users/<user_name>/permission', methods=['GET'])
|
||||
@admin_bp.route("/users/<user_name>/permission", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def get_user_permission(user_name: str):
|
||||
@ -376,7 +412,140 @@ def get_user_permission(user_name: str):
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
@admin_bp.route('/version', methods=['GET'])
|
||||
|
||||
@admin_bp.route("/variables", methods=["PUT"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def set_variable():
|
||||
try:
|
||||
data = request.get_json()
|
||||
if not data and "var_name" not in data:
|
||||
return error_response("Var name is required", 400)
|
||||
|
||||
if "var_value" not in data:
|
||||
return error_response("Var value is required", 400)
|
||||
var_name: str = data["var_name"]
|
||||
var_value: str = data["var_value"]
|
||||
|
||||
SettingsMgr.update_by_name(var_name, var_value)
|
||||
return success_response(None, "Set variable successfully")
|
||||
except AdminException as e:
|
||||
return error_response(str(e), 400)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route("/variables", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def get_variable():
|
||||
try:
|
||||
if request.content_length is None or request.content_length == 0:
|
||||
# list variables
|
||||
res = list(SettingsMgr.get_all())
|
||||
return success_response(res)
|
||||
|
||||
# get var
|
||||
data = request.get_json()
|
||||
if not data and "var_name" not in data:
|
||||
return error_response("Var name is required", 400)
|
||||
var_name: str = data["var_name"]
|
||||
res = SettingsMgr.get_by_name(var_name)
|
||||
return success_response(res)
|
||||
except AdminException as e:
|
||||
return error_response(str(e), 400)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route("/configs", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def get_config():
|
||||
try:
|
||||
res = list(ConfigMgr.get_all())
|
||||
return success_response(res)
|
||||
except AdminException as e:
|
||||
return error_response(str(e), 400)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route("/environments", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def get_environments():
|
||||
try:
|
||||
res = list(EnvironmentsMgr.get_all())
|
||||
return success_response(res)
|
||||
except AdminException as e:
|
||||
return error_response(str(e), 400)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route("/users/<username>/keys", methods=["POST"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def generate_user_api_key(username: str) -> tuple[Response, int]:
|
||||
try:
|
||||
user_details: list[dict[str, Any]] = UserMgr.get_user_details(username)
|
||||
if not user_details:
|
||||
return error_response("User not found!", 404)
|
||||
tenants: list[dict[str, Any]] = UserServiceMgr.get_user_tenants(username)
|
||||
if not tenants:
|
||||
return error_response("Tenant not found!", 404)
|
||||
tenant_id: str = tenants[0]["tenant_id"]
|
||||
key: str = generate_confirmation_token()
|
||||
obj: dict[str, Any] = {
|
||||
"tenant_id": tenant_id,
|
||||
"token": key,
|
||||
"beta": generate_confirmation_token().replace("ragflow-", "")[:32],
|
||||
"create_time": current_timestamp(),
|
||||
"create_date": datetime_format(datetime.now()),
|
||||
"update_time": None,
|
||||
"update_date": None,
|
||||
}
|
||||
|
||||
if not UserMgr.save_api_key(obj):
|
||||
return error_response("Failed to generate API key!", 500)
|
||||
return success_response(obj, "API key generated successfully")
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route("/users/<username>/keys", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def get_user_api_keys(username: str) -> tuple[Response, int]:
|
||||
try:
|
||||
api_keys: list[dict[str, Any]] = UserMgr.get_user_api_key(username)
|
||||
return success_response(api_keys, "Get user API keys")
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route("/users/<username>/keys/<key>", methods=["DELETE"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def delete_user_api_key(username: str, key: str) -> tuple[Response, int]:
|
||||
try:
|
||||
deleted = UserMgr.delete_api_key(username, key)
|
||||
if deleted:
|
||||
return success_response(None, "API key deleted successfully")
|
||||
else:
|
||||
return error_response("API key not found or could not be deleted", 404)
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
except Exception as e:
|
||||
return error_response(str(e), 500)
|
||||
|
||||
|
||||
@admin_bp.route("/version", methods=["GET"])
|
||||
@login_required
|
||||
@check_admin_auth
|
||||
def show_version():
|
||||
|
||||
@ -17,13 +17,18 @@
|
||||
import os
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from werkzeug.security import check_password_hash
|
||||
from common.constants import ActiveEnum
|
||||
from api.db.services import UserService
|
||||
from api.db.joint_services.user_account_service import create_new_user, delete_user_data
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.system_settings_service import SystemSettingsService
|
||||
from api.db.services.api_service import APITokenService
|
||||
from api.db.db_models import APIToken
|
||||
from api.utils.crypt import decrypt
|
||||
from api.utils import health_utils
|
||||
|
||||
@ -37,13 +42,15 @@ class UserMgr:
|
||||
users = UserService.get_all_users()
|
||||
result = []
|
||||
for user in users:
|
||||
result.append({
|
||||
'email': user.email,
|
||||
'nickname': user.nickname,
|
||||
'create_date': user.create_date,
|
||||
'is_active': user.is_active,
|
||||
'is_superuser': user.is_superuser,
|
||||
})
|
||||
result.append(
|
||||
{
|
||||
"email": user.email,
|
||||
"nickname": user.nickname,
|
||||
"create_date": user.create_date,
|
||||
"is_active": user.is_active,
|
||||
"is_superuser": user.is_superuser,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
@ -52,19 +59,21 @@ class UserMgr:
|
||||
users = UserService.query_user_by_email(username)
|
||||
result = []
|
||||
for user in users:
|
||||
result.append({
|
||||
'avatar': user.avatar,
|
||||
'email': user.email,
|
||||
'language': user.language,
|
||||
'last_login_time': user.last_login_time,
|
||||
'is_active': user.is_active,
|
||||
'is_anonymous': user.is_anonymous,
|
||||
'login_channel': user.login_channel,
|
||||
'status': user.status,
|
||||
'is_superuser': user.is_superuser,
|
||||
'create_date': user.create_date,
|
||||
'update_date': user.update_date
|
||||
})
|
||||
result.append(
|
||||
{
|
||||
"avatar": user.avatar,
|
||||
"email": user.email,
|
||||
"language": user.language,
|
||||
"last_login_time": user.last_login_time,
|
||||
"is_active": user.is_active,
|
||||
"is_anonymous": user.is_anonymous,
|
||||
"login_channel": user.login_channel,
|
||||
"status": user.status,
|
||||
"is_superuser": user.is_superuser,
|
||||
"create_date": user.create_date,
|
||||
"update_date": user.update_date,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
@ -126,8 +135,8 @@ class UserMgr:
|
||||
# format activate_status before handle
|
||||
_activate_status = activate_status.lower()
|
||||
target_status = {
|
||||
'on': ActiveEnum.ACTIVE.value,
|
||||
'off': ActiveEnum.INACTIVE.value,
|
||||
"on": ActiveEnum.ACTIVE.value,
|
||||
"off": ActiveEnum.INACTIVE.value,
|
||||
}.get(_activate_status)
|
||||
if not target_status:
|
||||
raise AdminException(f"Invalid activate_status: {activate_status}")
|
||||
@ -137,9 +146,84 @@ class UserMgr:
|
||||
UserService.update_user(usr.id, {"is_active": target_status})
|
||||
return f"Turn {_activate_status} user activate status successfully!"
|
||||
|
||||
@staticmethod
|
||||
def get_user_api_key(username: str) -> list[dict[str, Any]]:
|
||||
# use email to find user. check exist and unique.
|
||||
user_list: list[Any] = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"More than one user with username '{username}' found!")
|
||||
|
||||
usr: Any = user_list[0]
|
||||
# tenant_id is typically the same as user_id for the owner tenant
|
||||
tenant_id: str = usr.id
|
||||
|
||||
# Query all API keys for this tenant
|
||||
api_keys: Any = APITokenService.query(tenant_id=tenant_id)
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
for key in api_keys:
|
||||
result.append(key.to_dict())
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def save_api_key(api_key: dict[str, Any]) -> bool:
|
||||
return APITokenService.save(**api_key)
|
||||
|
||||
@staticmethod
|
||||
def delete_api_key(username: str, key: str) -> bool:
|
||||
# use email to find user. check exist and unique.
|
||||
user_list: list[Any] = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
|
||||
usr: Any = user_list[0]
|
||||
# tenant_id is typically the same as user_id for the owner tenant
|
||||
tenant_id: str = usr.id
|
||||
|
||||
# Delete the API key
|
||||
deleted_count: int = APITokenService.filter_delete([APIToken.tenant_id == tenant_id, APIToken.token == key])
|
||||
return deleted_count > 0
|
||||
|
||||
@staticmethod
|
||||
def grant_admin(username: str):
|
||||
# use email to find user. check exist and unique.
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
|
||||
# check activate status different from new
|
||||
usr = user_list[0]
|
||||
if usr.is_superuser:
|
||||
return f"{usr} is already superuser!"
|
||||
# update is_active
|
||||
UserService.update_user(usr.id, {"is_superuser": True})
|
||||
return "Grant successfully!"
|
||||
|
||||
@staticmethod
|
||||
def revoke_admin(username: str):
|
||||
# use email to find user. check exist and unique.
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
elif len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
# check activate status different from new
|
||||
usr = user_list[0]
|
||||
if not usr.is_superuser:
|
||||
return f"{usr} isn't superuser, yet!"
|
||||
# update is_active
|
||||
UserService.update_user(usr.id, {"is_superuser": False})
|
||||
return "Revoke successfully!"
|
||||
|
||||
|
||||
class UserServiceMgr:
|
||||
|
||||
@staticmethod
|
||||
def get_user_datasets(username):
|
||||
# use email to find user.
|
||||
@ -169,39 +253,43 @@ class UserServiceMgr:
|
||||
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||
# filter permitted agents and owned agents
|
||||
res = UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
|
||||
return [{
|
||||
'title': r['title'],
|
||||
'permission': r['permission'],
|
||||
'canvas_category': r['canvas_category'].split('_')[0],
|
||||
'avatar': r['avatar']
|
||||
} for r in res]
|
||||
return [{"title": r["title"], "permission": r["permission"], "canvas_category": r["canvas_category"].split("_")[0], "avatar": r["avatar"]} for r in res]
|
||||
|
||||
@staticmethod
|
||||
def get_user_tenants(email: str) -> list[dict[str, Any]]:
|
||||
users: list[Any] = UserService.query_user_by_email(email)
|
||||
if not users:
|
||||
raise UserNotFoundError(email)
|
||||
user: Any = users[0]
|
||||
|
||||
tenants: list[dict[str, Any]] = UserTenantService.get_tenants_by_user_id(user.id)
|
||||
return tenants
|
||||
|
||||
|
||||
class ServiceMgr:
|
||||
|
||||
@staticmethod
|
||||
def get_all_services():
|
||||
doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch')
|
||||
doc_engine = os.getenv("DOC_ENGINE", "elasticsearch")
|
||||
result = []
|
||||
configs = SERVICE_CONFIGS.configs
|
||||
for service_id, config in enumerate(configs):
|
||||
config_dict = config.to_dict()
|
||||
if config_dict['service_type'] == 'retrieval':
|
||||
if config_dict['extra']['retrieval_type'] != doc_engine:
|
||||
if config_dict["service_type"] == "retrieval":
|
||||
if config_dict["extra"]["retrieval_type"] != doc_engine:
|
||||
continue
|
||||
try:
|
||||
service_detail = ServiceMgr.get_service_details(service_id)
|
||||
if "status" in service_detail:
|
||||
config_dict['status'] = service_detail['status']
|
||||
config_dict["status"] = service_detail["status"]
|
||||
else:
|
||||
config_dict['status'] = 'timeout'
|
||||
config_dict["status"] = "timeout"
|
||||
except Exception as e:
|
||||
logging.warning(f"Can't get service details, error: {e}")
|
||||
config_dict['status'] = 'timeout'
|
||||
if not config_dict['host']:
|
||||
config_dict['host'] = '-'
|
||||
if not config_dict['port']:
|
||||
config_dict['port'] = '-'
|
||||
config_dict["status"] = "timeout"
|
||||
if not config_dict["host"]:
|
||||
config_dict["host"] = "-"
|
||||
if not config_dict["port"]:
|
||||
config_dict["port"] = "-"
|
||||
result.append(config_dict)
|
||||
return result
|
||||
|
||||
@ -217,11 +305,18 @@ class ServiceMgr:
|
||||
raise AdminException(f"invalid service_index: {service_idx}")
|
||||
|
||||
service_config = configs[service_idx]
|
||||
service_info = {'name': service_config.name, 'detail_func_name': service_config.detail_func_name}
|
||||
|
||||
detail_func = getattr(health_utils, service_info.get('detail_func_name'))
|
||||
# exclude retrieval service if retrieval_type is not matched
|
||||
doc_engine = os.getenv("DOC_ENGINE", "elasticsearch")
|
||||
if service_config.service_type == "retrieval":
|
||||
if service_config.retrieval_type != doc_engine:
|
||||
raise AdminException(f"invalid service_index: {service_idx}")
|
||||
|
||||
service_info = {"name": service_config.name, "detail_func_name": service_config.detail_func_name}
|
||||
|
||||
detail_func = getattr(health_utils, service_info.get("detail_func_name"))
|
||||
res = detail_func()
|
||||
res.update({'service_name': service_info.get('name')})
|
||||
res.update({"service_name": service_info.get("name")})
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
@ -231,3 +326,84 @@ class ServiceMgr:
|
||||
@staticmethod
|
||||
def restart_service(service_id: int):
|
||||
raise AdminException("restart_service: not implemented")
|
||||
|
||||
|
||||
class SettingsMgr:
|
||||
@staticmethod
|
||||
def get_all():
|
||||
settings = SystemSettingsService.get_all()
|
||||
result = []
|
||||
for setting in settings:
|
||||
result.append(
|
||||
{
|
||||
"name": setting.name,
|
||||
"source": setting.source,
|
||||
"data_type": setting.data_type,
|
||||
"value": setting.value,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_by_name(name: str):
|
||||
settings = SystemSettingsService.get_by_name(name)
|
||||
if len(settings) == 0:
|
||||
raise AdminException(f"Can't get setting: {name}")
|
||||
result = []
|
||||
for setting in settings:
|
||||
result.append(
|
||||
{
|
||||
"name": setting.name,
|
||||
"source": setting.source,
|
||||
"data_type": setting.data_type,
|
||||
"value": setting.value,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def update_by_name(name: str, value: str):
|
||||
settings = SystemSettingsService.get_by_name(name)
|
||||
if len(settings) == 1:
|
||||
setting = settings[0]
|
||||
setting.value = value
|
||||
setting_dict = setting.to_dict()
|
||||
SystemSettingsService.update_by_name(name, setting_dict)
|
||||
elif len(settings) > 1:
|
||||
raise AdminException(f"Can't update more than 1 setting: {name}")
|
||||
else:
|
||||
raise AdminException(f"No setting: {name}")
|
||||
|
||||
|
||||
class ConfigMgr:
|
||||
@staticmethod
|
||||
def get_all():
|
||||
result = []
|
||||
configs = SERVICE_CONFIGS.configs
|
||||
for config in configs:
|
||||
config_dict = config.to_dict()
|
||||
result.append(config_dict)
|
||||
return result
|
||||
|
||||
|
||||
class EnvironmentsMgr:
|
||||
@staticmethod
|
||||
def get_all():
|
||||
result = []
|
||||
|
||||
env_kv = {"env": "DOC_ENGINE", "value": os.getenv("DOC_ENGINE")}
|
||||
result.append(env_kv)
|
||||
|
||||
env_kv = {"env": "DEFAULT_SUPERUSER_EMAIL", "value": os.getenv("DEFAULT_SUPERUSER_EMAIL", "admin@ragflow.io")}
|
||||
result.append(env_kv)
|
||||
|
||||
env_kv = {"env": "DB_TYPE", "value": os.getenv("DB_TYPE", "mysql")}
|
||||
result.append(env_kv)
|
||||
|
||||
env_kv = {"env": "DEVICE", "value": os.getenv("DEVICE", "cpu")}
|
||||
result.append(env_kv)
|
||||
|
||||
env_kv = {"env": "STORAGE_IMPL", "value": os.getenv("STORAGE_IMPL", "MINIO")}
|
||||
result.append(env_kv)
|
||||
|
||||
return result
|
||||
|
||||
@ -27,6 +27,10 @@ import pandas as pd
|
||||
from agent import settings
|
||||
from common.connection_utils import timeout
|
||||
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
|
||||
_DEPRECATED_PARAMS = "_deprecated_params"
|
||||
_USER_FEEDED_PARAMS = "_user_feeded_params"
|
||||
@ -379,6 +383,7 @@ class ComponentBase(ABC):
|
||||
|
||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||
from agent.canvas import Graph # Local import to avoid cyclic dependency
|
||||
|
||||
assert isinstance(canvas, Graph), "canvas must be an instance of Canvas"
|
||||
self._canvas = canvas
|
||||
self._id = id
|
||||
@ -430,7 +435,7 @@ class ComponentBase(ABC):
|
||||
elif asyncio.iscoroutinefunction(self._invoke):
|
||||
await self._invoke(**kwargs)
|
||||
else:
|
||||
await asyncio.to_thread(self._invoke, **kwargs)
|
||||
await thread_pool_exec(self._invoke, **kwargs)
|
||||
except Exception as e:
|
||||
if self.get_exception_default_value():
|
||||
self.set_exception_default_value()
|
||||
|
||||
@ -97,6 +97,13 @@ Here's description of each category:
|
||||
class Categorize(LLM, ABC):
|
||||
component_name = "Categorize"
|
||||
|
||||
def get_input_elements(self) -> dict[str, dict]:
|
||||
query_key = self._param.query or "sys.query"
|
||||
elements = self.get_input_elements_from_text(f"{{{query_key}}}")
|
||||
if not elements:
|
||||
logging.warning(f"[Categorize] input element not detected for query key: {query_key}")
|
||||
return elements
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("Categorize processing"):
|
||||
@ -105,12 +112,15 @@ class Categorize(LLM, ABC):
|
||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||
if not msg:
|
||||
msg = [{"role": "user", "content": ""}]
|
||||
if kwargs.get("sys.query"):
|
||||
msg[-1]["content"] = kwargs["sys.query"]
|
||||
self.set_input_value("sys.query", kwargs["sys.query"])
|
||||
query_key = self._param.query or "sys.query"
|
||||
if query_key in kwargs:
|
||||
query_value = kwargs[query_key]
|
||||
else:
|
||||
msg[-1]["content"] = self._canvas.get_variable_value(self._param.query)
|
||||
self.set_input_value(self._param.query, msg[-1]["content"])
|
||||
query_value = self._canvas.get_variable_value(query_key)
|
||||
if query_value is None:
|
||||
query_value = ""
|
||||
msg[-1]["content"] = query_value
|
||||
self.set_input_value(query_key, msg[-1]["content"])
|
||||
self._param.update_prompt()
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
||||
|
||||
|
||||
@ -27,6 +27,10 @@ from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
|
||||
from timeit import default_timer as timer
|
||||
|
||||
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
class ToolParameter(TypedDict):
|
||||
type: str
|
||||
description: str
|
||||
@ -56,12 +60,12 @@ class LLMToolPluginCallSession(ToolCallSession):
|
||||
st = timer()
|
||||
tool_obj = self.tools_map[name]
|
||||
if isinstance(tool_obj, MCPToolCallSession):
|
||||
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
|
||||
resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, 60)
|
||||
else:
|
||||
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
|
||||
resp = await tool_obj.invoke_async(**arguments)
|
||||
else:
|
||||
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
|
||||
resp = await thread_pool_exec(tool_obj.invoke, **arguments)
|
||||
|
||||
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
||||
return resp
|
||||
@ -122,6 +126,7 @@ class ToolParamBase(ComponentParamBase):
|
||||
class ToolBase(ComponentBase):
|
||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||
from agent.canvas import Canvas # Local import to avoid cyclic dependency
|
||||
|
||||
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
|
||||
self._canvas = canvas
|
||||
self._id = id
|
||||
@ -164,7 +169,7 @@ class ToolBase(ComponentBase):
|
||||
elif asyncio.iscoroutinefunction(self._invoke):
|
||||
res = await self._invoke(**kwargs)
|
||||
else:
|
||||
res = await asyncio.to_thread(self._invoke, **kwargs)
|
||||
res = await thread_pool_exec(self._invoke, **kwargs)
|
||||
except Exception as e:
|
||||
self._param.outputs["_ERROR"] = {"value": str(e)}
|
||||
logging.exception(e)
|
||||
|
||||
@ -86,6 +86,12 @@ class ExeSQL(ToolBase, ABC):
|
||||
|
||||
def convert_decimals(obj):
|
||||
from decimal import Decimal
|
||||
import math
|
||||
if isinstance(obj, float):
|
||||
# Handle NaN and Infinity which are not valid JSON values
|
||||
if math.isnan(obj) or math.isinf(obj):
|
||||
return None
|
||||
return obj
|
||||
if isinstance(obj, Decimal):
|
||||
return float(obj) # 或 str(obj)
|
||||
elif isinstance(obj, dict):
|
||||
|
||||
@ -174,7 +174,7 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
if kbs:
|
||||
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
||||
kbinfos = settings.retriever.retrieval(
|
||||
kbinfos = await settings.retriever.retrieval(
|
||||
query,
|
||||
embd_mdl,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
@ -193,7 +193,7 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
if self._param.toc_enhance:
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
|
||||
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
|
||||
cks = await settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
|
||||
chat_mdl, self._param.top_n)
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
|
||||
@ -1 +0,0 @@
|
||||
from .deep_research import DeepResearcher as DeepResearcher
|
||||
@ -1,238 +0,0 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import re
|
||||
from functools import partial
|
||||
from agentic_reasoning.prompts import BEGIN_SEARCH_QUERY, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT, MAX_SEARCH_LIMIT, \
|
||||
END_SEARCH_QUERY, REASON_PROMPT, RELEVANT_EXTRACTION_PROMPT
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from rag.nlp import extract_between
|
||||
from rag.prompts import kb_prompt
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
|
||||
|
||||
class DeepResearcher:
|
||||
def __init__(self,
|
||||
chat_mdl: LLMBundle,
|
||||
prompt_config: dict,
|
||||
kb_retrieve: partial = None,
|
||||
kg_retrieve: partial = None
|
||||
):
|
||||
self.chat_mdl = chat_mdl
|
||||
self.prompt_config = prompt_config
|
||||
self._kb_retrieve = kb_retrieve
|
||||
self._kg_retrieve = kg_retrieve
|
||||
|
||||
def _remove_tags(text: str, start_tag: str, end_tag: str) -> str:
|
||||
"""General Tag Removal Method"""
|
||||
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
|
||||
return re.sub(pattern, "", text)
|
||||
|
||||
@staticmethod
|
||||
def _remove_query_tags(text: str) -> str:
|
||||
"""Remove Query Tags"""
|
||||
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
|
||||
|
||||
@staticmethod
|
||||
def _remove_result_tags(text: str) -> str:
|
||||
"""Remove Result Tags"""
|
||||
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT)
|
||||
|
||||
async def _generate_reasoning(self, msg_history):
|
||||
"""Generate reasoning steps"""
|
||||
query_think = ""
|
||||
if msg_history[-1]["role"] != "user":
|
||||
msg_history.append({"role": "user", "content": "Continues reasoning with the new information.\n"})
|
||||
else:
|
||||
msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
|
||||
|
||||
async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
if not ans:
|
||||
continue
|
||||
query_think = ans
|
||||
yield query_think
|
||||
query_think = ""
|
||||
yield query_think
|
||||
|
||||
def _extract_search_queries(self, query_think, question, step_index):
|
||||
"""Extract search queries from thinking"""
|
||||
queries = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
|
||||
if not queries and step_index == 0:
|
||||
# If this is the first step and no queries are found, use the original question as the query
|
||||
queries = [question]
|
||||
return queries
|
||||
|
||||
def _truncate_previous_reasoning(self, all_reasoning_steps):
|
||||
"""Truncate previous reasoning steps to maintain a reasonable length"""
|
||||
truncated_prev_reasoning = ""
|
||||
for i, step in enumerate(all_reasoning_steps):
|
||||
truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n"
|
||||
|
||||
prev_steps = truncated_prev_reasoning.split('\n\n')
|
||||
if len(prev_steps) <= 5:
|
||||
truncated_prev_reasoning = '\n\n'.join(prev_steps)
|
||||
else:
|
||||
truncated_prev_reasoning = ''
|
||||
for i, step in enumerate(prev_steps):
|
||||
if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step:
|
||||
truncated_prev_reasoning += step + '\n\n'
|
||||
else:
|
||||
if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
|
||||
truncated_prev_reasoning += '...\n\n'
|
||||
|
||||
return truncated_prev_reasoning.strip('\n')
|
||||
|
||||
def _retrieve_information(self, search_query):
|
||||
"""Retrieve information from different sources"""
|
||||
# 1. Knowledge base retrieval
|
||||
kbinfos = []
|
||||
try:
|
||||
kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []}
|
||||
except Exception as e:
|
||||
logging.error(f"Knowledge base retrieval error: {e}")
|
||||
|
||||
# 2. Web retrieval (if Tavily API is configured)
|
||||
try:
|
||||
if self.prompt_config.get("tavily_api_key"):
|
||||
tav = Tavily(self.prompt_config["tavily_api_key"])
|
||||
tav_res = tav.retrieve_chunks(search_query)
|
||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
except Exception as e:
|
||||
logging.error(f"Web retrieval error: {e}")
|
||||
|
||||
# 3. Knowledge graph retrieval (if configured)
|
||||
try:
|
||||
if self.prompt_config.get("use_kg") and self._kg_retrieve:
|
||||
ck = self._kg_retrieve(question=search_query)
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
except Exception as e:
|
||||
logging.error(f"Knowledge graph retrieval error: {e}")
|
||||
|
||||
return kbinfos
|
||||
|
||||
def _update_chunk_info(self, chunk_info, kbinfos):
|
||||
"""Update chunk information for citations"""
|
||||
if not chunk_info["chunks"]:
|
||||
# If this is the first retrieval, use the retrieval results directly
|
||||
for k in chunk_info.keys():
|
||||
chunk_info[k] = kbinfos[k]
|
||||
else:
|
||||
# Merge newly retrieved information, avoiding duplicates
|
||||
cids = [c["chunk_id"] for c in chunk_info["chunks"]]
|
||||
for c in kbinfos["chunks"]:
|
||||
if c["chunk_id"] not in cids:
|
||||
chunk_info["chunks"].append(c)
|
||||
|
||||
dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
|
||||
for d in kbinfos["doc_aggs"]:
|
||||
if d["doc_id"] not in dids:
|
||||
chunk_info["doc_aggs"].append(d)
|
||||
|
||||
async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
|
||||
"""Extract and summarize relevant information"""
|
||||
summary_think = ""
|
||||
async for ans in self.chat_mdl.async_chat_streamly(
|
||||
RELEVANT_EXTRACTION_PROMPT.format(
|
||||
prev_reasoning=truncated_prev_reasoning,
|
||||
search_query=search_query,
|
||||
document="\n".join(kb_prompt(kbinfos, 4096))
|
||||
),
|
||||
[{"role": "user",
|
||||
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
|
||||
{"temperature": 0.7}):
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
if not ans:
|
||||
continue
|
||||
summary_think = ans
|
||||
yield summary_think
|
||||
summary_think = ""
|
||||
|
||||
yield summary_think
|
||||
|
||||
async def thinking(self, chunk_info: dict, question: str):
|
||||
executed_search_queries = []
|
||||
msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
|
||||
all_reasoning_steps = []
|
||||
think = "<think>"
|
||||
|
||||
for step_index in range(MAX_SEARCH_LIMIT + 1):
|
||||
# Check if the maximum search limit has been reached
|
||||
if step_index == MAX_SEARCH_LIMIT - 1:
|
||||
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n"
|
||||
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
|
||||
all_reasoning_steps.append(summary_think)
|
||||
msg_history.append({"role": "assistant", "content": summary_think})
|
||||
break
|
||||
|
||||
# Step 1: Generate reasoning
|
||||
query_think = ""
|
||||
async for ans in self._generate_reasoning(msg_history):
|
||||
query_think = ans
|
||||
yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||
|
||||
think += self._remove_query_tags(query_think)
|
||||
all_reasoning_steps.append(query_think)
|
||||
|
||||
# Step 2: Extract search queries
|
||||
queries = self._extract_search_queries(query_think, question, step_index)
|
||||
if not queries and step_index > 0:
|
||||
# If not the first step and no queries, end the search process
|
||||
break
|
||||
|
||||
# Process each search query
|
||||
for search_query in queries:
|
||||
logging.info(f"[THINK]Query: {step_index}. {search_query}")
|
||||
msg_history.append({"role": "assistant", "content": search_query})
|
||||
think += f"\n\n> {step_index + 1}. {search_query}\n\n"
|
||||
yield {"answer": think + "</think>", "reference": {}, "audio_binary": None}
|
||||
|
||||
# Check if the query has already been executed
|
||||
if search_query in executed_search_queries:
|
||||
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n"
|
||||
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
|
||||
all_reasoning_steps.append(summary_think)
|
||||
msg_history.append({"role": "user", "content": summary_think})
|
||||
think += summary_think
|
||||
continue
|
||||
|
||||
executed_search_queries.append(search_query)
|
||||
|
||||
# Step 3: Truncate previous reasoning steps
|
||||
truncated_prev_reasoning = self._truncate_previous_reasoning(all_reasoning_steps)
|
||||
|
||||
# Step 4: Retrieve information
|
||||
kbinfos = self._retrieve_information(search_query)
|
||||
|
||||
# Step 5: Update chunk information
|
||||
self._update_chunk_info(chunk_info, kbinfos)
|
||||
|
||||
# Step 6: Extract relevant information
|
||||
think += "\n\n"
|
||||
summary_think = ""
|
||||
async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
|
||||
summary_think = ans
|
||||
yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||
|
||||
all_reasoning_steps.append(summary_think)
|
||||
msg_history.append(
|
||||
{"role": "user", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"})
|
||||
think += self._remove_result_tags(summary_think)
|
||||
logging.info(f"[THINK]Summary: {step_index}. {summary_think}")
|
||||
|
||||
yield think + "</think>"
|
||||
@ -1,147 +0,0 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
|
||||
END_SEARCH_QUERY = "<|end_search_query|>"
|
||||
BEGIN_SEARCH_RESULT = "<|begin_search_result|>"
|
||||
END_SEARCH_RESULT = "<|end_search_result|>"
|
||||
MAX_SEARCH_LIMIT = 6
|
||||
|
||||
REASON_PROMPT = f"""You are an advanced reasoning agent. Your goal is to answer the user's question by breaking it down into a series of verifiable steps.
|
||||
|
||||
You have access to a powerful search tool to find information.
|
||||
|
||||
**Your Task:**
|
||||
1. Analyze the user's question.
|
||||
2. If you need information, issue a search query to find a specific fact.
|
||||
3. Review the search results.
|
||||
4. Repeat the search process until you have all the facts needed to answer the question.
|
||||
5. Once you have gathered sufficient information, synthesize the facts and provide the final answer directly.
|
||||
|
||||
**Tool Usage:**
|
||||
- To search, you MUST write your query between the special tokens: {BEGIN_SEARCH_QUERY}your query{END_SEARCH_QUERY}.
|
||||
- The system will provide results between {BEGIN_SEARCH_RESULT}search results{END_SEARCH_RESULT}.
|
||||
- You have a maximum of {MAX_SEARCH_LIMIT} search attempts.
|
||||
|
||||
---
|
||||
**Example 1: Multi-hop Question**
|
||||
|
||||
**Question:** "Are both the directors of Jaws and Casino Royale from the same country?"
|
||||
|
||||
**Your Thought Process & Actions:**
|
||||
First, I need to identify the director of Jaws.
|
||||
{BEGIN_SEARCH_QUERY}who is the director of Jaws?{END_SEARCH_QUERY}
|
||||
[System returns search results]
|
||||
{BEGIN_SEARCH_RESULT}
|
||||
Jaws is a 1975 American thriller film directed by Steven Spielberg.
|
||||
{END_SEARCH_RESULT}
|
||||
Okay, the director of Jaws is Steven Spielberg. Now I need to find out his nationality.
|
||||
{BEGIN_SEARCH_QUERY}where is Steven Spielberg from?{END_SEARCH_QUERY}
|
||||
[System returns search results]
|
||||
{BEGIN_SEARCH_RESULT}
|
||||
Steven Allan Spielberg is an American filmmaker. Born in Cincinnati, Ohio...
|
||||
{END_SEARCH_RESULT}
|
||||
So, Steven Spielberg is from the USA. Next, I need to find the director of Casino Royale.
|
||||
{BEGIN_SEARCH_QUERY}who is the director of Casino Royale 2006?{END_SEARCH_QUERY}
|
||||
[System returns search results]
|
||||
{BEGIN_SEARCH_RESULT}
|
||||
Casino Royale is a 2006 spy film directed by Martin Campbell.
|
||||
{END_SEARCH_RESULT}
|
||||
The director of Casino Royale is Martin Campbell. Now I need his nationality.
|
||||
{BEGIN_SEARCH_QUERY}where is Martin Campbell from?{END_SEARCH_QUERY}
|
||||
[System returns search results]
|
||||
{BEGIN_SEARCH_RESULT}
|
||||
Martin Campbell (born 24 October 1943) is a New Zealand film and television director.
|
||||
{END_SEARCH_RESULT}
|
||||
I have all the information. Steven Spielberg is from the USA, and Martin Campbell is from New Zealand. They are not from the same country.
|
||||
|
||||
Final Answer: No, the directors of Jaws and Casino Royale are not from the same country. Steven Spielberg is from the USA, and Martin Campbell is from New Zealand.
|
||||
|
||||
---
|
||||
**Example 2: Simple Fact Retrieval**
|
||||
|
||||
**Question:** "When was the founder of craigslist born?"
|
||||
|
||||
**Your Thought Process & Actions:**
|
||||
First, I need to know who founded craigslist.
|
||||
{BEGIN_SEARCH_QUERY}who founded craigslist?{END_SEARCH_QUERY}
|
||||
[System returns search results]
|
||||
{BEGIN_SEARCH_RESULT}
|
||||
Craigslist was founded in 1995 by Craig Newmark.
|
||||
{END_SEARCH_RESULT}
|
||||
The founder is Craig Newmark. Now I need his birth date.
|
||||
{BEGIN_SEARCH_QUERY}when was Craig Newmark born?{END_SEARCH_QUERY}
|
||||
[System returns search results]
|
||||
{BEGIN_SEARCH_RESULT}
|
||||
Craig Newmark was born on December 6, 1952.
|
||||
{END_SEARCH_RESULT}
|
||||
I have found the answer.
|
||||
|
||||
Final Answer: The founder of craigslist, Craig Newmark, was born on December 6, 1952.
|
||||
|
||||
---
|
||||
**Important Rules:**
|
||||
- **One Fact at a Time:** Decompose the problem and issue one search query at a time to find a single, specific piece of information.
|
||||
- **Be Precise:** Formulate clear and precise search queries. If a search fails, rephrase it.
|
||||
- **Synthesize at the End:** Do not provide the final answer until you have completed all necessary searches.
|
||||
- **Language Consistency:** Your search queries should be in the same language as the user's question.
|
||||
|
||||
Now, begin your work. Please answer the following question by thinking step-by-step.
|
||||
"""
|
||||
|
||||
RELEVANT_EXTRACTION_PROMPT = """You are a highly efficient information extraction module. Your sole purpose is to extract the single most relevant piece of information from the provided `Searched Web Pages` that directly answers the `Current Search Query`.
|
||||
|
||||
**Your Task:**
|
||||
1. Read the `Current Search Query` to understand what specific information is needed.
|
||||
2. Scan the `Searched Web Pages` to find the answer to that query.
|
||||
3. Extract only the essential, factual information that answers the query. Be concise.
|
||||
|
||||
**Context (For Your Information Only):**
|
||||
The `Previous Reasoning Steps` are provided to give you context on the overall goal, but your primary focus MUST be on answering the `Current Search Query`. Do not use information from the previous steps in your output.
|
||||
|
||||
**Output Format:**
|
||||
Your response must follow one of two formats precisely.
|
||||
|
||||
1. **If a direct and relevant answer is found:**
|
||||
- Start your response immediately with `Final Information`.
|
||||
- Provide only the extracted fact(s). Do not add any extra conversational text.
|
||||
|
||||
*Example:*
|
||||
`Current Search Query`: Where is Martin Campbell from?
|
||||
`Searched Web Pages`: [Long article snippet about Martin Campbell's career, which includes the sentence "Martin Campbell (born 24 October 1943) is a New Zealand film and television director..."]
|
||||
|
||||
*Your Output:*
|
||||
Final Information
|
||||
Martin Campbell is a New Zealand film and television director.
|
||||
|
||||
2. **If no relevant answer that directly addresses the query is found in the web pages:**
|
||||
- Start your response immediately with `Final Information`.
|
||||
- Write the exact phrase: `No helpful information found.`
|
||||
|
||||
---
|
||||
**BEGIN TASK**
|
||||
|
||||
**Inputs:**
|
||||
|
||||
- **Previous Reasoning Steps:**
|
||||
{prev_reasoning}
|
||||
|
||||
- **Current Search Query:**
|
||||
{search_query}
|
||||
|
||||
- **Searched Web Pages:**
|
||||
{document}
|
||||
"""
|
||||
@ -16,21 +16,23 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from quart import Blueprint, Quart, request, g, current_app, session
|
||||
from flasgger import Swagger
|
||||
from quart import Blueprint, Quart, request, g, current_app, session, jsonify
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from quart_cors import cors
|
||||
from common.constants import StatusEnum
|
||||
from common.constants import StatusEnum, RetCode
|
||||
from api.db.db_models import close_connection, APIToken
|
||||
from api.db.services import UserService
|
||||
from api.utils.json_encode import CustomJSONEncoder
|
||||
from api.utils import commands
|
||||
|
||||
from quart_auth import Unauthorized
|
||||
from quart_auth import Unauthorized as QuartAuthUnauthorized
|
||||
from werkzeug.exceptions import Unauthorized as WerkzeugUnauthorized
|
||||
from quart_schema import QuartSchema
|
||||
from common import settings
|
||||
from api.utils.api_utils import server_error_response
|
||||
from api.utils.api_utils import server_error_response, get_json_result
|
||||
from api.constants import API_VERSION
|
||||
from common.misc_utils import get_uuid
|
||||
|
||||
@ -38,40 +40,27 @@ settings.init_settings()
|
||||
|
||||
__all__ = ["app"]
|
||||
|
||||
UNAUTHORIZED_MESSAGE = "<Unauthorized '401: Unauthorized'>"
|
||||
|
||||
|
||||
def _unauthorized_message(error):
|
||||
if error is None:
|
||||
return UNAUTHORIZED_MESSAGE
|
||||
try:
|
||||
msg = repr(error)
|
||||
except Exception:
|
||||
return UNAUTHORIZED_MESSAGE
|
||||
if msg == UNAUTHORIZED_MESSAGE:
|
||||
return msg
|
||||
if "Unauthorized" in msg and "401" in msg:
|
||||
return msg
|
||||
return UNAUTHORIZED_MESSAGE
|
||||
|
||||
app = Quart(__name__)
|
||||
app = cors(app, allow_origin="*")
|
||||
|
||||
# Add this at the beginning of your file to configure Swagger UI
|
||||
swagger_config = {
|
||||
"headers": [],
|
||||
"specs": [
|
||||
{
|
||||
"endpoint": "apispec",
|
||||
"route": "/apispec.json",
|
||||
"rule_filter": lambda rule: True, # Include all endpoints
|
||||
"model_filter": lambda tag: True, # Include all models
|
||||
}
|
||||
],
|
||||
"static_url_path": "/flasgger_static",
|
||||
"swagger_ui": True,
|
||||
"specs_route": "/apidocs/",
|
||||
}
|
||||
|
||||
swagger = Swagger(
|
||||
app,
|
||||
config=swagger_config,
|
||||
template={
|
||||
"swagger": "2.0",
|
||||
"info": {
|
||||
"title": "RAGFlow API",
|
||||
"description": "",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
"securityDefinitions": {
|
||||
"ApiKeyAuth": {"type": "apiKey", "name": "Authorization", "in": "header"}
|
||||
},
|
||||
},
|
||||
)
|
||||
# openapi supported
|
||||
QuartSchema(app)
|
||||
|
||||
app.url_map.strict_slashes = False
|
||||
app.json_encoder = CustomJSONEncoder
|
||||
@ -125,18 +114,28 @@ def _load_user():
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if not user and len(authorization.split()) == 2:
|
||||
objs = APIToken.query(token=authorization.split()[1])
|
||||
if objs:
|
||||
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
g.user = user[0]
|
||||
return user[0]
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
except Exception as e_auth:
|
||||
logging.warning(f"load_user got exception {e_auth}")
|
||||
try:
|
||||
authorization = request.headers.get("Authorization")
|
||||
if len(authorization.split()) == 2:
|
||||
objs = APIToken.query(token=authorization.split()[1])
|
||||
if objs:
|
||||
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
g.user = user[0]
|
||||
return user[0]
|
||||
except Exception as e_api_token:
|
||||
logging.warning(f"load_user got exception {e_api_token}")
|
||||
|
||||
|
||||
current_user = LocalProxy(_load_user)
|
||||
@ -164,10 +163,18 @@ def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
if not current_user: # or not session.get("_user_id"):
|
||||
raise Unauthorized()
|
||||
else:
|
||||
return await current_app.ensure_async(func)(*args, **kwargs)
|
||||
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
|
||||
t_start = time.perf_counter() if timing_enabled else None
|
||||
user = current_user
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing login_required auth_ms=%.2f path=%s",
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
if not user: # or not session.get("_user_id"):
|
||||
raise QuartAuthUnauthorized()
|
||||
return await current_app.ensure_async(func)(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -277,14 +284,34 @@ client_urls_prefix = [
|
||||
|
||||
@app.errorhandler(404)
|
||||
async def not_found(error):
|
||||
error_msg: str = f"The requested URL {request.path} was not found"
|
||||
logging.error(error_msg)
|
||||
return {
|
||||
logging.error(f"The requested URL {request.path} was not found")
|
||||
message = f"Not Found: {request.path}"
|
||||
response = {
|
||||
"code": RetCode.NOT_FOUND,
|
||||
"message": message,
|
||||
"data": None,
|
||||
"error": "Not Found",
|
||||
"message": error_msg,
|
||||
}, 404
|
||||
}
|
||||
return jsonify(response), RetCode.NOT_FOUND
|
||||
|
||||
|
||||
@app.errorhandler(401)
|
||||
async def unauthorized(error):
|
||||
logging.warning("Unauthorized request")
|
||||
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
|
||||
|
||||
|
||||
@app.errorhandler(QuartAuthUnauthorized)
|
||||
async def unauthorized_quart_auth(error):
|
||||
logging.warning("Unauthorized request (quart_auth)")
|
||||
return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(error)), RetCode.UNAUTHORIZED
|
||||
|
||||
|
||||
@app.errorhandler(WerkzeugUnauthorized)
|
||||
async def unauthorized_werkzeug(error):
|
||||
logging.warning("Unauthorized request (werkzeug)")
|
||||
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
|
||||
|
||||
@app.teardown_request
|
||||
def _db_close(exception):
|
||||
if exception:
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
@ -29,9 +28,14 @@ from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, Ta
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
|
||||
get_request_json
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.utils.api_utils import (
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
get_data_error_result,
|
||||
get_request_json,
|
||||
)
|
||||
from agent.canvas import Canvas
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
from api.db.db_models import APIToken, Task
|
||||
@ -132,12 +136,12 @@ async def run():
|
||||
files = req.get("files", [])
|
||||
inputs = req.get("inputs", {})
|
||||
user_id = req.get("user_id", current_user.id)
|
||||
if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
|
||||
if not await thread_pool_exec(UserCanvasService.accessible, req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
|
||||
e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
@ -147,7 +151,7 @@ async def run():
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
task_id = get_uuid()
|
||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
|
||||
ok, error_message = await thread_pool_exec(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
|
||||
if not ok:
|
||||
return get_data_error_result(message=error_message)
|
||||
return get_json_result(data={"message_id": task_id})
|
||||
@ -540,6 +544,7 @@ def sessions(canvas_id):
|
||||
@login_required
|
||||
def prompts():
|
||||
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
||||
|
||||
return get_json_result(data={
|
||||
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
||||
"plan_generation": NEXT_STEP,
|
||||
|
||||
@ -13,11 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import base64
|
||||
import xxhash
|
||||
from quart import request
|
||||
|
||||
@ -27,8 +27,14 @@ from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
get_request_json
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
get_request_json,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
@ -38,7 +44,6 @@ from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
@ -61,7 +66,7 @@ async def list_chunk():
|
||||
}
|
||||
if "available_int" in req:
|
||||
query["available_int"] = int(req["available_int"])
|
||||
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
|
||||
sres = await settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
|
||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
@ -126,10 +131,15 @@ def get():
|
||||
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
||||
async def set():
|
||||
req = await get_request_json()
|
||||
content_with_weight = req["content_with_weight"]
|
||||
if not isinstance(content_with_weight, (str, bytes)):
|
||||
raise TypeError("expected string or bytes-like object")
|
||||
if isinstance(content_with_weight, bytes):
|
||||
content_with_weight = content_with_weight.decode("utf-8", errors="ignore")
|
||||
d = {
|
||||
"id": req["chunk_id"],
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
|
||||
"content_with_weight": content_with_weight}
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(content_with_weight)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
if "important_kwd" in req:
|
||||
if not isinstance(req["important_kwd"], list):
|
||||
@ -171,20 +181,21 @@ async def set():
|
||||
_d = beAdoc(d, q, a, not any(
|
||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
|
||||
v, c = embd_mdl.encode([doc.name, content_with_weight if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
_d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
# update image
|
||||
image_base64 = req.get("image_base64", None)
|
||||
if image_base64:
|
||||
bkt, name = req.get("img_id", "-").split("-")
|
||||
img_id = req.get("img_id", "")
|
||||
if image_base64 and img_id and "-" in img_id:
|
||||
bkt, name = img_id.split("-", 1)
|
||||
image_binary = base64.b64decode(image_base64)
|
||||
settings.STORAGE_IMPL.put(bkt, name, image_binary)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_set_sync)
|
||||
return await thread_pool_exec(_set_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -207,7 +218,7 @@ async def switch():
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_switch_sync)
|
||||
return await thread_pool_exec(_switch_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -222,19 +233,34 @@ async def rm():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
condition = {"id": req["chunk_ids"], "doc_id": req["doc_id"]}
|
||||
try:
|
||||
deleted_count = settings.docStoreConn.delete(condition,
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id)
|
||||
except Exception:
|
||||
return get_data_error_result(message="Chunk deleting failure")
|
||||
deleted_chunk_ids = req["chunk_ids"]
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
if isinstance(deleted_chunk_ids, list):
|
||||
unique_chunk_ids = list(dict.fromkeys(deleted_chunk_ids))
|
||||
has_ids = len(unique_chunk_ids) > 0
|
||||
else:
|
||||
unique_chunk_ids = [deleted_chunk_ids]
|
||||
has_ids = deleted_chunk_ids not in (None, "")
|
||||
if has_ids and deleted_count == 0:
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
if deleted_count > 0 and deleted_count < len(unique_chunk_ids):
|
||||
deleted_count += settings.docStoreConn.delete({"doc_id": req["doc_id"]},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id)
|
||||
chunk_number = deleted_count
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
for cid in deleted_chunk_ids:
|
||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
return await thread_pool_exec(_rm_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -244,6 +270,7 @@ async def rm():
|
||||
@validate_request("doc_id", "content_with_weight")
|
||||
async def create():
|
||||
req = await get_request_json()
|
||||
req_id = request.headers.get("X-Request-ID")
|
||||
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
|
||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
@ -260,14 +287,23 @@ async def create():
|
||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||
if "tag_feas" in req:
|
||||
d["tag_feas"] = req["tag_feas"]
|
||||
if "tag_feas" in req:
|
||||
d["tag_feas"] = req["tag_feas"]
|
||||
|
||||
try:
|
||||
def _log_response(resp, code, message):
|
||||
logging.info(
|
||||
"chunk_create response req_id=%s status=%s code=%s message=%s",
|
||||
req_id,
|
||||
getattr(resp, "status_code", None),
|
||||
code,
|
||||
message,
|
||||
)
|
||||
|
||||
def _create_sync():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
resp = get_data_error_result(message="Document not found!")
|
||||
_log_response(resp, RetCode.DATA_ERROR, "Document not found!")
|
||||
return resp
|
||||
d["kb_id"] = [doc.kb_id]
|
||||
d["docnm_kwd"] = doc.name
|
||||
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
||||
@ -275,11 +311,15 @@ async def create():
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
resp = get_data_error_result(message="Tenant not found!")
|
||||
_log_response(resp, RetCode.DATA_ERROR, "Tenant not found!")
|
||||
return resp
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
resp = get_data_error_result(message="Knowledgebase not found!")
|
||||
_log_response(resp, RetCode.DATA_ERROR, "Knowledgebase not found!")
|
||||
return resp
|
||||
if kb.pagerank:
|
||||
d[PAGERANK_FLD] = kb.pagerank
|
||||
|
||||
@ -293,10 +333,13 @@ async def create():
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
return get_json_result(data={"chunk_id": chunck_id})
|
||||
resp = get_json_result(data={"chunk_id": chunck_id})
|
||||
_log_response(resp, RetCode.SUCCESS, "success")
|
||||
return resp
|
||||
|
||||
return await asyncio.to_thread(_create_sync)
|
||||
return await thread_pool_exec(_create_sync)
|
||||
except Exception as e:
|
||||
logging.info("chunk_create exception req_id=%s error=%r", req_id, e)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@ -372,14 +415,21 @@ async def retrieval_test():
|
||||
_question += await keyword_extraction(chat_mdl, _question)
|
||||
|
||||
labels = label_question(_question, [kb])
|
||||
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
float(req.get("similarity_threshold", 0.0)),
|
||||
float(req.get("vector_similarity_weight", 0.3)),
|
||||
top,
|
||||
local_doc_ids, rerank_mdl=rerank_mdl,
|
||||
highlight=req.get("highlight", False),
|
||||
rank_feature=labels
|
||||
)
|
||||
ranks = await settings.retriever.retrieval(
|
||||
_question,
|
||||
embd_mdl,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
page,
|
||||
size,
|
||||
float(req.get("similarity_threshold", 0.0)),
|
||||
float(req.get("vector_similarity_weight", 0.3)),
|
||||
doc_ids=local_doc_ids,
|
||||
top=top,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=labels
|
||||
)
|
||||
|
||||
if use_kg:
|
||||
ck = await settings.kg_retriever.retrieval(_question,
|
||||
tenant_ids,
|
||||
@ -407,7 +457,7 @@ async def retrieval_test():
|
||||
|
||||
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def knowledge_graph():
|
||||
async def knowledge_graph():
|
||||
doc_id = request.args["doc_id"]
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
@ -415,7 +465,7 @@ def knowledge_graph():
|
||||
"doc_ids": [doc_id],
|
||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||
}
|
||||
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
||||
sres = await settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
for id in sres.ids[:2]:
|
||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||
|
||||
@ -25,6 +25,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, get_requ
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode
|
||||
from api.apps import login_required, current_user
|
||||
import logging
|
||||
|
||||
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@ -42,13 +43,19 @@ async def set_dialog():
|
||||
if len(name.encode("utf-8")) > 255:
|
||||
return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255")
|
||||
|
||||
if is_create and DialogService.query(tenant_id=current_user.id, name=name.strip()):
|
||||
name = name.strip()
|
||||
name = duplicate_name(
|
||||
DialogService.query,
|
||||
name=name,
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value)
|
||||
name = name.strip()
|
||||
if is_create:
|
||||
# only for chat creating
|
||||
existing_names = {
|
||||
d.name.casefold()
|
||||
for d in DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value)
|
||||
if d.name
|
||||
}
|
||||
if name.casefold() in existing_names:
|
||||
def _name_exists(name: str, **_kwargs) -> bool:
|
||||
return name.casefold() in existing_names
|
||||
|
||||
name = duplicate_name(_name_exists, name=name)
|
||||
|
||||
description = req.get("description", "A helpful dialog")
|
||||
icon = req.get("icon", "")
|
||||
@ -63,16 +70,30 @@ async def set_dialog():
|
||||
meta_data_filter = req.get("meta_data_filter", {})
|
||||
prompt_config = req["prompt_config"]
|
||||
|
||||
# Set default parameters for datasets with knowledge retrieval
|
||||
# All datasets with {knowledge} in system prompt need "knowledge" parameter to enable retrieval
|
||||
kb_ids = req.get("kb_ids", [])
|
||||
parameters = prompt_config.get("parameters")
|
||||
logging.debug(f"set_dialog: kb_ids={kb_ids}, parameters={parameters}, is_create={not is_create}")
|
||||
# Check if parameters is missing, None, or empty list
|
||||
if kb_ids and not parameters:
|
||||
# Check if system prompt uses {knowledge} placeholder
|
||||
if "{knowledge}" in prompt_config.get("system", ""):
|
||||
# Set default parameters for any dataset with knowledge placeholder
|
||||
prompt_config["parameters"] = [{"key": "knowledge", "optional": False}]
|
||||
logging.debug(f"Set default parameters for datasets with knowledge placeholder: {kb_ids}")
|
||||
|
||||
if not is_create:
|
||||
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']:
|
||||
# only for chat updating
|
||||
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config.get("system", ""):
|
||||
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.")
|
||||
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["optional"]:
|
||||
continue
|
||||
if prompt_config["system"].find("{%s}" % p["key"]) < 0:
|
||||
return get_data_error_result(
|
||||
message="Parameter '{}' is not used".format(p["key"]))
|
||||
for p in prompt_config.get("parameters", []):
|
||||
if p["optional"]:
|
||||
continue
|
||||
if prompt_config.get("system", "").find("{%s}" % p["key"]) < 0:
|
||||
return get_data_error_result(
|
||||
message="Parameter '{}' is not used".format(p["key"]))
|
||||
|
||||
try:
|
||||
e, tenant = TenantService.get_by_id(current_user.id)
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import os.path
|
||||
import pathlib
|
||||
@ -27,18 +26,19 @@ from api.db import VALID_FILE_TYPES, FileType
|
||||
from api.db.db_models import Task
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from common.metadata_utils import meta_filter, convert_conditions
|
||||
from common.metadata_utils import meta_filter, convert_conditions, turn2jsonschema
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request, get_request_json,
|
||||
validate_request,
|
||||
get_request_json,
|
||||
)
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from common.file_utils import get_project_base_directory
|
||||
@ -62,10 +62,21 @@ async def upload():
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = files.getlist("file")
|
||||
def _close_file_objs(objs):
|
||||
for obj in objs:
|
||||
try:
|
||||
obj.close()
|
||||
except Exception:
|
||||
try:
|
||||
obj.stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
_close_file_objs(file_objs)
|
||||
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
_close_file_objs(file_objs)
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
@ -74,8 +85,9 @@ async def upload():
|
||||
if not check_kb_team_permission(kb, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
|
||||
err, files = await thread_pool_exec(FileService.upload_document, kb, file_objs, current_user.id)
|
||||
if err:
|
||||
files = [f[0] for f in files] if files else []
|
||||
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||
|
||||
if not files:
|
||||
@ -214,6 +226,7 @@ async def list_docs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||
@ -333,6 +346,8 @@ async def list_docs():
|
||||
doc_item["thumbnail"] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}"
|
||||
if doc_item.get("source_type"):
|
||||
doc_item["source_type"] = doc_item["source_type"].split("/")[0]
|
||||
if doc_item["parser_config"].get("metadata"):
|
||||
doc_item["parser_config"]["metadata"] = turn2jsonschema(doc_item["parser_config"]["metadata"])
|
||||
|
||||
return get_json_result(data={"total": tol, "docs": docs})
|
||||
except Exception as e:
|
||||
@ -394,6 +409,7 @@ async def doc_infos():
|
||||
async def metadata_summary():
|
||||
req = await get_request_json()
|
||||
kb_id = req.get("kb_id")
|
||||
doc_ids = req.get("doc_ids")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
@ -405,7 +421,7 @@ async def metadata_summary():
|
||||
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
|
||||
try:
|
||||
summary = DocumentService.get_metadata_summary(kb_id)
|
||||
summary = DocumentService.get_metadata_summary(kb_id, doc_ids)
|
||||
return get_json_result(data={"summary": summary})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -413,36 +429,16 @@ async def metadata_summary():
|
||||
|
||||
@manager.route("/metadata/update", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_ids")
|
||||
async def metadata_update():
|
||||
req = await get_request_json()
|
||||
kb_id = req.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
|
||||
selector = req.get("selector", {}) or {}
|
||||
document_ids = req.get("doc_ids")
|
||||
updates = req.get("updates", []) or []
|
||||
deletes = req.get("deletes", []) or []
|
||||
|
||||
if not isinstance(selector, dict):
|
||||
return get_json_result(data=False, message="selector must be an object.", code=RetCode.ARGUMENT_ERROR)
|
||||
if not isinstance(updates, list) or not isinstance(deletes, list):
|
||||
return get_json_result(data=False, message="updates and deletes must be lists.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
metadata_condition = selector.get("metadata_condition", {}) or {}
|
||||
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||
return get_json_result(data=False, message="metadata_condition must be an object.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
document_ids = selector.get("document_ids", []) or []
|
||||
if document_ids and not isinstance(document_ids, list):
|
||||
return get_json_result(data=False, message="document_ids must be a list.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
for upd in updates:
|
||||
if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd:
|
||||
return get_json_result(data=False, message="Each update requires key and value.", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -450,24 +446,8 @@ async def metadata_update():
|
||||
if not isinstance(d, dict) or not d.get("key"):
|
||||
return get_json_result(data=False, message="Each delete requires key.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([kb_id])
|
||||
target_doc_ids = set(kb_doc_ids)
|
||||
if document_ids:
|
||||
invalid_ids = set(document_ids) - set(kb_doc_ids)
|
||||
if invalid_ids:
|
||||
return get_json_result(data=False, message=f"These documents do not belong to dataset {kb_id}: {', '.join(invalid_ids)}", code=RetCode.ARGUMENT_ERROR)
|
||||
target_doc_ids = set(document_ids)
|
||||
|
||||
if metadata_condition:
|
||||
metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
|
||||
filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||
target_doc_ids = target_doc_ids & filtered_ids
|
||||
if metadata_condition.get("conditions") and not target_doc_ids:
|
||||
return get_json_result(data={"updated": 0, "matched_docs": 0})
|
||||
|
||||
target_doc_ids = list(target_doc_ids)
|
||||
updated = DocumentService.batch_update_metadata(kb_id, target_doc_ids, updates, deletes)
|
||||
return get_json_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
|
||||
updated = DocumentService.batch_update_metadata(None, document_ids, updates, deletes)
|
||||
return get_json_result(data={"updated": updated})
|
||||
|
||||
|
||||
@manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821
|
||||
@ -521,31 +501,61 @@ async def change_status():
|
||||
return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
result = {}
|
||||
has_error = False
|
||||
for doc_id in doc_ids:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
result[doc_id] = {"error": "No authorization."}
|
||||
has_error = True
|
||||
continue
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
result[doc_id] = {"error": "No authorization."}
|
||||
has_error = True
|
||||
continue
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
result[doc_id] = {"error": "Can't find this dataset!"}
|
||||
has_error = True
|
||||
continue
|
||||
current_status = str(doc.status)
|
||||
if current_status == status:
|
||||
result[doc_id] = {"status": status}
|
||||
continue
|
||||
if not DocumentService.update_by_id(doc_id, {"status": str(status)}):
|
||||
result[doc_id] = {"error": "Database error (Document update)!"}
|
||||
has_error = True
|
||||
continue
|
||||
|
||||
status_int = int(status)
|
||||
if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
|
||||
result[doc_id] = {"error": "Database error (docStore update)!"}
|
||||
if getattr(doc, "chunk_num", 0) > 0:
|
||||
try:
|
||||
ok = settings.docStoreConn.update(
|
||||
{"doc_id": doc_id},
|
||||
{"available_int": status_int},
|
||||
search.index_name(kb.tenant_id),
|
||||
doc.kb_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
msg = str(exc)
|
||||
if "3022" in msg:
|
||||
result[doc_id] = {"error": "Document store table missing."}
|
||||
else:
|
||||
result[doc_id] = {"error": f"Document store update failed: {msg}"}
|
||||
has_error = True
|
||||
continue
|
||||
if not ok:
|
||||
result[doc_id] = {"error": "Database error (docStore update)!"}
|
||||
has_error = True
|
||||
continue
|
||||
result[doc_id] = {"status": status}
|
||||
except Exception as e:
|
||||
result[doc_id] = {"error": f"Internal server error: {str(e)}"}
|
||||
has_error = True
|
||||
|
||||
if has_error:
|
||||
return get_json_result(data=result, message="Partial failure", code=RetCode.SERVER_ERROR)
|
||||
return get_json_result(data=result)
|
||||
|
||||
|
||||
@ -562,7 +572,7 @@ async def rm():
|
||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
|
||||
errors = await thread_pool_exec(FileService.delete_docs, doc_ids, current_user.id)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||
@ -575,10 +585,11 @@ async def rm():
|
||||
@validate_request("doc_ids", "run")
|
||||
async def run():
|
||||
req = await get_request_json()
|
||||
uid = current_user.id
|
||||
try:
|
||||
def _run_sync():
|
||||
for doc_id in req["doc_ids"]:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
if not DocumentService.accessible(doc_id, uid):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
kb_table_num_map = {}
|
||||
@ -615,6 +626,7 @@ async def run():
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
raise LookupError("Can't find this dataset!")
|
||||
doc.parser_config["llm_id"] = kb.parser_config.get("llm_id")
|
||||
doc.parser_config["enable_metadata"] = kb.parser_config.get("enable_metadata", False)
|
||||
doc.parser_config["metadata"] = kb.parser_config.get("metadata", {})
|
||||
DocumentService.update_parser_config(doc.id, doc.parser_config)
|
||||
@ -623,7 +635,7 @@ async def run():
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_run_sync)
|
||||
return await thread_pool_exec(_run_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -633,9 +645,10 @@ async def run():
|
||||
@validate_request("doc_id", "name")
|
||||
async def rename():
|
||||
req = await get_request_json()
|
||||
uid = current_user.id
|
||||
try:
|
||||
def _rename_sync():
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
if not DocumentService.accessible(req["doc_id"], uid):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
@ -674,7 +687,7 @@ async def rename():
|
||||
)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rename_sync)
|
||||
return await thread_pool_exec(_rename_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -689,7 +702,7 @@ async def get(doc_id):
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||
data = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
|
||||
response = await make_response(data)
|
||||
|
||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||
@ -711,7 +724,7 @@ async def get(doc_id):
|
||||
async def download_attachment(attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
|
||||
data = await thread_pool_exec(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
@ -784,7 +797,7 @@ async def get_image(image_id):
|
||||
if len(arr) != 2:
|
||||
return get_data_error_result(message="Image not found.")
|
||||
bkt, nm = image_id.split("-")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
|
||||
data = await thread_pool_exec(settings.STORAGE_IMPL.get, bkt, nm)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", "image/JPEG")
|
||||
return response
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License
|
||||
#
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
@ -25,7 +24,7 @@ from api.common.check_team_permission import check_file_team_permission
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from common.constants import RetCode, FileSource
|
||||
from api.db import FileType
|
||||
from api.db.services import duplicate_name
|
||||
@ -35,7 +34,6 @@ from api.utils.file_utils import filename_type
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route('/upload', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
# @validate_request("parent_id")
|
||||
@ -65,7 +63,7 @@ async def upload():
|
||||
|
||||
async def _handle_single_file(file_obj):
|
||||
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= await thread_pool_exec(DocumentService.get_doc_count, current_user.id):
|
||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
||||
|
||||
# split file name path
|
||||
@ -77,35 +75,35 @@ async def upload():
|
||||
file_len = len(file_obj_names)
|
||||
|
||||
# get folder
|
||||
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
|
||||
file_id_list = await thread_pool_exec(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
|
||||
len_id_list = len(file_id_list)
|
||||
|
||||
# create folder
|
||||
if file_len != len_id_list:
|
||||
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
|
||||
e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
len_id_list)
|
||||
else:
|
||||
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
|
||||
e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 2])
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
len_id_list)
|
||||
|
||||
# file type
|
||||
filetype = filename_type(file_obj_names[file_len - 1])
|
||||
location = file_obj_names[file_len - 1]
|
||||
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
|
||||
while await thread_pool_exec(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
|
||||
location += "_"
|
||||
blob = await asyncio.to_thread(file_obj.read)
|
||||
filename = await asyncio.to_thread(
|
||||
blob = await thread_pool_exec(file_obj.read)
|
||||
filename = await thread_pool_exec(
|
||||
duplicate_name,
|
||||
FileService.query,
|
||||
name=file_obj_names[file_len - 1],
|
||||
parent_id=last_folder.id)
|
||||
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
|
||||
await thread_pool_exec(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
|
||||
file_data = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": last_folder.id,
|
||||
@ -116,7 +114,7 @@ async def upload():
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
}
|
||||
inserted = await asyncio.to_thread(FileService.insert, file_data)
|
||||
inserted = await thread_pool_exec(FileService.insert, file_data)
|
||||
return inserted.to_json()
|
||||
|
||||
for file_obj in file_objs:
|
||||
@ -249,6 +247,7 @@ def get_all_parent_folders():
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
file_ids = req["file_ids"]
|
||||
uid = current_user.id
|
||||
|
||||
try:
|
||||
def _delete_single_file(file):
|
||||
@ -287,21 +286,21 @@ async def rm():
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
if not check_file_team_permission(file, uid):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||
continue
|
||||
|
||||
if file.type == FileType.FOLDER.value:
|
||||
_delete_folder_recursive(file, current_user.id)
|
||||
_delete_folder_recursive(file, uid)
|
||||
continue
|
||||
|
||||
_delete_single_file(file)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
return await thread_pool_exec(_rm_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -357,10 +356,10 @@ async def get(file_id):
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
|
||||
blob = await thread_pool_exec(settings.STORAGE_IMPL.get, file.parent_id, file.location)
|
||||
if not blob:
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||
blob = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
|
||||
|
||||
response = await make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||
@ -460,7 +459,7 @@ async def move():
|
||||
_move_entry_recursive(file, dest_folder)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_move_sync)
|
||||
return await thread_pool_exec(_move_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -17,8 +17,8 @@ import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import asyncio
|
||||
|
||||
from common.metadata_utils import turn2jsonschema
|
||||
from quart import request
|
||||
import numpy as np
|
||||
|
||||
@ -30,8 +30,15 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
|
||||
get_request_json
|
||||
from api.utils.api_utils import (
|
||||
get_error_data_result,
|
||||
server_error_response,
|
||||
get_data_error_result,
|
||||
validate_request,
|
||||
not_allowed_parameters,
|
||||
get_request_json,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from api.db import VALID_FILE_TYPES
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
@ -44,7 +51,6 @@ from common import settings
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/create', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
@ -82,6 +88,20 @@ async def update():
|
||||
return get_data_error_result(
|
||||
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
|
||||
req["name"] = req["name"].strip()
|
||||
if settings.DOC_ENGINE_INFINITY:
|
||||
parser_id = req.get("parser_id")
|
||||
if isinstance(parser_id, str) and parser_id.lower() == "tag":
|
||||
return get_json_result(
|
||||
code=RetCode.OPERATING_ERROR,
|
||||
message="The chunking method Tag has not been supported by Infinity yet.",
|
||||
data=False,
|
||||
)
|
||||
if "pagerank" in req and req["pagerank"] > 0:
|
||||
return get_json_result(
|
||||
code=RetCode.DATA_ERROR,
|
||||
message="'pagerank' can only be set when doc_engine is elasticsearch",
|
||||
data=False,
|
||||
)
|
||||
|
||||
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
|
||||
return get_json_result(
|
||||
@ -130,7 +150,7 @@ async def update():
|
||||
|
||||
if kb.pagerank != req.get("pagerank", 0):
|
||||
if req.get("pagerank", 0) > 0:
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.update,
|
||||
{"kb_id": kb.id},
|
||||
{PAGERANK_FLD: req["pagerank"]},
|
||||
@ -139,7 +159,7 @@ async def update():
|
||||
)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.update,
|
||||
{"exists": PAGERANK_FLD},
|
||||
{"remove": PAGERANK_FLD},
|
||||
@ -174,6 +194,7 @@ async def update_metadata_setting():
|
||||
message="Database error (Knowledgebase rename)!")
|
||||
kb = kb.to_dict()
|
||||
kb["parser_config"]["metadata"] = req["metadata"]
|
||||
kb["parser_config"]["enable_metadata"] = req.get("enable_metadata", True)
|
||||
KnowledgebaseService.update_by_id(kb["id"], kb)
|
||||
return get_json_result(data=kb)
|
||||
|
||||
@ -198,6 +219,8 @@ def detail():
|
||||
message="Can't find this dataset!")
|
||||
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
|
||||
kb["connectors"] = Connector2KbService.list_connectors(kb_id)
|
||||
if kb["parser_config"].get("metadata"):
|
||||
kb["parser_config"]["metadata"] = turn2jsonschema(kb["parser_config"]["metadata"])
|
||||
|
||||
for key in ["graphrag_task_finish_at", "raptor_task_finish_at", "mindmap_task_finish_at"]:
|
||||
if finish_at := kb.get(key):
|
||||
@ -249,7 +272,8 @@ async def list_kbs():
|
||||
@validate_request("kb_id")
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
|
||||
uid = current_user.id
|
||||
if not KnowledgebaseService.accessible4deletion(req["kb_id"], uid):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
@ -257,7 +281,7 @@ async def rm():
|
||||
)
|
||||
try:
|
||||
kbs = KnowledgebaseService.query(
|
||||
created_by=current_user.id, id=req["kb_id"])
|
||||
created_by=uid, id=req["kb_id"])
|
||||
if not kbs:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of dataset authorized for this operation.',
|
||||
@ -280,17 +304,24 @@ async def rm():
|
||||
File.name == kbs[0].name,
|
||||
]
|
||||
)
|
||||
# Delete the table BEFORE deleting the database record
|
||||
for kb in kbs:
|
||||
try:
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
|
||||
logging.info(f"Dropped index for dataset {kb.id}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to drop index for dataset {kb.id}: {e}")
|
||||
|
||||
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
for kb in kbs:
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
|
||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
return await thread_pool_exec(_rm_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -372,7 +403,7 @@ async def rename_tags(kb_id):
|
||||
|
||||
@manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def knowledge_graph(kb_id):
|
||||
async def knowledge_graph(kb_id):
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -388,7 +419,7 @@ def knowledge_graph(kb_id):
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
|
||||
return get_json_result(data=obj)
|
||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||
sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||
if not len(sres.ids):
|
||||
return get_json_result(data=obj)
|
||||
|
||||
|
||||
@ -195,6 +195,9 @@ async def add_llm():
|
||||
elif factory == "MinerU":
|
||||
api_key = apikey_json(["api_key", "provider_order"])
|
||||
|
||||
elif factory == "PaddleOCR":
|
||||
api_key = apikey_json(["api_key", "provider_order"])
|
||||
|
||||
llm = {
|
||||
"tenant_id": current_user.id,
|
||||
"llm_factory": factory,
|
||||
@ -230,8 +233,7 @@ async def add_llm():
|
||||
**extra,
|
||||
)
|
||||
try:
|
||||
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
|
||||
{"temperature": 0.9})
|
||||
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||
if not tc and m.find("**ERROR**:") >= 0:
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
@ -371,17 +373,18 @@ def my_llms():
|
||||
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_app():
|
||||
async def list_app():
|
||||
self_deployed = ["FastEmbed", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
|
||||
weighted = []
|
||||
model_type = request.args.get("model_type")
|
||||
tenant_id = current_user.id
|
||||
try:
|
||||
TenantLLMService.ensure_mineru_from_env(current_user.id)
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
TenantLLMService.ensure_mineru_from_env(tenant_id)
|
||||
objs = TenantLLMService.query(tenant_id=tenant_id)
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value])
|
||||
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
|
||||
llms = LLMService.get_all()
|
||||
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.fid == 'Builtin' or (m.llm_name + "@" + m.fid) in status)]
|
||||
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.fid == "Builtin" or (m.llm_name + "@" + m.fid) in status)]
|
||||
for m in llms:
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
|
||||
if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"] == LLMType.EMBEDDING and m["fid"] == "Builtin" and m["llm_name"] == os.getenv("TEI_MODEL", ""):
|
||||
|
||||
@ -13,8 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
|
||||
from quart import Response, request
|
||||
from api.apps import current_user, login_required
|
||||
|
||||
@ -23,12 +21,11 @@ from api.db.services.mcp_server_service import MCPServerService
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
|
||||
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
|
||||
from api.utils.web_utils import get_float, safe_json_parse
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def list_mcp() -> Response:
|
||||
@ -108,7 +105,7 @@ async def create() -> Response:
|
||||
return get_data_error_result(message="Tenant not found.")
|
||||
|
||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
return get_data_error_result(err_message)
|
||||
|
||||
@ -160,7 +157,7 @@ async def update() -> Response:
|
||||
req["id"] = mcp_id
|
||||
|
||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
return get_data_error_result(err_message)
|
||||
|
||||
@ -244,7 +241,7 @@ async def import_multiple() -> Response:
|
||||
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
|
||||
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
|
||||
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
results.append({"server": base_name, "success": False, "message": err_message})
|
||||
continue
|
||||
@ -324,7 +321,7 @@ async def list_tools() -> Response:
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
|
||||
try:
|
||||
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
|
||||
tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"MCP list tools error: {e}")
|
||||
|
||||
@ -341,7 +338,7 @@ async def list_tools() -> Response:
|
||||
return server_error_response(e)
|
||||
finally:
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
|
||||
|
||||
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
|
||||
@ -368,10 +365,10 @@ async def test_tool() -> Response:
|
||||
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
result = await asyncio.to_thread(tool_call_session.tool_call, tool_name, arguments, timeout)
|
||||
result = await thread_pool_exec(tool_call_session.tool_call, tool_name, arguments, timeout)
|
||||
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -425,12 +422,12 @@ async def test_mcp() -> Response:
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
|
||||
try:
|
||||
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
|
||||
tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"Test MCP error: {e}")
|
||||
finally:
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, [tool_call_session])
|
||||
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session])
|
||||
|
||||
for tool in tools:
|
||||
tool_dict = tool.model_dump()
|
||||
|
||||
@ -51,7 +51,7 @@ def list_agents(tenant_id):
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 30))
|
||||
order_by = request.args.get("orderby", "update_time")
|
||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
||||
if str(request.args.get("desc","false")).lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
@ -162,6 +162,7 @@ async def webhook(agent_id: str):
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
|
||||
|
||||
# 4. Check webhook configuration in DSL
|
||||
webhook_cfg = {}
|
||||
components = dsl.get("components", {})
|
||||
for k, _ in components.items():
|
||||
cpn_obj = components[k]["obj"]
|
||||
|
||||
@ -51,7 +51,9 @@ async def create(tenant_id):
|
||||
req["llm_id"] = llm.pop("model_name")
|
||||
if req.get("llm_id") is not None:
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
|
||||
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"):
|
||||
model_type = llm.get("model_type")
|
||||
model_type = model_type if model_type in ["chat", "image2text"] else "chat"
|
||||
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type):
|
||||
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
|
||||
req["llm_setting"] = req.pop("llm")
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
@ -174,7 +176,7 @@ async def update(tenant_id, chat_id):
|
||||
req["llm_id"] = llm.pop("model_name")
|
||||
if req.get("llm_id") is not None:
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
|
||||
model_type = llm.pop("model_type")
|
||||
model_type = llm.get("model_type")
|
||||
model_type = model_type if model_type in ["chat", "image2text"] else "chat"
|
||||
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type):
|
||||
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
|
||||
|
||||
@ -233,6 +233,15 @@ async def delete(tenant_id):
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
FileService.filter_delete(
|
||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
|
||||
|
||||
# Drop index for this dataset
|
||||
try:
|
||||
from rag.nlp import search
|
||||
idxnm = search.index_name(kb.tenant_id)
|
||||
settings.docStoreConn.delete_idx(idxnm, kb_id)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to drop index for dataset {kb_id}: {e}")
|
||||
|
||||
if not KnowledgebaseService.delete_by_id(kb_id):
|
||||
errors.append(f"Delete dataset error for {kb_id}")
|
||||
continue
|
||||
@ -481,7 +490,7 @@ def list_datasets(tenant_id):
|
||||
|
||||
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def knowledge_graph(tenant_id, dataset_id):
|
||||
async def knowledge_graph(tenant_id, dataset_id):
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
@ -497,7 +506,7 @@ def knowledge_graph(tenant_id, dataset_id):
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
|
||||
return get_result(data=obj)
|
||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||
sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||
if not len(sres.ids):
|
||||
return get_result(data=obj)
|
||||
|
||||
|
||||
@ -135,7 +135,7 @@ async def retrieval(tenant_id):
|
||||
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||
if not doc_ids and metadata_condition:
|
||||
doc_ids = ["-999"]
|
||||
ranks = settings.retriever.retrieval(
|
||||
ranks = await settings.retriever.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
kb.tenant_id,
|
||||
|
||||
@ -606,12 +606,12 @@ def list_docs(dataset_id, tenant_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/metadata/summary", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def metadata_summary(dataset_id, tenant_id):
|
||||
async def metadata_summary(dataset_id, tenant_id):
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||
|
||||
req = await get_request_json()
|
||||
try:
|
||||
summary = DocumentService.get_metadata_summary(dataset_id)
|
||||
summary = DocumentService.get_metadata_summary(dataset_id, req.get("doc_ids"))
|
||||
return get_result(data={"summary": summary})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -647,10 +647,10 @@ async def metadata_batch_update(dataset_id, tenant_id):
|
||||
for d in deletes:
|
||||
if not isinstance(d, dict) or not d.get("key"):
|
||||
return get_error_data_result(message="Each delete requires key.")
|
||||
|
||||
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
|
||||
target_doc_ids = set(kb_doc_ids)
|
||||
|
||||
if document_ids:
|
||||
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
|
||||
target_doc_ids = set(kb_doc_ids)
|
||||
invalid_ids = set(document_ids) - set(kb_doc_ids)
|
||||
if invalid_ids:
|
||||
return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}")
|
||||
@ -935,7 +935,7 @@ async def stop_parsing(tenant_id, dataset_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list_chunks(tenant_id, dataset_id, document_id):
|
||||
async def list_chunks(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
List chunks of a document.
|
||||
---
|
||||
@ -1081,7 +1081,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
_ = Chunk(**final_chunk)
|
||||
|
||||
elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
|
||||
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||
sres = await settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||
res["total"] = sres.total
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
@ -1519,11 +1519,12 @@ async def retrieval_test(tenant_id):
|
||||
toc_enhance = req.get("toc_enhance", False)
|
||||
langs = req.get("cross_languages", [])
|
||||
if not isinstance(doc_ids, list):
|
||||
return get_error_data_result("`documents` should be a list")
|
||||
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
|
||||
for doc_id in doc_ids:
|
||||
if doc_id not in doc_ids_list:
|
||||
return get_error_data_result(f"The datasets don't own the document {doc_id}")
|
||||
return get_error_data_result("`documents` should be a list")
|
||||
if doc_ids:
|
||||
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
|
||||
for doc_id in doc_ids:
|
||||
if doc_id not in doc_ids_list:
|
||||
return get_error_data_result(f"The datasets don't own the document {doc_id}")
|
||||
if not doc_ids:
|
||||
metadata_condition = req.get("metadata_condition", {}) or {}
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
@ -1558,7 +1559,7 @@ async def retrieval_test(tenant_id):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += await keyword_extraction(chat_mdl, question)
|
||||
|
||||
ranks = settings.retriever.retrieval(
|
||||
ranks = await settings.retriever.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
tenant_ids,
|
||||
@ -1575,7 +1576,7 @@ async def retrieval_test(tenant_id):
|
||||
)
|
||||
if toc_enhance:
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
cks = settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
|
||||
cks = await settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
|
||||
if cks:
|
||||
ranks["chunks"] = cks
|
||||
if use_kg:
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
import re
|
||||
from quart import request, make_response
|
||||
@ -24,7 +23,7 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.db import FileType
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.file_service import FileService
|
||||
@ -33,7 +32,6 @@ from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from common import settings
|
||||
from common.constants import RetCode
|
||||
|
||||
|
||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
async def upload(tenant_id):
|
||||
@ -640,7 +638,7 @@ async def get(tenant_id, file_id):
|
||||
async def download_attachment(tenant_id, attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
||||
data = await thread_pool_exec(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
|
||||
@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from quart import request
|
||||
from api.apps import login_required, current_user
|
||||
@ -21,6 +23,7 @@ from api.db import TenantPermission
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default
|
||||
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
|
||||
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||
@ -30,26 +33,60 @@ from memory.utils.prompt_util import PromptAssembler
|
||||
from common.constants import MemoryType, RetCode, ForgettingPolicy
|
||||
|
||||
|
||||
@manager.route("", methods=["POST"]) # noqa: F821
|
||||
@manager.route("/memories", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "memory_type", "embd_id", "llm_id")
|
||||
async def create_memory():
|
||||
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
|
||||
t_start = time.perf_counter() if timing_enabled else None
|
||||
req = await get_request_json()
|
||||
t_parsed = time.perf_counter() if timing_enabled else None
|
||||
# check name length
|
||||
name = req["name"]
|
||||
memory_name = name.strip()
|
||||
if len(memory_name) == 0:
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
return get_error_argument_result("Memory name cannot be empty or whitespace.")
|
||||
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
|
||||
# check memory_type valid
|
||||
if not isinstance(req["memory_type"], list):
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
return get_error_argument_result("Memory type must be a list.")
|
||||
memory_type = set(req["memory_type"])
|
||||
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
||||
if invalid_type:
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
|
||||
memory_type = list(memory_type)
|
||||
|
||||
try:
|
||||
t_before_db = time.perf_counter() if timing_enabled else None
|
||||
res, memory = MemoryService.create_memory(
|
||||
tenant_id=current_user.id,
|
||||
name=memory_name,
|
||||
@ -57,6 +94,15 @@ async def create_memory():
|
||||
embd_id=req["embd_id"],
|
||||
llm_id=req["llm_id"]
|
||||
)
|
||||
if timing_enabled:
|
||||
logging.info(
|
||||
"api_timing create_memory parse_ms=%.2f validate_ms=%.2f db_ms=%.2f total_ms=%.2f path=%s",
|
||||
(t_parsed - t_start) * 1000,
|
||||
(t_before_db - t_parsed) * 1000,
|
||||
(time.perf_counter() - t_before_db) * 1000,
|
||||
(time.perf_counter() - t_start) * 1000,
|
||||
request.path,
|
||||
)
|
||||
|
||||
if res:
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||
@ -67,7 +113,7 @@ async def create_memory():
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/<memory_id>", methods=["PUT"]) # noqa: F821
|
||||
@manager.route("/memories/<memory_id>", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
async def update_memory(memory_id):
|
||||
req = await get_request_json()
|
||||
@ -151,7 +197,7 @@ async def update_memory(memory_id):
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
|
||||
@manager.route("/memories/<memory_id>", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
async def delete_memory(memory_id):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
@ -167,7 +213,7 @@ async def delete_memory(memory_id):
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("", methods=["GET"]) # noqa: F821
|
||||
@manager.route("/memories", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def list_memory():
|
||||
args = request.args
|
||||
@ -179,13 +225,18 @@ async def list_memory():
|
||||
page = int(args.get("page", 1))
|
||||
page_size = int(args.get("page_size", 50))
|
||||
# make filter dict
|
||||
filter_dict = {"memory_type": memory_types, "storage_type": storage_type}
|
||||
filter_dict: dict = {"storage_type": storage_type}
|
||||
if not tenant_ids:
|
||||
# restrict to current user's tenants
|
||||
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
|
||||
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
|
||||
else:
|
||||
if len(tenant_ids) == 1 and ',' in tenant_ids[0]:
|
||||
tenant_ids = tenant_ids[0].split(',')
|
||||
filter_dict["tenant_id"] = tenant_ids
|
||||
if memory_types and len(memory_types) == 1 and ',' in memory_types[0]:
|
||||
memory_types = memory_types[0].split(',')
|
||||
filter_dict["memory_type"] = memory_types
|
||||
|
||||
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
|
||||
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
|
||||
@ -196,7 +247,7 @@ async def list_memory():
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/<memory_id>/config", methods=["GET"]) # noqa: F821
|
||||
@manager.route("/memories/<memory_id>/config", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_memory_config(memory_id):
|
||||
memory = MemoryService.get_with_owner_name_by_id(memory_id)
|
||||
@ -205,11 +256,13 @@ async def get_memory_config(memory_id):
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||
|
||||
|
||||
@manager.route("/<memory_id>", methods=["GET"]) # noqa: F821
|
||||
@manager.route("/memories/<memory_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_memory_detail(memory_id):
|
||||
args = request.args
|
||||
agent_ids = args.getlist("agent_id")
|
||||
if len(agent_ids) == 1 and ',' in agent_ids[0]:
|
||||
agent_ids = agent_ids[0].split(',')
|
||||
keywords = args.get("keywords", "")
|
||||
keywords = keywords.strip()
|
||||
page = int(args.get("page", 1))
|
||||
@ -220,9 +273,19 @@ async def get_memory_detail(memory_id):
|
||||
messages = MessageService.list_message(
|
||||
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
|
||||
agent_name_mapping = {}
|
||||
extract_task_mapping = {}
|
||||
if messages["message_list"]:
|
||||
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
|
||||
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
|
||||
task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id])
|
||||
if task_list:
|
||||
task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task
|
||||
for task in task_list:
|
||||
# the 'digest' field carries the source_id when a task is created, so use 'digest' as key
|
||||
extract_task_mapping.update({int(task["digest"]): task})
|
||||
for message in messages["message_list"]:
|
||||
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
|
||||
message["task"] = extract_task_mapping.get(message["message_id"], {})
|
||||
for extract_msg in message["extract"]:
|
||||
extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown")
|
||||
return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True)
|
||||
@ -24,44 +24,31 @@ from api.utils.api_utils import validate_request, get_request_json, get_error_ar
|
||||
from common.constants import RetCode
|
||||
|
||||
|
||||
@manager.route("", methods=["POST"]) # noqa: F821
|
||||
@manager.route("/messages", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response")
|
||||
async def add_message():
|
||||
|
||||
req = await get_request_json()
|
||||
memory_ids = req["memory_id"]
|
||||
agent_id = req["agent_id"]
|
||||
session_id = req["session_id"]
|
||||
user_id = req["user_id"] if req.get("user_id") else ""
|
||||
user_input = req["user_input"]
|
||||
agent_response = req["agent_response"]
|
||||
|
||||
res = []
|
||||
for memory_id in memory_ids:
|
||||
success, msg = await memory_message_service.save_to_memory(
|
||||
memory_id,
|
||||
{
|
||||
"user_id": user_id,
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id,
|
||||
"user_input": user_input,
|
||||
"agent_response": agent_response
|
||||
}
|
||||
)
|
||||
res.append({
|
||||
"memory_id": memory_id,
|
||||
"success": success,
|
||||
"message": msg
|
||||
})
|
||||
message_dict = {
|
||||
"user_id": req.get("user_id"),
|
||||
"agent_id": req["agent_id"],
|
||||
"session_id": req["session_id"],
|
||||
"user_input": req["user_input"],
|
||||
"agent_response": req["agent_response"],
|
||||
}
|
||||
|
||||
if all([r["success"] for r in res]):
|
||||
return get_json_result(message="Successfully added to memories.")
|
||||
res, msg = await memory_message_service.queue_save_to_memory_task(memory_ids, message_dict)
|
||||
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add.", data=res)
|
||||
if res:
|
||||
return get_json_result(message=msg)
|
||||
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add. Detail:" + msg)
|
||||
|
||||
|
||||
@manager.route("/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821
|
||||
@manager.route("/messages/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
async def forget_message(memory_id: str, message_id: int):
|
||||
|
||||
@ -80,7 +67,7 @@ async def forget_message(memory_id: str, message_id: int):
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.")
|
||||
|
||||
|
||||
@manager.route("/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821
|
||||
@manager.route("/messages/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("status")
|
||||
async def update_message(memory_id: str, message_id: int):
|
||||
@ -100,16 +87,17 @@ async def update_message(memory_id: str, message_id: int):
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.")
|
||||
|
||||
|
||||
@manager.route("/search", methods=["GET"]) # noqa: F821
|
||||
@manager.route("/messages/search", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def search_message():
|
||||
args = request.args
|
||||
print(args, flush=True)
|
||||
empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)]
|
||||
if empty_fields:
|
||||
return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.")
|
||||
|
||||
memory_ids = args.getlist("memory_id")
|
||||
if len(memory_ids) == 1 and ',' in memory_ids[0]:
|
||||
memory_ids = memory_ids[0].split(',')
|
||||
query = args.get("query")
|
||||
similarity_threshold = float(args.get("similarity_threshold", 0.2))
|
||||
keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7))
|
||||
@ -132,11 +120,13 @@ async def search_message():
|
||||
return get_json_result(message=True, data=res)
|
||||
|
||||
|
||||
@manager.route("", methods=["GET"]) # noqa: F821
|
||||
@manager.route("/messages", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_messages():
|
||||
args = request.args
|
||||
memory_ids = args.getlist("memory_id")
|
||||
if len(memory_ids) == 1 and ',' in memory_ids[0]:
|
||||
memory_ids = memory_ids[0].split(',')
|
||||
agent_id = args.get("agent_id", "")
|
||||
session_id = args.get("session_id", "")
|
||||
limit = int(args.get("limit", 10))
|
||||
@ -154,7 +144,7 @@ async def get_messages():
|
||||
return get_json_result(message=True, data=res)
|
||||
|
||||
|
||||
@manager.route("/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821
|
||||
@manager.route("/messages/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_message_content(memory_id:str, message_id: int):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
@ -19,6 +19,10 @@ import re
|
||||
import time
|
||||
|
||||
import tiktoken
|
||||
import os
|
||||
import tempfile
|
||||
import logging
|
||||
|
||||
from quart import Response, jsonify, request
|
||||
|
||||
from agent.canvas import Canvas
|
||||
@ -35,7 +39,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.services.user_service import TenantService,UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
|
||||
get_result, get_request_json, server_error_response, token_required, validate_request
|
||||
@ -304,9 +308,12 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
# The choices field on the last chunk will always be an empty array [].
|
||||
async def streamed_response_generator(chat_id, dia, msg):
|
||||
token_used = 0
|
||||
answer_cache = ""
|
||||
reasoning_cache = ""
|
||||
last_ans = {}
|
||||
full_content = ""
|
||||
full_reasoning = ""
|
||||
final_answer = None
|
||||
final_reference = None
|
||||
in_think = False
|
||||
response = {
|
||||
"id": f"chatcmpl-{chat_id}",
|
||||
"choices": [
|
||||
@ -336,47 +343,30 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
chat_kwargs["doc_ids"] = doc_ids_str
|
||||
async for ans in async_chat(dia, msg, True, **chat_kwargs):
|
||||
last_ans = ans
|
||||
answer = ans["answer"]
|
||||
|
||||
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
||||
if reasoning_match:
|
||||
reasoning_part = reasoning_match.group(1)
|
||||
content_part = answer[reasoning_match.end() :]
|
||||
else:
|
||||
reasoning_part = ""
|
||||
content_part = answer
|
||||
|
||||
reasoning_incremental = ""
|
||||
if reasoning_part:
|
||||
if reasoning_part.startswith(reasoning_cache):
|
||||
reasoning_incremental = reasoning_part.replace(reasoning_cache, "", 1)
|
||||
else:
|
||||
reasoning_incremental = reasoning_part
|
||||
reasoning_cache = reasoning_part
|
||||
|
||||
content_incremental = ""
|
||||
if content_part:
|
||||
if content_part.startswith(answer_cache):
|
||||
content_incremental = content_part.replace(answer_cache, "", 1)
|
||||
else:
|
||||
content_incremental = content_part
|
||||
answer_cache = content_part
|
||||
|
||||
token_used += len(reasoning_incremental) + len(content_incremental)
|
||||
|
||||
if not any([reasoning_incremental, content_incremental]):
|
||||
if ans.get("final"):
|
||||
if ans.get("answer"):
|
||||
full_content = ans["answer"]
|
||||
final_answer = ans.get("answer") or full_content
|
||||
final_reference = ans.get("reference", {})
|
||||
continue
|
||||
|
||||
if reasoning_incremental:
|
||||
response["choices"][0]["delta"]["reasoning_content"] = reasoning_incremental
|
||||
else:
|
||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||
|
||||
if content_incremental:
|
||||
response["choices"][0]["delta"]["content"] = content_incremental
|
||||
else:
|
||||
if ans.get("start_to_think"):
|
||||
in_think = True
|
||||
continue
|
||||
if ans.get("end_to_think"):
|
||||
in_think = False
|
||||
continue
|
||||
delta = ans.get("answer") or ""
|
||||
if not delta:
|
||||
continue
|
||||
token_used += len(delta)
|
||||
if in_think:
|
||||
full_reasoning += delta
|
||||
response["choices"][0]["delta"]["reasoning_content"] = delta
|
||||
response["choices"][0]["delta"]["content"] = None
|
||||
|
||||
else:
|
||||
full_content += delta
|
||||
response["choices"][0]["delta"]["content"] = delta
|
||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
|
||||
@ -388,8 +378,9 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
response["choices"][0]["finish_reason"] = "stop"
|
||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
|
||||
if need_reference:
|
||||
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
|
||||
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
|
||||
reference_payload = final_reference if final_reference is not None else last_ans.get("reference", [])
|
||||
response["choices"][0]["delta"]["reference"] = chunks_format(reference_payload)
|
||||
response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content
|
||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||
yield "data:[DONE]\n\n"
|
||||
|
||||
@ -438,7 +429,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
],
|
||||
}
|
||||
if need_reference:
|
||||
response["choices"][0]["message"]["reference"] = chunks_format(answer.get("reference", []))
|
||||
response["choices"][0]["message"]["reference"] = chunks_format(answer.get("reference", {}))
|
||||
|
||||
return jsonify(response)
|
||||
|
||||
@ -1058,11 +1049,13 @@ async def retrieval_test_embedded():
|
||||
use_kg = req.get("use_kg", False)
|
||||
top = int(req.get("top_k", 1024))
|
||||
langs = req.get("cross_languages", [])
|
||||
rerank_id = req.get("rerank_id", "")
|
||||
tenant_id = objs[0].tenant_id
|
||||
if not tenant_id:
|
||||
return get_error_data_result(message="permission denined.")
|
||||
|
||||
async def _retrieval():
|
||||
nonlocal similarity_threshold, vector_similarity_weight, top, rerank_id
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
tenant_ids = []
|
||||
_question = question
|
||||
@ -1074,6 +1067,15 @@ async def retrieval_test_embedded():
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
# Apply search_config settings if not explicitly provided in request
|
||||
if not req.get("similarity_threshold"):
|
||||
similarity_threshold = float(search_config.get("similarity_threshold", similarity_threshold))
|
||||
if not req.get("vector_similarity_weight"):
|
||||
vector_similarity_weight = float(search_config.get("vector_similarity_weight", vector_similarity_weight))
|
||||
if not req.get("top_k"):
|
||||
top = int(search_config.get("top_k", top))
|
||||
if not req.get("rerank_id"):
|
||||
rerank_id = search_config.get("rerank_id", "")
|
||||
else:
|
||||
meta_data_filter = req.get("meta_data_filter") or {}
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
@ -1103,15 +1105,15 @@ async def retrieval_test_embedded():
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
||||
if rerank_id:
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=rerank_id)
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
_question += await keyword_extraction(chat_mdl, _question)
|
||||
|
||||
labels = label_question(_question, [kb])
|
||||
ranks = settings.retriever.retrieval(
|
||||
ranks = await settings.retriever.retrieval(
|
||||
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||
)
|
||||
@ -1233,3 +1235,93 @@ async def mindmap():
|
||||
if "error" in mind_map:
|
||||
return server_error_response(Exception(mind_map["error"]))
|
||||
return get_json_result(data=mind_map)
|
||||
|
||||
@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def sequence2txt(tenant_id):
|
||||
req = await request.form
|
||||
stream_mode = req.get("stream", "false").lower() == "true"
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_error_data_result(message="Missing 'file' in multipart form-data")
|
||||
|
||||
uploaded = files["file"]
|
||||
|
||||
ALLOWED_EXTS = {
|
||||
".wav", ".mp3", ".m4a", ".aac",
|
||||
".flac", ".ogg", ".webm",
|
||||
".opus", ".wma"
|
||||
}
|
||||
|
||||
filename = uploaded.filename or ""
|
||||
suffix = os.path.splitext(filename)[-1].lower()
|
||||
if suffix not in ALLOWED_EXTS:
|
||||
return get_error_data_result(message=
|
||||
f"Unsupported audio format: {suffix}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}"
|
||||
)
|
||||
fd, temp_audio_path = tempfile.mkstemp(suffix=suffix)
|
||||
os.close(fd)
|
||||
await uploaded.save(temp_audio_path)
|
||||
|
||||
tenants = TenantService.get_info_by(tenant_id)
|
||||
if not tenants:
|
||||
return get_error_data_result(message="Tenant not found!")
|
||||
|
||||
asr_id = tenants[0]["asr_id"]
|
||||
if not asr_id:
|
||||
return get_error_data_result(message="No default ASR model is set")
|
||||
|
||||
asr_mdl=LLMBundle(tenants[0]["tenant_id"], LLMType.SPEECH2TEXT, asr_id)
|
||||
if not stream_mode:
|
||||
text = asr_mdl.transcription(temp_audio_path)
|
||||
try:
|
||||
os.remove(temp_audio_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||
return get_json_result(data={"text": text})
|
||||
async def event_stream():
|
||||
try:
|
||||
for evt in asr_mdl.stream_transcription(temp_audio_path):
|
||||
yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
err = {"event": "error", "text": str(e)}
|
||||
yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n"
|
||||
finally:
|
||||
try:
|
||||
os.remove(temp_audio_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||
|
||||
return Response(event_stream(), content_type="text/event-stream")
|
||||
|
||||
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def tts(tenant_id):
|
||||
req = await get_request_json()
|
||||
text = req["text"]
|
||||
|
||||
tenants = TenantService.get_info_by(tenant_id)
|
||||
if not tenants:
|
||||
return get_error_data_result(message="Tenant not found!")
|
||||
|
||||
tts_id = tenants[0]["tts_id"]
|
||||
if not tts_id:
|
||||
return get_error_data_result(message="No default TTS model is set")
|
||||
|
||||
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
|
||||
|
||||
def stream_audio():
|
||||
try:
|
||||
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
|
||||
for chunk in tts_mdl.tts(txt):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
|
||||
|
||||
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
||||
resp.headers.add_header("Cache-Control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
|
||||
return resp
|
||||
@ -178,7 +178,7 @@ def healthz():
|
||||
|
||||
|
||||
@manager.route("/ping", methods=["GET"]) # noqa: F821
|
||||
def ping():
|
||||
async def ping():
|
||||
return "pong", 200
|
||||
|
||||
|
||||
|
||||
@ -281,7 +281,11 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to reconnect: {e}")
|
||||
time.sleep(0.1)
|
||||
self.connect()
|
||||
try:
|
||||
self.connect()
|
||||
except Exception as e2:
|
||||
logging.error(f"Failed to reconnect on second attempt: {e2}")
|
||||
raise
|
||||
|
||||
def begin(self):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
@ -352,7 +356,11 @@ class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to reconnect to PostgreSQL: {e}")
|
||||
time.sleep(0.1)
|
||||
self.connect()
|
||||
try:
|
||||
self.connect()
|
||||
except Exception as e2:
|
||||
logging.error(f"Failed to reconnect to PostgreSQL on second attempt: {e2}")
|
||||
raise
|
||||
|
||||
def begin(self):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
@ -1197,224 +1205,93 @@ class Memory(DataBaseModel):
|
||||
class Meta:
|
||||
db_table = "memory"
|
||||
|
||||
class SystemSettings(DataBaseModel):
|
||||
name = CharField(max_length=128, primary_key=True)
|
||||
source = CharField(max_length=32, null=False, index=False)
|
||||
data_type = CharField(max_length=32, null=False, index=False)
|
||||
value = CharField(max_length=1024, null=False, index=False)
|
||||
class Meta:
|
||||
db_table = "system_settings"
|
||||
|
||||
def alter_db_add_column(migrator, table_name, column_name, column_type):
|
||||
try:
|
||||
migrate(migrator.add_column(table_name, column_name, column_type))
|
||||
except OperationalError as ex:
|
||||
error_codes = [1060]
|
||||
error_messages = ['Duplicate column name']
|
||||
|
||||
should_skip_error = (
|
||||
(hasattr(ex, 'args') and ex.args and ex.args[0] in error_codes) or
|
||||
(str(ex) in error_messages)
|
||||
)
|
||||
|
||||
if not should_skip_error:
|
||||
logging.critical(f"Failed to add {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name}, operation error: {ex}")
|
||||
|
||||
except Exception as ex:
|
||||
logging.critical(f"Failed to add {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name}, error: {ex}")
|
||||
pass
|
||||
|
||||
def alter_db_column_type(migrator, table_name, column_name, new_column_type):
|
||||
try:
|
||||
migrate(migrator.alter_column_type(table_name, column_name, new_column_type))
|
||||
except Exception as ex:
|
||||
logging.critical(f"Failed to alter {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name} type, error: {ex}")
|
||||
pass
|
||||
|
||||
def alter_db_rename_column(migrator, table_name, old_column_name, new_column_name):
|
||||
try:
|
||||
migrate(migrator.rename_column(table_name, old_column_name, new_column_name))
|
||||
except Exception:
|
||||
# rename fail will lead to a weired error.
|
||||
# logging.critical(f"Failed to rename {settings.DATABASE_TYPE.upper()}.{table_name} column {old_column_name} to {new_column_name}, error: {ex}")
|
||||
pass
|
||||
|
||||
def migrate_db():
|
||||
logging.disable(logging.ERROR)
|
||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||
try:
|
||||
migrate(migrator.add_column("file", "source_type", CharField(max_length=128, null=False, default="", help_text="where dose this document come from", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("tenant", "rerank_id", CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID")))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("dialog", "rerank_id", CharField(max_length=128, null=False, default="", help_text="default rerank model ID")))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("dialog", "top_k", IntegerField(default=1024)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.alter_column_type("tenant_llm", "api_key", CharField(max_length=2048, null=True, help_text="API KEY", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("api_token", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("tenant", "tts_id", CharField(max_length=256, null=True, help_text="default tts model ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("api_4_conversation", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("task", "retry_count", IntegerField(default=0)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.alter_column_type("api_token", "dialog_id", CharField(max_length=32, null=True, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("tenant_llm", "max_tokens", IntegerField(default=8192, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("api_4_conversation", "dsl", JSONField(null=True, default={})))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "pagerank", IntegerField(default=0, index=False)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("api_token", "beta", CharField(max_length=255, null=True, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("task", "digest", TextField(null=True, help_text="task digest", default="")))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
migrate(migrator.add_column("task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default="")))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("conversation", "user_id", CharField(max_length=255, null=True, help_text="user_id", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("document", "meta_fields", JSONField(null=True, default={})))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("task", "task_type", CharField(max_length=32, null=False, default="")))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("task", "priority", IntegerField(default=0)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("user_canvas", "permission", CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("llm", "is_tools", BooleanField(null=False, help_text="support tools", default=False)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=dict)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.rename_column("task", "process_duation", "process_duration"))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.rename_column("document", "process_duation", "process_duration"))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("api_4_conversation", "errors", TextField(null=True, help_text="errors")))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.alter_column_type("canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title")))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.alter_column_type("canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description")))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("user_canvas", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "raptor_task_finish_at", CharField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "mindmap_task_finish_at", CharField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.alter_column_type("tenant_llm", "api_key", TextField(null=True, help_text="API KEY")))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# RAG Evaluation tables
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "tenant_id", CharField(max_length=32, null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "name", CharField(max_length=255, null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "description", TextField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "kb_ids", JSONField(null=False)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "created_by", CharField(max_length=32, null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "create_time", BigIntegerField(null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "update_time", BigIntegerField(null=False)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1)))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
alter_db_add_column(migrator, "file", "source_type", CharField(max_length=128, null=False, default="", help_text="where dose this document come from", index=True))
|
||||
alter_db_add_column(migrator, "tenant", "rerank_id", CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID"))
|
||||
alter_db_add_column(migrator, "dialog", "rerank_id", CharField(max_length=128, null=False, default="", help_text="default rerank model ID"))
|
||||
alter_db_column_type(migrator, "dialog", "top_k", IntegerField(default=1024))
|
||||
alter_db_add_column(migrator, "tenant_llm", "api_key", CharField(max_length=2048, null=True, help_text="API KEY", index=True))
|
||||
alter_db_add_column(migrator, "api_token", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
|
||||
alter_db_add_column(migrator, "tenant", "tts_id", CharField(max_length=256, null=True, help_text="default tts model ID", index=True))
|
||||
alter_db_add_column(migrator, "api_4_conversation", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
|
||||
alter_db_add_column(migrator, "task", "retry_count", IntegerField(default=0))
|
||||
alter_db_column_type(migrator, "api_token", "dialog_id", CharField(max_length=32, null=True, index=True))
|
||||
alter_db_add_column(migrator, "tenant_llm", "max_tokens", IntegerField(default=8192, index=True))
|
||||
alter_db_add_column(migrator, "api_4_conversation", "dsl", JSONField(null=True, default={}))
|
||||
alter_db_add_column(migrator, "knowledgebase", "pagerank", IntegerField(default=0, index=False))
|
||||
alter_db_add_column(migrator, "api_token", "beta", CharField(max_length=255, null=True, index=True))
|
||||
alter_db_add_column(migrator, "task", "digest", TextField(null=True, help_text="task digest", default=""))
|
||||
alter_db_add_column(migrator, "task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default=""))
|
||||
alter_db_add_column(migrator, "conversation", "user_id", CharField(max_length=255, null=True, help_text="user_id", index=True))
|
||||
alter_db_add_column(migrator, "document", "meta_fields", JSONField(null=True, default={}))
|
||||
alter_db_add_column(migrator, "task", "task_type", CharField(max_length=32, null=False, default=""))
|
||||
alter_db_add_column(migrator, "task", "priority", IntegerField(default=0))
|
||||
alter_db_add_column(migrator, "user_canvas", "permission", CharField(max_length=16, null=False, help_text="me|team", default="me", index=True))
|
||||
alter_db_add_column(migrator, "llm", "is_tools", BooleanField(null=False, help_text="support tools", default=False))
|
||||
alter_db_add_column(migrator, "mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=dict))
|
||||
alter_db_rename_column(migrator, "task", "process_duation", "process_duration")
|
||||
alter_db_rename_column(migrator, "document", "process_duation", "process_duration")
|
||||
alter_db_add_column(migrator, "document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True))
|
||||
alter_db_add_column(migrator, "api_4_conversation", "errors", TextField(null=True, help_text="errors"))
|
||||
alter_db_add_column(migrator, "dialog", "meta_data_filter", JSONField(null=True, default={}))
|
||||
alter_db_column_type(migrator, "canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title"))
|
||||
alter_db_column_type(migrator, "canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description"))
|
||||
alter_db_add_column(migrator, "user_canvas", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True))
|
||||
alter_db_add_column(migrator, "canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True))
|
||||
alter_db_add_column(migrator, "knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True))
|
||||
alter_db_add_column(migrator, "document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True))
|
||||
alter_db_add_column(migrator, "knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True))
|
||||
alter_db_add_column(migrator, "knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True))
|
||||
alter_db_add_column(migrator, "knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True))
|
||||
alter_db_add_column(migrator, "knowledgebase", "raptor_task_finish_at", CharField(null=True))
|
||||
alter_db_add_column(migrator, "knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True))
|
||||
alter_db_add_column(migrator, "knowledgebase", "mindmap_task_finish_at", CharField(null=True))
|
||||
alter_db_column_type(migrator, "tenant_llm", "api_key", TextField(null=True, help_text="API KEY"))
|
||||
alter_db_add_column(migrator, "tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True))
|
||||
alter_db_add_column(migrator, "connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False))
|
||||
alter_db_add_column(migrator, "llm_factories", "rank", IntegerField(default=0, index=False))
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
@ -30,6 +30,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.db.services.system_settings_service import SystemSettingsService
|
||||
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache
|
||||
from common.constants import LLMType
|
||||
from common.file_utils import get_project_base_directory
|
||||
@ -158,13 +159,15 @@ def add_graph_templates():
|
||||
CanvasTemplateService.save(**cnvs)
|
||||
except Exception:
|
||||
CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
|
||||
except Exception:
|
||||
logging.exception("Add agent templates error: ")
|
||||
except Exception as e:
|
||||
logging.exception(f"Add agent templates error: {e}")
|
||||
|
||||
|
||||
def init_web_data():
|
||||
start_time = time.time()
|
||||
|
||||
init_table()
|
||||
|
||||
init_llm_factory()
|
||||
# if not UserService.get_all().count():
|
||||
# init_superuser()
|
||||
@ -174,6 +177,31 @@ def init_web_data():
|
||||
init_memory_size_cache()
|
||||
logging.info("init web data success:{}".format(time.time() - start_time))
|
||||
|
||||
def init_table():
|
||||
# init system_settings
|
||||
with open(os.path.join(get_project_base_directory(), "conf", "system_settings.json"), "r") as f:
|
||||
records_from_file = json.load(f)["system_settings"]
|
||||
|
||||
record_index = {}
|
||||
records_from_db = SystemSettingsService.get_all()
|
||||
for index, record in enumerate(records_from_db):
|
||||
record_index[record.name] = index
|
||||
|
||||
to_save = []
|
||||
for record in records_from_file:
|
||||
setting_name = record["name"]
|
||||
if setting_name not in record_index:
|
||||
to_save.append(record)
|
||||
|
||||
len_to_save = len(to_save)
|
||||
if len_to_save > 0:
|
||||
# not initialized
|
||||
try:
|
||||
SystemSettingsService.insert_many(to_save, len_to_save)
|
||||
except Exception as e:
|
||||
logging.exception("System settings init error: {}".format(e))
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
init_web_db()
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from api.db.services.task_service import TaskService
|
||||
from common import settings
|
||||
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
|
||||
from common.constants import MemoryType, LLMType
|
||||
@ -24,6 +23,7 @@ from common.doc_store.doc_store_base import FusionExpr
|
||||
from common.misc_utils import get_uuid
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api.db.db_models import Task
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
@ -90,13 +90,19 @@ async def save_to_memory(memory_id: str, message_dict: dict):
|
||||
return await embed_and_save(memory, message_list)
|
||||
|
||||
|
||||
async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int):
|
||||
async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int, task_id: str=None):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return False, f"Memory '{memory_id}' not found."
|
||||
msg = f"Memory '{memory_id}' not found."
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||
return False, msg
|
||||
|
||||
if memory.memory_type == MemoryType.RAW.value:
|
||||
return True, f"Memory '{memory_id}' don't need to extract."
|
||||
msg = f"Memory '{memory_id}' don't need to extract."
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||
return True, msg
|
||||
|
||||
tenant_id = memory.tenant_id
|
||||
extracted_content = await extract_by_llm(
|
||||
@ -105,7 +111,8 @@ async def save_extracted_to_memory_only(memory_id: str, message_dict, source_mes
|
||||
{"temperature": memory.temperature},
|
||||
get_memory_type_human(memory.memory_type),
|
||||
message_dict.get("user_input", ""),
|
||||
message_dict.get("agent_response", "")
|
||||
message_dict.get("agent_response", ""),
|
||||
task_id=task_id
|
||||
)
|
||||
message_list = [{
|
||||
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
|
||||
@ -122,13 +129,18 @@ async def save_extracted_to_memory_only(memory_id: str, message_dict, source_mes
|
||||
"status": True
|
||||
} for content in extracted_content]
|
||||
if not message_list:
|
||||
return True, "No memory extracted from raw message."
|
||||
msg = "No memory extracted from raw message."
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||
return True, msg
|
||||
|
||||
return await embed_and_save(memory, message_list)
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 0.5, "progress_msg": timestamp_to_date(current_timestamp())+ " " + f"Extracted {len(message_list)} messages from raw dialogue."})
|
||||
return await embed_and_save(memory, message_list, task_id)
|
||||
|
||||
|
||||
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
|
||||
agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]:
|
||||
agent_response: str, system_prompt: str = "", user_prompt: str="", task_id: str=None) -> List[dict]:
|
||||
llm_type = TenantLLMService.llm_id2llm_type(llm_id)
|
||||
if not llm_type:
|
||||
raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
|
||||
@ -143,8 +155,12 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
|
||||
else:
|
||||
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
|
||||
llm = LLMBundle(tenant_id, llm_type, llm_id)
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 0.15, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared prompts and LLM."})
|
||||
res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
|
||||
res_json = get_json_result_from_llm_response(res)
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 0.35, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Get extracted result from LLM."})
|
||||
return [{
|
||||
"content": extracted_content["content"],
|
||||
"valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]),
|
||||
@ -153,16 +169,23 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
|
||||
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
|
||||
|
||||
|
||||
async def embed_and_save(memory, message_list: list[dict]):
|
||||
async def embed_and_save(memory, message_list: list[dict], task_id: str=None):
|
||||
embedding_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 0.65, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared embedding model."})
|
||||
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
|
||||
for idx, msg in enumerate(message_list):
|
||||
msg["content_embed"] = vector_list[idx]
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 0.85, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Embedded extracted content."})
|
||||
vector_dimension = len(vector_list[0])
|
||||
if not MessageService.has_index(memory.tenant_id, memory.id):
|
||||
created = MessageService.create_index(memory.tenant_id, memory.id, vector_size=vector_dimension)
|
||||
if not created:
|
||||
return False, "Failed to create message index."
|
||||
error_msg = "Failed to create message index."
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
|
||||
return False, error_msg
|
||||
|
||||
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
|
||||
current_memory_size = get_memory_size_cache(memory.tenant_id, memory.id)
|
||||
@ -174,11 +197,19 @@ async def embed_and_save(memory, message_list: list[dict]):
|
||||
MessageService.delete_message({"message_id": message_ids_to_delete}, memory.tenant_id, memory.id)
|
||||
decrease_memory_size_cache(memory.id, delete_size)
|
||||
else:
|
||||
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
|
||||
error_msg = "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
|
||||
return False, error_msg
|
||||
fail_cases = MessageService.insert_message(message_list, memory.tenant_id, memory.id)
|
||||
if fail_cases:
|
||||
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
|
||||
error_msg = "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
|
||||
return False, error_msg
|
||||
|
||||
if task_id:
|
||||
TaskService.update_progress(task_id, {"progress": 0.95, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Saved messages to storage."})
|
||||
increase_memory_size_cache(memory.id, new_msg_size)
|
||||
return True, "Message saved successfully."
|
||||
|
||||
@ -379,11 +410,11 @@ async def handle_save_to_memory_task(task_param: dict):
|
||||
memory_id = task_param["memory_id"]
|
||||
source_id = task_param["source_id"]
|
||||
message_dict = task_param["message_dict"]
|
||||
success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id)
|
||||
success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id, task.id)
|
||||
if success:
|
||||
TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": msg})
|
||||
TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||
return True, msg
|
||||
|
||||
logging.error(msg)
|
||||
TaskService.update_progress(task.id, {"progress": -1, "progress_msg": None})
|
||||
TaskService.update_progress(task.id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||
return False, msg
|
||||
|
||||
@ -190,10 +190,15 @@ class CommonService:
|
||||
data_list (list): List of dictionaries containing record data to insert.
|
||||
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
|
||||
"""
|
||||
current_ts = current_timestamp()
|
||||
current_datetime = datetime_format(datetime.now())
|
||||
with DB.atomic():
|
||||
for d in data_list:
|
||||
d["create_time"] = current_timestamp()
|
||||
d["create_date"] = datetime_format(datetime.now())
|
||||
d["create_time"] = current_ts
|
||||
d["create_date"] = current_datetime
|
||||
d["update_time"] = current_ts
|
||||
d["update_date"] = current_datetime
|
||||
|
||||
for i in range(0, len(data_list), batch_size):
|
||||
cls.model.insert_many(data_list[i : i + batch_size]).execute()
|
||||
|
||||
|
||||
@ -29,7 +29,6 @@ from common.misc_utils import get_uuid
|
||||
from common.constants import TaskStatus
|
||||
from common.time_utils import current_timestamp, timestamp_to_date
|
||||
|
||||
|
||||
class ConnectorService(CommonService):
|
||||
model = Connector
|
||||
|
||||
@ -202,6 +201,7 @@ class SyncLogsService(CommonService):
|
||||
return None
|
||||
|
||||
class FileObj(BaseModel):
|
||||
id: str
|
||||
filename: str
|
||||
blob: bytes
|
||||
|
||||
@ -209,7 +209,7 @@ class SyncLogsService(CommonService):
|
||||
return self.blob
|
||||
|
||||
errs = []
|
||||
files = [FileObj(filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs]
|
||||
files = [FileObj(id=d["id"], filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs]
|
||||
doc_ids = []
|
||||
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
|
||||
errs.extend(err)
|
||||
|
||||
@ -64,11 +64,13 @@ class ConversationService(CommonService):
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
|
||||
def structure_answer(conv, ans, message_id, session_id):
|
||||
reference = ans["reference"]
|
||||
if not isinstance(reference, dict):
|
||||
reference = {}
|
||||
ans["reference"] = {}
|
||||
is_final = ans.get("final", True)
|
||||
|
||||
chunk_list = chunks_format(reference)
|
||||
|
||||
@ -81,14 +83,32 @@ def structure_answer(conv, ans, message_id, session_id):
|
||||
|
||||
if not conv.message:
|
||||
conv.message = []
|
||||
content = ans["answer"]
|
||||
if ans.get("start_to_think"):
|
||||
content = "<think>"
|
||||
elif ans.get("end_to_think"):
|
||||
content = "</think>"
|
||||
|
||||
if not conv.message or conv.message[-1].get("role", "") != "assistant":
|
||||
conv.message.append({"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id})
|
||||
conv.message.append({"role": "assistant", "content": content, "created_at": time.time(), "id": message_id})
|
||||
else:
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
|
||||
if is_final:
|
||||
if ans.get("answer"):
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
|
||||
else:
|
||||
conv.message[-1]["created_at"] = time.time()
|
||||
conv.message[-1]["id"] = message_id
|
||||
else:
|
||||
conv.message[-1]["content"] = (conv.message[-1].get("content") or "") + content
|
||||
conv.message[-1]["created_at"] = time.time()
|
||||
conv.message[-1]["id"] = message_id
|
||||
if conv.reference:
|
||||
conv.reference[-1] = reference
|
||||
should_update_reference = is_final or bool(reference.get("chunks")) or bool(reference.get("doc_aggs"))
|
||||
if should_update_reference:
|
||||
conv.reference[-1] = reference
|
||||
return ans
|
||||
|
||||
|
||||
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
||||
assert name, "`name` can not be empty."
|
||||
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import binascii
|
||||
import logging
|
||||
import re
|
||||
@ -23,7 +24,6 @@ from functools import partial
|
||||
from timeit import default_timer as timer
|
||||
from langfuse import Langfuse
|
||||
from peewee import fn
|
||||
from agentic_reasoning import DeepResearcher
|
||||
from api.db.services.file_service import FileService
|
||||
from common.constants import LLMType, ParserType, StatusEnum
|
||||
from api.db.db_models import DB, Dialog
|
||||
@ -36,7 +36,7 @@ from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.advanced_rag import DeepResearcher
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp.search import index_name
|
||||
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||
@ -196,19 +196,13 @@ async def async_chat_solo(dialog, messages, stream=True):
|
||||
if attachments and msg:
|
||||
msg[-1]["content"] += attachments
|
||||
if stream:
|
||||
last_ans = ""
|
||||
delta_ans = ""
|
||||
answer = ""
|
||||
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||
if kind == "marker":
|
||||
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||||
yield {"answer": "", "reference": {}, "audio_binary": None, "prompt": "", "created_at": time.time(), "final": False, **flags}
|
||||
continue
|
||||
last_ans = answer
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||
delta_ans = ""
|
||||
if delta_ans:
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False}
|
||||
else:
|
||||
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
@ -279,6 +273,7 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
||||
|
||||
|
||||
async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
logging.debug("Begin async_chat")
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||
async for ans in async_chat_solo(dialog, messages, stream):
|
||||
@ -301,10 +296,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id)
|
||||
if langfuse_keys:
|
||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||||
if langfuse.auth_check():
|
||||
langfuse_tracer = langfuse
|
||||
trace_id = langfuse_tracer.create_trace_id()
|
||||
trace_context = {"trace_id": trace_id}
|
||||
try:
|
||||
if langfuse.auth_check():
|
||||
langfuse_tracer = langfuse
|
||||
trace_id = langfuse_tracer.create_trace_id()
|
||||
trace_context = {"trace_id": trace_id}
|
||||
except Exception:
|
||||
# Skip langfuse tracing if connection fails
|
||||
pass
|
||||
|
||||
check_langfuse_tracer_ts = timer()
|
||||
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
|
||||
@ -324,13 +323,20 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
logging.debug(f"field_map retrieved: {field_map}")
|
||||
# try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||
if ans:
|
||||
# For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid
|
||||
if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")):
|
||||
yield ans
|
||||
return
|
||||
else:
|
||||
logging.debug("SQL failed or returned no results, falling back to vector search")
|
||||
|
||||
param_keys = [p["key"] for p in prompt_config.get("parameters", [])]
|
||||
logging.debug(f"attachments={attachments}, param_keys={param_keys}, embd_mdl={embd_mdl}")
|
||||
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["key"] == "knowledge":
|
||||
@ -367,10 +373,11 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||
knowledges = []
|
||||
|
||||
if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]:
|
||||
if attachments is not None and "knowledge" in param_keys:
|
||||
logging.debug("Proceeding with retrieval")
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
knowledges = []
|
||||
if prompt_config.get("reasoning", False):
|
||||
if prompt_config.get("reasoning", False) or kwargs.get("reasoning"):
|
||||
reasoner = DeepResearcher(
|
||||
chat_mdl,
|
||||
prompt_config,
|
||||
@ -386,16 +393,28 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
doc_ids=attachments,
|
||||
),
|
||||
)
|
||||
queue = asyncio.Queue()
|
||||
async def callback(msg:str):
|
||||
nonlocal queue
|
||||
await queue.put(msg + "<br/>")
|
||||
|
||||
await callback("<START_DEEP_RESEARCH>")
|
||||
task = asyncio.create_task(reasoner.research(kbinfos, questions[-1], questions[-1], callback=callback))
|
||||
while True:
|
||||
msg = await queue.get()
|
||||
if msg.find("<START_DEEP_RESEARCH>") == 0:
|
||||
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "start_to_think": True}
|
||||
elif msg.find("<END_DEEP_RESEARCH>") == 0:
|
||||
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True}
|
||||
break
|
||||
else:
|
||||
yield {"answer": msg, "reference": {}, "audio_binary": None, "final": False}
|
||||
|
||||
await task
|
||||
|
||||
async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
|
||||
if isinstance(think, str):
|
||||
thought = think
|
||||
knowledges = [t for t in think.split("\n") if t]
|
||||
elif stream:
|
||||
yield think
|
||||
else:
|
||||
if embd_mdl:
|
||||
kbinfos = retriever.retrieval(
|
||||
kbinfos = await retriever.retrieval(
|
||||
" ".join(questions),
|
||||
embd_mdl,
|
||||
tenant_ids,
|
||||
@ -411,7 +430,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
rank_feature=label_question(" ".join(questions), kbs),
|
||||
)
|
||||
if prompt_config.get("toc_enhance"):
|
||||
cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
|
||||
cks = await retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids)
|
||||
@ -426,16 +445,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
|
||||
retrieval_ts = timer()
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
empty_res = prompt_config["empty_response"]
|
||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||
"audio_binary": tts(tts_mdl, empty_res)}
|
||||
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
"audio_binary": tts(tts_mdl, empty_res), "final": True}
|
||||
return
|
||||
|
||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||
@ -538,21 +555,22 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
)
|
||||
|
||||
if stream:
|
||||
last_ans = ""
|
||||
answer = ""
|
||||
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
||||
if thought:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
last_state = None
|
||||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||
last_state = state
|
||||
if kind == "marker":
|
||||
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||||
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, **flags}
|
||||
continue
|
||||
last_ans = answer
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
delta_ans = answer[len(last_ans):]
|
||||
if delta_ans:
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
yield decorate_answer(thought + answer)
|
||||
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "final": False}
|
||||
full_answer = last_state.full_text if last_state else ""
|
||||
if full_answer:
|
||||
final = decorate_answer(thought + full_answer)
|
||||
final["final"] = True
|
||||
final["audio_binary"] = None
|
||||
final["answer"] = ""
|
||||
yield final
|
||||
else:
|
||||
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
@ -565,112 +583,306 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
|
||||
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||
sys_prompt = """
|
||||
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
||||
Ensure that:
|
||||
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
|
||||
2. Write only the SQL, no explanations or additional text.
|
||||
"""
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
Table of database fields are as follows:
|
||||
{}
|
||||
logging.debug(f"use_sql: Question: {question}")
|
||||
|
||||
Question are as follows:
|
||||
# Determine which document engine we're using
|
||||
doc_engine = "infinity" if settings.DOC_ENGINE_INFINITY else "es"
|
||||
|
||||
# Construct the full table name
|
||||
# For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause)
|
||||
# For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table)
|
||||
base_table = index_name(tenant_id)
|
||||
if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1:
|
||||
# Infinity: append kb_id to table name
|
||||
table_name = f"{base_table}_{kb_ids[0]}"
|
||||
logging.debug(f"use_sql: Using Infinity table name: {table_name}")
|
||||
else:
|
||||
# Elasticsearch/OpenSearch: use base index name
|
||||
table_name = base_table
|
||||
logging.debug(f"use_sql: Using ES/OS table name: {table_name}")
|
||||
|
||||
# Generate engine-specific SQL prompts
|
||||
if doc_engine == "infinity":
|
||||
# Build Infinity prompts with JSON extraction context
|
||||
json_field_names = list(field_map.keys())
|
||||
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
|
||||
|
||||
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
|
||||
Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT)
|
||||
NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false
|
||||
|
||||
RULES:
|
||||
1. Use EXACT field names (case-sensitive) from the list below
|
||||
2. For SELECT: include doc_id, docnm, and json_extract_string() for requested fields
|
||||
3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...))
|
||||
4. Add AS alias for extracted field names
|
||||
5. DO NOT select 'content' field
|
||||
6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when:
|
||||
- Question asks to "show me" or "display" specific columns
|
||||
- Question mentions "not null" or "excluding null"
|
||||
- Add NULL check for count specific column
|
||||
- DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls)
|
||||
7. Output ONLY the SQL, no explanations"""
|
||||
user_prompt = """Table: {}
|
||||
Fields (EXACT case): {}
|
||||
{}
|
||||
Please write the SQL, only SQL, without any other explanations or text.
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
|
||||
Question: {}
|
||||
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format(
|
||||
table_name,
|
||||
", ".join(json_field_names),
|
||||
"\n".join([f" - {field}" for field in json_field_names]),
|
||||
question
|
||||
)
|
||||
else:
|
||||
# Build ES/OS prompts with direct field access
|
||||
sys_prompt = """You are a Database Administrator. Write SQL queries.
|
||||
|
||||
RULES:
|
||||
1. Use EXACT field names from the schema below (e.g., product_tks, not product)
|
||||
2. Quote field names starting with digit: "123_field"
|
||||
3. Add IS NOT NULL in WHERE clause when:
|
||||
- Question asks to "show me" or "display" specific columns
|
||||
4. Include doc_id/docnm in non-aggregate statement
|
||||
5. Output ONLY the SQL, no explanations"""
|
||||
user_prompt = """Table: {}
|
||||
Available fields:
|
||||
{}
|
||||
Question: {}
|
||||
Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
|
||||
table_name,
|
||||
"\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
|
||||
question
|
||||
)
|
||||
|
||||
tried_times = 0
|
||||
|
||||
async def get_table():
|
||||
nonlocal sys_prompt, user_prompt, question, tried_times
|
||||
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
|
||||
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||
sql = re.sub(r".*select ", "select ", sql.lower())
|
||||
sql = re.sub(r" +", " ", sql)
|
||||
sql = re.sub(r"([;;]|```).*", "", sql)
|
||||
sql = re.sub(r"&", "and", sql)
|
||||
if sql[: len("select ")] != "select ":
|
||||
return None, None
|
||||
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
|
||||
if sql[: len("select *")] != "select *":
|
||||
sql = "select doc_id,docnm_kwd," + sql[6:]
|
||||
else:
|
||||
flds = []
|
||||
for k in field_map.keys():
|
||||
if k in forbidden_select_fields4resume:
|
||||
continue
|
||||
if len(flds) > 11:
|
||||
break
|
||||
flds.append(k)
|
||||
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
||||
logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
|
||||
# Remove think blocks if present (format: </think>...)
|
||||
sql = re.sub(r"</think>\n.*?\n\s*", "", sql, flags=re.DOTALL)
|
||||
sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL)
|
||||
# Remove markdown code blocks (```sql ... ```)
|
||||
sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
|
||||
sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
|
||||
# Remove trailing semicolon that ES SQL parser doesn't like
|
||||
sql = sql.rstrip().rstrip(';').strip()
|
||||
|
||||
if kb_ids:
|
||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||
if "where" not in sql.lower():
|
||||
# Add kb_id filter for ES/OS only (Infinity already has it in table name)
|
||||
if doc_engine != "infinity" and kb_ids:
|
||||
# Build kb_filter: single KB or multiple KBs with OR
|
||||
if len(kb_ids) == 1:
|
||||
kb_filter = f"kb_id = '{kb_ids[0]}'"
|
||||
else:
|
||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||
|
||||
if "where " not in sql.lower():
|
||||
o = sql.lower().split("order by")
|
||||
if len(o) > 1:
|
||||
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
|
||||
else:
|
||||
sql += f" WHERE {kb_filter}"
|
||||
else:
|
||||
sql += f" AND {kb_filter}"
|
||||
elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower():
|
||||
sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE)
|
||||
|
||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||
tried_times += 1
|
||||
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
||||
logging.debug(f"use_sql: Executing SQL retrieval (attempt {tried_times})")
|
||||
tbl = settings.retriever.sql_retrieval(sql, format="json")
|
||||
if tbl is None:
|
||||
logging.debug("use_sql: SQL retrieval returned None")
|
||||
return None, sql
|
||||
logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows")
|
||||
return tbl, sql
|
||||
|
||||
try:
|
||||
tbl, sql = await get_table()
|
||||
logging.debug(f"use_sql: Initial SQL execution SUCCESS. SQL: {sql}")
|
||||
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows, columns: {[c['name'] for c in tbl.get('columns', [])]}")
|
||||
except Exception as e:
|
||||
user_prompt = """
|
||||
logging.warning(f"use_sql: Initial SQL execution FAILED with error: {e}")
|
||||
# Build retry prompt with error information
|
||||
if doc_engine == "infinity":
|
||||
# Build Infinity error retry prompt
|
||||
json_field_names = list(field_map.keys())
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
JSON fields available in 'chunk_data' column (use these exact names in json_extract_string):
|
||||
{}
|
||||
|
||||
Question: {}
|
||||
Please write the SQL using json_extract_string(chunk_data, '$.field_name') with the field names from the list above. Only SQL, no explanations.
|
||||
|
||||
|
||||
The SQL error you provided last time is as follows:
|
||||
{}
|
||||
|
||||
Please correct the error and write SQL again using json_extract_string(chunk_data, '$.field_name') syntax with the correct field names. Only SQL, no explanations.
|
||||
""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, e)
|
||||
else:
|
||||
# Build ES/OS error retry prompt
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
Table of database fields are as follows:
|
||||
Table of database fields are as follows (use the field names directly in SQL):
|
||||
{}
|
||||
|
||||
Question are as follows:
|
||||
{}
|
||||
Please write the SQL, only SQL, without any other explanations or text.
|
||||
Please write the SQL using the exact field names above, only SQL, without any other explanations or text.
|
||||
|
||||
|
||||
The SQL error you provided last time is as follows:
|
||||
{}
|
||||
|
||||
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
|
||||
Please correct the error and write SQL again using the exact field names above, only SQL, without any other explanations or text.
|
||||
""".format(table_name, "\n".join([f"{k} ({v})" for k, v in field_map.items()]), question, e)
|
||||
try:
|
||||
tbl, sql = await get_table()
|
||||
logging.debug(f"use_sql: Retry SQL execution SUCCESS. SQL: {sql}")
|
||||
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows on retry")
|
||||
except Exception:
|
||||
logging.error("use_sql: Retry SQL execution also FAILED, returning None")
|
||||
return
|
||||
|
||||
if len(tbl["rows"]) == 0:
|
||||
logging.warning(f"use_sql: No rows returned from SQL query, returning None. SQL: {sql}")
|
||||
return None
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
logging.debug(f"use_sql: Proceeding with {len(tbl['rows'])} rows to build answer")
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"])
|
||||
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]])
|
||||
|
||||
logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}")
|
||||
logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}")
|
||||
|
||||
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
|
||||
|
||||
logging.debug(f"use_sql: column_idx={column_idx}")
|
||||
logging.debug(f"use_sql: field_map={field_map}")
|
||||
|
||||
# Helper function to map column names to display names
|
||||
def map_column_name(col_name):
|
||||
if col_name.lower() == "count(star)":
|
||||
return "COUNT(*)"
|
||||
|
||||
# First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.)
|
||||
# Pattern: anything AS alias_name
|
||||
as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE)
|
||||
if as_match:
|
||||
alias = as_match.group(1).strip('"\'')
|
||||
|
||||
# Use the alias for display name lookup
|
||||
if alias in field_map:
|
||||
display = field_map[alias]
|
||||
return re.sub(r"(/.*|([^()]+))", "", display)
|
||||
# If alias not in field_map, try to match case-insensitively
|
||||
for field_key, display_value in field_map.items():
|
||||
if field_key.lower() == alias.lower():
|
||||
return re.sub(r"(/.*|([^()]+))", "", display_value)
|
||||
# Return alias as-is if no mapping found
|
||||
return alias
|
||||
|
||||
# Try direct mapping first (for simple column names)
|
||||
if col_name in field_map:
|
||||
display = field_map[col_name]
|
||||
# Clean up any suffix patterns
|
||||
return re.sub(r"(/.*|([^()]+))", "", display)
|
||||
|
||||
# Try case-insensitive match for simple column names
|
||||
col_lower = col_name.lower()
|
||||
for field_key, display_value in field_map.items():
|
||||
if field_key.lower() == col_lower:
|
||||
return re.sub(r"(/.*|([^()]+))", "", display_value)
|
||||
|
||||
# For aggregate expressions or complex expressions without AS alias,
|
||||
# try to replace field names with display names
|
||||
result = col_name
|
||||
for field_name, display_name in field_map.items():
|
||||
# Replace field_name with display_name in the expression
|
||||
result = result.replace(field_name, display_name)
|
||||
|
||||
# Clean up any suffix patterns
|
||||
result = re.sub(r"(/.*|([^()]+))", "", result)
|
||||
return result
|
||||
|
||||
# compose Markdown table
|
||||
columns = (
|
||||
"|" + "|".join(
|
||||
[re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
|
||||
"|Source|" if docid_idx and docid_idx else "|")
|
||||
[map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + (
|
||||
"|Source|" if docid_idx and doc_name_idx else "|")
|
||||
)
|
||||
|
||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||
|
||||
rows = ["|" + "|".join([remove_redundant_spaces(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
|
||||
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
||||
# Build rows ensuring column names match values - create a dict for each row
|
||||
# keyed by column name to handle any SQL column order
|
||||
rows = []
|
||||
for row_idx, r in enumerate(tbl["rows"]):
|
||||
row_dict = {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)}
|
||||
if row_idx == 0:
|
||||
logging.debug(f"use_sql: First row data: {row_dict}")
|
||||
row_values = []
|
||||
for col_idx in column_idx:
|
||||
col_name = tbl["columns"][col_idx]["name"]
|
||||
value = row_dict.get(col_name, " ")
|
||||
row_values.append(remove_redundant_spaces(str(value)).replace("None", " "))
|
||||
# Add Source column with citation marker if Source column exists
|
||||
if docid_idx and doc_name_idx:
|
||||
row_values.append(f" ##{row_idx}$$")
|
||||
row_str = "|" + "|".join(row_values) + "|"
|
||||
if re.sub(r"[ |]+", "", row_str):
|
||||
rows.append(row_str)
|
||||
if quota:
|
||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
rows = "\n".join(rows)
|
||||
else:
|
||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
rows = "\n".join(rows)
|
||||
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
||||
|
||||
if not docid_idx or not doc_name_idx:
|
||||
logging.warning("SQL missing field: " + sql)
|
||||
logging.warning(f"use_sql: SQL missing required doc_id or docnm_kwd field. docid_idx={docid_idx}, doc_name_idx={doc_name_idx}. SQL: {sql}")
|
||||
# For aggregate queries (COUNT, SUM, AVG, MAX, MIN, DISTINCT), fetch doc_id, docnm_kwd separately
|
||||
# to provide source chunks, but keep the original table format answer
|
||||
if re.search(r"(count|sum|avg|max|min|distinct)\s*\(", sql.lower()):
|
||||
# Keep original table format as answer
|
||||
answer = "\n".join([columns, line, rows])
|
||||
|
||||
# Now fetch doc_id, docnm_kwd to provide source chunks
|
||||
# Extract WHERE clause from the original SQL
|
||||
where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE)
|
||||
if where_match:
|
||||
where_clause = where_match.group(1).strip()
|
||||
# Build a query to get doc_id and docnm_kwd with the same WHERE clause
|
||||
chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}"
|
||||
# Add LIMIT to avoid fetching too many chunks
|
||||
if "limit" not in chunks_sql.lower():
|
||||
chunks_sql += " limit 20"
|
||||
logging.debug(f"use_sql: Fetching chunks with SQL: {chunks_sql}")
|
||||
try:
|
||||
chunks_tbl = settings.retriever.sql_retrieval(chunks_sql, format="json")
|
||||
if chunks_tbl.get("rows") and len(chunks_tbl["rows"]) > 0:
|
||||
# Build chunks reference - use case-insensitive matching
|
||||
chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None)
|
||||
chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None)
|
||||
if chunks_did_idx is not None and chunks_dn_idx is not None:
|
||||
chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]]
|
||||
# Build doc_aggs
|
||||
doc_aggs = {}
|
||||
for r in chunks_tbl["rows"]:
|
||||
doc_id = r[chunks_did_idx]
|
||||
doc_name = r[chunks_dn_idx]
|
||||
if doc_id not in doc_aggs:
|
||||
doc_aggs[doc_id] = {"doc_name": doc_name, "count": 0}
|
||||
doc_aggs[doc_id]["count"] += 1
|
||||
doc_aggs_list = [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]
|
||||
logging.debug(f"use_sql: Returning aggregate answer with {len(chunks)} chunks from {len(doc_aggs)} documents")
|
||||
return {"answer": answer, "reference": {"chunks": chunks, "doc_aggs": doc_aggs_list}, "prompt": sys_prompt}
|
||||
except Exception as e:
|
||||
logging.warning(f"use_sql: Failed to fetch chunks: {e}")
|
||||
# Fallback: return answer without chunks
|
||||
return {"answer": answer, "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
|
||||
# Fallback to table format for other cases
|
||||
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
|
||||
|
||||
docid_idx = list(docid_idx)[0]
|
||||
@ -680,7 +892,8 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
if r[docid_idx] not in doc_aggs:
|
||||
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
|
||||
doc_aggs[r[docid_idx]]["count"] += 1
|
||||
return {
|
||||
|
||||
result = {
|
||||
"answer": "\n".join([columns, line, rows]),
|
||||
"reference": {
|
||||
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
|
||||
@ -688,6 +901,8 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
},
|
||||
"prompt": sys_prompt,
|
||||
}
|
||||
logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents")
|
||||
return result
|
||||
|
||||
def clean_tts_text(text: str) -> str:
|
||||
if not text:
|
||||
@ -733,6 +948,84 @@ def tts(tts_mdl, text):
|
||||
return None
|
||||
return binascii.hexlify(bin).decode("utf-8")
|
||||
|
||||
|
||||
class _ThinkStreamState:
|
||||
def __init__(self) -> None:
|
||||
self.full_text = ""
|
||||
self.last_idx = 0
|
||||
self.endswith_think = False
|
||||
self.last_full = ""
|
||||
self.last_model_full = ""
|
||||
self.in_think = False
|
||||
self.buffer = ""
|
||||
|
||||
|
||||
def _next_think_delta(state: _ThinkStreamState) -> str:
|
||||
full_text = state.full_text
|
||||
if full_text == state.last_full:
|
||||
return ""
|
||||
state.last_full = full_text
|
||||
delta_ans = full_text[state.last_idx:]
|
||||
|
||||
if delta_ans.find("<think>") == 0:
|
||||
state.last_idx += len("<think>")
|
||||
return "<think>"
|
||||
if delta_ans.find("<think>") > 0:
|
||||
delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("<think>")]
|
||||
state.last_idx += delta_ans.find("<think>")
|
||||
return delta_text
|
||||
if delta_ans.endswith("</think>"):
|
||||
state.endswith_think = True
|
||||
elif state.endswith_think:
|
||||
state.endswith_think = False
|
||||
return "</think>"
|
||||
|
||||
state.last_idx = len(full_text)
|
||||
if full_text.endswith("</think>"):
|
||||
state.last_idx -= len("</think>")
|
||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
|
||||
async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
|
||||
state = _ThinkStreamState()
|
||||
async for chunk in stream_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
if chunk.startswith(state.last_model_full):
|
||||
new_part = chunk[len(state.last_model_full):]
|
||||
state.last_model_full = chunk
|
||||
else:
|
||||
new_part = chunk
|
||||
state.last_model_full += chunk
|
||||
if not new_part:
|
||||
continue
|
||||
state.full_text += new_part
|
||||
delta = _next_think_delta(state)
|
||||
if not delta:
|
||||
continue
|
||||
if delta in ("<think>", "</think>"):
|
||||
if delta == "<think>" and state.in_think:
|
||||
continue
|
||||
if delta == "</think>" and not state.in_think:
|
||||
continue
|
||||
if state.buffer:
|
||||
yield ("text", state.buffer, state)
|
||||
state.buffer = ""
|
||||
state.in_think = delta == "<think>"
|
||||
yield ("marker", delta, state)
|
||||
continue
|
||||
state.buffer += delta
|
||||
if num_tokens_from_string(state.buffer) < min_tokens:
|
||||
continue
|
||||
yield ("text", state.buffer, state)
|
||||
state.buffer = ""
|
||||
|
||||
if state.buffer:
|
||||
yield ("text", state.buffer, state)
|
||||
state.buffer = ""
|
||||
if state.endswith_think:
|
||||
yield ("marker", "</think>", state)
|
||||
|
||||
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
doc_ids = search_config.get("doc_ids", [])
|
||||
rerank_mdl = None
|
||||
@ -758,7 +1051,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||||
|
||||
kbinfos = retriever.retrieval(
|
||||
kbinfos = await retriever.retrieval(
|
||||
question=question,
|
||||
embd_mdl=embd_mdl,
|
||||
tenant_ids=tenant_ids,
|
||||
@ -798,11 +1091,20 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
||||
refs["chunks"] = chunks_format(refs)
|
||||
return {"answer": answer, "reference": refs}
|
||||
|
||||
answer = ""
|
||||
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
||||
answer = ans
|
||||
yield {"answer": answer, "reference": {}}
|
||||
yield decorate_answer(answer)
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(sys_prompt, msg, {"temperature": 0.1})
|
||||
last_state = None
|
||||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||
last_state = state
|
||||
if kind == "marker":
|
||||
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||||
yield {"answer": "", "reference": {}, "final": False, **flags}
|
||||
continue
|
||||
yield {"answer": value, "reference": {}, "final": False}
|
||||
full_answer = last_state.full_text if last_state else ""
|
||||
final = decorate_answer(full_answer)
|
||||
final["final"] = True
|
||||
final["answer"] = ""
|
||||
yield final
|
||||
|
||||
|
||||
async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
@ -825,7 +1127,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||||
|
||||
ranks = settings.retriever.retrieval(
|
||||
ranks = await settings.retriever.retrieval(
|
||||
question=question,
|
||||
embd_mdl=embd_mdl,
|
||||
tenant_ids=tenant_ids,
|
||||
|
||||
@ -340,14 +340,35 @@ class DocumentService(CommonService):
|
||||
def remove_document(cls, doc, tenant_id):
|
||||
from api.db.services.task_service import TaskService
|
||||
cls.clear_chunk_num(doc.id)
|
||||
|
||||
# Delete tasks first
|
||||
try:
|
||||
TaskService.filter_delete([Task.doc_id == doc.id])
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to delete tasks for document {doc.id}: {e}")
|
||||
|
||||
# Delete chunk images (non-critical, log and continue)
|
||||
try:
|
||||
cls.delete_chunk_images(doc, tenant_id)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to delete chunk images for document {doc.id}: {e}")
|
||||
|
||||
# Delete thumbnail (non-critical, log and continue)
|
||||
try:
|
||||
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
|
||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to delete thumbnail for document {doc.id}: {e}")
|
||||
|
||||
# Delete chunks from doc store - this is critical, log errors
|
||||
try:
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to delete chunks from doc store for document {doc.id}: {e}")
|
||||
|
||||
# Cleanup knowledge graph references (non-critical, log and continue)
|
||||
try:
|
||||
graph_source = settings.docStoreConn.get_fields(
|
||||
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||
)
|
||||
@ -360,8 +381,9 @@ class DocumentService(CommonService):
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to cleanup knowledge graph for document {doc.id}: {e}")
|
||||
|
||||
return cls.delete_by_id(doc.id)
|
||||
|
||||
@classmethod
|
||||
@ -423,6 +445,7 @@ class DocumentService(CommonService):
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)),
|
||||
(((cls.model.progress < 1) & (cls.model.progress > 0)) |
|
||||
(cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
|
||||
return list(docs.dicts())
|
||||
@ -753,10 +776,25 @@ class DocumentService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_metadata_summary(cls, kb_id):
|
||||
def get_metadata_summary(cls, kb_id, document_ids=None):
|
||||
def _meta_value_type(value):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
return "list"
|
||||
if isinstance(value, bool):
|
||||
return "string"
|
||||
if isinstance(value, (int, float)):
|
||||
return "number"
|
||||
return "string"
|
||||
|
||||
fields = [cls.model.id, cls.model.meta_fields]
|
||||
summary = {}
|
||||
for r in cls.model.select(*fields).where(cls.model.kb_id == kb_id):
|
||||
type_counter = {}
|
||||
query = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
|
||||
if document_ids:
|
||||
query = query.where(cls.model.id.in_(document_ids))
|
||||
for r in query:
|
||||
meta_fields = r.meta_fields or {}
|
||||
if isinstance(meta_fields, str):
|
||||
try:
|
||||
@ -766,6 +804,11 @@ class DocumentService(CommonService):
|
||||
if not isinstance(meta_fields, dict):
|
||||
continue
|
||||
for k, v in meta_fields.items():
|
||||
value_type = _meta_value_type(v)
|
||||
if value_type:
|
||||
if k not in type_counter:
|
||||
type_counter[k] = {}
|
||||
type_counter[k][value_type] = type_counter[k].get(value_type, 0) + 1
|
||||
values = v if isinstance(v, list) else [v]
|
||||
for vv in values:
|
||||
if not vv:
|
||||
@ -774,11 +817,19 @@ class DocumentService(CommonService):
|
||||
if k not in summary:
|
||||
summary[k] = {}
|
||||
summary[k][sv] = summary[k].get(sv, 0) + 1
|
||||
return {k: sorted([(val, cnt) for val, cnt in v.items()], key=lambda x: x[1], reverse=True) for k, v in summary.items()}
|
||||
result = {}
|
||||
for k, v in summary.items():
|
||||
values = sorted([(val, cnt) for val, cnt in v.items()], key=lambda x: x[1], reverse=True)
|
||||
type_counts = type_counter.get(k, {})
|
||||
value_type = "string"
|
||||
if type_counts:
|
||||
value_type = max(type_counts.items(), key=lambda item: item[1])[0]
|
||||
result[k] = {"type": value_type, "values": values}
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def batch_update_metadata(cls, kb_id, doc_ids, updates=None, deletes=None):
|
||||
def batch_update_metadata(cls, kb_id, doc_ids, updates=None, deletes=None, adds=None):
|
||||
updates = updates or []
|
||||
deletes = deletes or []
|
||||
if not doc_ids:
|
||||
@ -801,11 +852,20 @@ class DocumentService(CommonService):
|
||||
changed = False
|
||||
for upd in updates:
|
||||
key = upd.get("key")
|
||||
if not key or key not in meta:
|
||||
if not key:
|
||||
continue
|
||||
if key not in meta:
|
||||
meta[key] = upd.get("value")
|
||||
|
||||
new_value = upd.get("value")
|
||||
match_provided = "match" in upd
|
||||
if key not in meta:
|
||||
if match_provided:
|
||||
continue
|
||||
meta[key] = dedupe_list(new_value) if isinstance(new_value, list) else new_value
|
||||
changed = True
|
||||
continue
|
||||
|
||||
if isinstance(meta[key], list):
|
||||
if not match_provided:
|
||||
if isinstance(new_value, list):
|
||||
@ -865,7 +925,7 @@ class DocumentService(CommonService):
|
||||
updated_docs = 0
|
||||
with DB.atomic():
|
||||
rows = cls.model.select(cls.model.id, cls.model.meta_fields).where(
|
||||
(cls.model.id.in_(doc_ids)) & (cls.model.kb_id == kb_id)
|
||||
cls.model.id.in_(doc_ids)
|
||||
)
|
||||
for r in rows:
|
||||
meta = _normalize_meta(r.meta_fields or {})
|
||||
@ -914,6 +974,8 @@ class DocumentService(CommonService):
|
||||
bad = 0
|
||||
e, doc = DocumentService.get_by_id(d["id"])
|
||||
status = doc.run # TaskStatus.RUNNING.value
|
||||
if status == TaskStatus.CANCEL.value:
|
||||
continue
|
||||
doc_progress = doc.progress if doc and doc.progress else 0.0
|
||||
special_task_running = False
|
||||
priority = 0
|
||||
@ -957,7 +1019,16 @@ class DocumentService(CommonService):
|
||||
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
else:
|
||||
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
cls.update_by_id(d["id"], info)
|
||||
info["update_time"] = current_timestamp()
|
||||
info["update_date"] = get_format_time()
|
||||
(
|
||||
cls.model.update(info)
|
||||
.where(
|
||||
(cls.model.id == d["id"])
|
||||
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
except Exception as e:
|
||||
if str(e).find("'0'") < 0:
|
||||
logging.exception("fetch task exception")
|
||||
@ -990,7 +1061,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
|
||||
# cancelled: run == "2" but progress can vary
|
||||
# cancelled: run == "2"
|
||||
cancelled = (
|
||||
cls.model.select(fn.COUNT(1))
|
||||
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
|
||||
@ -1245,7 +1316,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not settings.docStoreConn.index_exist(idxnm, kb_id):
|
||||
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]))
|
||||
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
|
||||
@ -225,21 +225,36 @@ class EvaluationService(CommonService):
|
||||
"""
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
case_instances = []
|
||||
|
||||
if not cases:
|
||||
return success_count, failure_count
|
||||
|
||||
cur_timestamp = current_timestamp()
|
||||
|
||||
for case_data in cases:
|
||||
success, _ = cls.add_test_case(
|
||||
dataset_id=dataset_id,
|
||||
question=case_data.get("question", ""),
|
||||
reference_answer=case_data.get("reference_answer"),
|
||||
relevant_doc_ids=case_data.get("relevant_doc_ids"),
|
||||
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
|
||||
metadata=case_data.get("metadata")
|
||||
)
|
||||
try:
|
||||
for case_data in cases:
|
||||
case_id = get_uuid()
|
||||
case_info = {
|
||||
"id": case_id,
|
||||
"dataset_id": dataset_id,
|
||||
"question": case_data.get("question", ""),
|
||||
"reference_answer": case_data.get("reference_answer"),
|
||||
"relevant_doc_ids": case_data.get("relevant_doc_ids"),
|
||||
"relevant_chunk_ids": case_data.get("relevant_chunk_ids"),
|
||||
"metadata": case_data.get("metadata"),
|
||||
"create_time": cur_timestamp
|
||||
}
|
||||
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
failure_count += 1
|
||||
case_instances.append(EvaluationCase(**case_info))
|
||||
EvaluationCase.bulk_create(case_instances, batch_size=300)
|
||||
success_count = len(case_instances)
|
||||
failure_count = 0
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error bulk importing test cases: {str(e)}")
|
||||
failure_count = len(cases)
|
||||
success_count = 0
|
||||
|
||||
return success_count, failure_count
|
||||
|
||||
|
||||
@ -439,6 +439,15 @@ class FileService(CommonService):
|
||||
|
||||
err, files = [], []
|
||||
for file in file_objs:
|
||||
doc_id = file.id if hasattr(file, "id") else get_uuid()
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if e:
|
||||
blob = file.read()
|
||||
settings.STORAGE_IMPL.put(kb.id, doc.location, blob, kb.tenant_id)
|
||||
doc.size = len(blob)
|
||||
doc = doc.to_dict()
|
||||
DocumentService.update_by_id(doc["id"], doc)
|
||||
continue
|
||||
try:
|
||||
DocumentService.check_doc_health(kb.tenant_id, file.filename)
|
||||
filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id)
|
||||
@ -455,7 +464,6 @@ class FileService(CommonService):
|
||||
blob = read_potential_broken_pdf(blob)
|
||||
settings.STORAGE_IMPL.put(kb.id, location, blob)
|
||||
|
||||
doc_id = get_uuid()
|
||||
|
||||
img = thumbnail_img(filename, blob)
|
||||
thumbnail_location = ""
|
||||
|
||||
@ -397,7 +397,7 @@ class KnowledgebaseService(CommonService):
|
||||
if dataset_name == "":
|
||||
return False, get_data_error_result(message="Dataset name can't be empty.")
|
||||
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
|
||||
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is large than {DATASET_NAME_LIMIT}")
|
||||
|
||||
# Deduplicate name within tenant
|
||||
dataset_name = duplicate_name(
|
||||
|
||||
@ -441,3 +441,46 @@ class LLMBundle(LLM4Tenant):
|
||||
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||
generation.end()
|
||||
return
|
||||
|
||||
async def async_chat_streamly_delta(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||
total_tokens = 0
|
||||
ans = ""
|
||||
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
|
||||
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
||||
elif hasattr(self.mdl, "async_chat_streamly"):
|
||||
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
||||
else:
|
||||
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||
|
||||
if stream_fn:
|
||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
try:
|
||||
async for txt in chat_partial(**use_kwargs):
|
||||
if isinstance(txt, int):
|
||||
total_tokens = txt
|
||||
break
|
||||
|
||||
if txt.endswith("</think>"):
|
||||
ans = ans[: -len("</think>")]
|
||||
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
ans += txt
|
||||
yield txt
|
||||
except Exception as e:
|
||||
if generation:
|
||||
generation.update(output={"error": str(e)})
|
||||
generation.end()
|
||||
raise
|
||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
if generation:
|
||||
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||
generation.end()
|
||||
return
|
||||
|
||||
@ -167,4 +167,4 @@ class MemoryService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_memory(cls, memory_id: str):
|
||||
return cls.model.delete().where(cls.model.id == memory_id).execute()
|
||||
return cls.delete_by_id(memory_id)
|
||||
|
||||
44
api/db/services/system_settings_service.py
Normal file
44
api/db/services/system_settings_service.py
Normal file
@ -0,0 +1,44 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
from api.db.db_models import DB
|
||||
from api.db.db_models import SystemSettings
|
||||
from api.db.services.common_service import CommonService
|
||||
|
||||
|
||||
class SystemSettingsService(CommonService):
|
||||
model = SystemSettings
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_name(cls, name):
|
||||
objs = cls.model.select().where(cls.model.name.startswith(name))
|
||||
return objs
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_by_name(cls, name, obj):
|
||||
obj["update_time"] = current_timestamp()
|
||||
obj["update_date"] = datetime_format(datetime.now())
|
||||
cls.model.update(obj).where(cls.model.name.startswith(name)).execute()
|
||||
return SystemSettings(**obj)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_record_count(cls):
|
||||
count = cls.model.select().count()
|
||||
return count
|
||||
@ -121,13 +121,6 @@ class TaskService(CommonService):
|
||||
.where(cls.model.id == task_id)
|
||||
)
|
||||
docs = list(docs.dicts())
|
||||
# Assuming docs = list(docs.dicts())
|
||||
if docs:
|
||||
kb_config = docs[0]['kb_parser_config'] # Dict from Knowledgebase.parser_config
|
||||
mineru_method = kb_config.get('mineru_parse_method', 'auto')
|
||||
mineru_formula = kb_config.get('mineru_formula_enable', True)
|
||||
mineru_table = kb_config.get('mineru_table_enable', True)
|
||||
print(mineru_method, mineru_formula, mineru_table)
|
||||
if not docs:
|
||||
return None
|
||||
|
||||
@ -179,6 +172,40 @@ class TaskService(CommonService):
|
||||
return None
|
||||
return tasks
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tasks_progress_by_doc_ids(cls, doc_ids: list[str]):
|
||||
"""Retrieve all tasks associated with specific documents.
|
||||
|
||||
This method fetches all processing tasks for given document ids, ordered by
|
||||
creation time. It includes task progress and chunk information.
|
||||
|
||||
Args:
|
||||
doc_ids (str): The unique identifier of the document.
|
||||
|
||||
Returns:
|
||||
list[dict]: List of task dictionaries containing task details.
|
||||
Returns None if no tasks are found.
|
||||
"""
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.doc_id,
|
||||
cls.model.from_page,
|
||||
cls.model.progress,
|
||||
cls.model.progress_msg,
|
||||
cls.model.digest,
|
||||
cls.model.chunk_ids,
|
||||
cls.model.create_time
|
||||
]
|
||||
tasks = (
|
||||
cls.model.select(*fields).order_by(cls.model.create_time.desc())
|
||||
.where(cls.model.doc_id.in_(doc_ids))
|
||||
)
|
||||
tasks = list(tasks.dicts())
|
||||
if not tasks:
|
||||
return None
|
||||
return tasks
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_chunk_ids(cls, id: str, chunk_ids: str):
|
||||
@ -495,6 +522,7 @@ def cancel_all_task_of(doc_id):
|
||||
def has_canceled(task_id):
|
||||
try:
|
||||
if REDIS_CONN.get(f"{task_id}-cancel"):
|
||||
logging.info(f"Task: {task_id} has been canceled")
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@ -19,7 +19,7 @@ import logging
|
||||
from peewee import IntegrityError
|
||||
from langfuse import Langfuse
|
||||
from common import settings
|
||||
from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, LLMType
|
||||
from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, PADDLEOCR_DEFAULT_CONFIG, PADDLEOCR_ENV_KEYS, LLMType
|
||||
from api.db.db_models import DB, LLMFactories, TenantLLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
@ -60,10 +60,8 @@ class TenantLLMService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name,
|
||||
cls.model.used_tokens, cls.model.status]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
|
||||
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens, cls.model.status]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
|
||||
return list(objs)
|
||||
|
||||
@ -90,6 +88,7 @@ class TenantLLMService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
||||
from api.db.services.llm_service import LLMService
|
||||
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
raise LookupError("Tenant not found")
|
||||
@ -119,9 +118,9 @@ class TenantLLMService(CommonService):
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
if model_config:
|
||||
model_config = model_config.to_dict()
|
||||
elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''):
|
||||
elif llm_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv("TEI_MODEL", ""):
|
||||
embedding_cfg = settings.EMBEDDING_CFG
|
||||
model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
|
||||
model_config = {"llm_factory": "Builtin", "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
|
||||
else:
|
||||
raise LookupError(f"Model({mdlnm}@{fid}) not authorized")
|
||||
|
||||
@ -140,33 +139,27 @@ class TenantLLMService(CommonService):
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
return None
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||
base_url=model_config["api_base"])
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
elif llm_type == LLMType.RERANK:
|
||||
if model_config["llm_factory"] not in RerankModel:
|
||||
return None
|
||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||
base_url=model_config["api_base"])
|
||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return None
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
|
||||
base_url=model_config["api_base"], **kwargs)
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
if model_config["llm_factory"] not in ChatModel:
|
||||
return None
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||
base_url=model_config["api_base"], **kwargs)
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
elif llm_type == LLMType.SPEECH2TEXT:
|
||||
if model_config["llm_factory"] not in Seq2txtModel:
|
||||
return None
|
||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
|
||||
model_name=model_config["llm_name"], lang=lang,
|
||||
base_url=model_config["api_base"])
|
||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
||||
elif llm_type == LLMType.TTS:
|
||||
if model_config["llm_factory"] not in TTSModel:
|
||||
return None
|
||||
@ -216,14 +209,11 @@ class TenantLLMService(CommonService):
|
||||
try:
|
||||
num = (
|
||||
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name,
|
||||
cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||
.execute()
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s",
|
||||
tenant_id, llm_name)
|
||||
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
||||
return 0
|
||||
|
||||
return num
|
||||
@ -231,9 +221,7 @@ class TenantLLMService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_openai_models(cls):
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"),
|
||||
~(cls.model.llm_name == "text-embedding-3-small"),
|
||||
~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
return list(objs)
|
||||
|
||||
@classmethod
|
||||
@ -298,6 +286,68 @@ class TenantLLMService(CommonService):
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
@classmethod
|
||||
def _collect_paddleocr_env_config(cls) -> dict | None:
|
||||
cfg = PADDLEOCR_DEFAULT_CONFIG
|
||||
found = False
|
||||
for key in PADDLEOCR_ENV_KEYS:
|
||||
val = os.environ.get(key)
|
||||
if val:
|
||||
found = True
|
||||
cfg[key] = val
|
||||
return cfg if found else None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def ensure_paddleocr_from_env(cls, tenant_id: str) -> str | None:
|
||||
"""
|
||||
Ensure a PaddleOCR model exists for the tenant if env variables are present.
|
||||
Return the existing or newly created llm_name, or None if env not set.
|
||||
"""
|
||||
cfg = cls._collect_paddleocr_env_config()
|
||||
if not cfg:
|
||||
return None
|
||||
|
||||
saved_paddleocr_models = cls.query(tenant_id=tenant_id, llm_factory="PaddleOCR", model_type=LLMType.OCR.value)
|
||||
|
||||
def _parse_api_key(raw: str) -> dict:
|
||||
try:
|
||||
return json.loads(raw or "{}")
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
for item in saved_paddleocr_models:
|
||||
api_cfg = _parse_api_key(item.api_key)
|
||||
normalized = {k: api_cfg.get(k, PADDLEOCR_DEFAULT_CONFIG.get(k)) for k in PADDLEOCR_ENV_KEYS}
|
||||
if normalized == cfg:
|
||||
return item.llm_name
|
||||
|
||||
used_names = {item.llm_name for item in saved_paddleocr_models}
|
||||
idx = 1
|
||||
base_name = "paddleocr-from-env"
|
||||
while True:
|
||||
candidate = f"{base_name}-{idx}"
|
||||
if candidate in used_names:
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
cls.save(
|
||||
tenant_id=tenant_id,
|
||||
llm_factory="PaddleOCR",
|
||||
llm_name=candidate,
|
||||
model_type=LLMType.OCR.value,
|
||||
api_key=json.dumps(cfg),
|
||||
api_base="",
|
||||
max_tokens=0,
|
||||
)
|
||||
return candidate
|
||||
except IntegrityError:
|
||||
logging.warning("PaddleOCR env model %s already exists for tenant %s, retry with next name", candidate, tenant_id)
|
||||
used_names.add(candidate)
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
@ -306,6 +356,7 @@ class TenantLLMService(CommonService):
|
||||
@staticmethod
|
||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||
from api.db.services.llm_service import LLMService
|
||||
|
||||
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||
llm_factories = settings.FACTORY_LLM_INFOS
|
||||
for llm_factory in llm_factories:
|
||||
@ -340,9 +391,12 @@ class LLM4Tenant:
|
||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
||||
self.langfuse = None
|
||||
if langfuse_keys:
|
||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key,
|
||||
host=langfuse_keys.host)
|
||||
if langfuse.auth_check():
|
||||
self.langfuse = langfuse
|
||||
trace_id = self.langfuse.create_trace_id()
|
||||
self.trace_context = {"trace_id": trace_id}
|
||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||||
try:
|
||||
if langfuse.auth_check():
|
||||
self.langfuse = langfuse
|
||||
trace_id = self.langfuse.create_trace_id()
|
||||
self.trace_context = {"trace_id": trace_id}
|
||||
except Exception:
|
||||
# Skip langfuse tracing if connection fails
|
||||
pass
|
||||
|
||||
@ -18,8 +18,8 @@
|
||||
# from beartype.claw import beartype_all # <-- you didn't sign up for this
|
||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||
|
||||
from common.log_utils import init_root_logger
|
||||
from plugin import GlobalPluginManager
|
||||
import time
|
||||
start_ts = time.time()
|
||||
|
||||
import logging
|
||||
import os
|
||||
@ -40,6 +40,8 @@ from api.db.init_data import init_web_data, init_superuser
|
||||
from common.versions import get_ragflow_version
|
||||
from common.config_utils import show_configs
|
||||
from common.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||
from common.log_utils import init_root_logger
|
||||
from plugin import GlobalPluginManager
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
|
||||
stop_event = threading.Event()
|
||||
@ -145,7 +147,7 @@ if __name__ == '__main__':
|
||||
|
||||
# start http server
|
||||
try:
|
||||
logging.info("RAGFlow HTTP server start...")
|
||||
logging.info(f"RAGFlow server is ready after {time.time() - start_ts}s initialization.")
|
||||
app.run(host=settings.HOST_IP, port=settings.HOST_PORT)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
@ -29,8 +29,15 @@ import requests
|
||||
from quart import (
|
||||
Response,
|
||||
jsonify,
|
||||
request
|
||||
request,
|
||||
has_app_context,
|
||||
)
|
||||
from werkzeug.exceptions import BadRequest as WerkzeugBadRequest
|
||||
|
||||
try:
|
||||
from quart.exceptions import BadRequest as QuartBadRequest
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
QuartBadRequest = None
|
||||
|
||||
from peewee import OperationalError
|
||||
|
||||
@ -42,41 +49,45 @@ from api.db.services.tenant_llm_service import LLMFactoriesService
|
||||
from common.connection_utils import timeout
|
||||
from common.constants import RetCode
|
||||
from common import settings
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
|
||||
def _safe_jsonify(payload: dict):
|
||||
if has_app_context():
|
||||
return jsonify(payload)
|
||||
return payload
|
||||
|
||||
|
||||
async def _coerce_request_data() -> dict:
|
||||
"""Fetch JSON body with sane defaults; fallback to form data."""
|
||||
if hasattr(request, "_cached_payload"):
|
||||
return request._cached_payload
|
||||
payload: Any = None
|
||||
last_error: Exception | None = None
|
||||
|
||||
try:
|
||||
payload = await request.get_json(force=True, silent=True)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
payload = None
|
||||
body_bytes = await request.get_data()
|
||||
has_body = bool(body_bytes)
|
||||
content_type = (request.content_type or "").lower()
|
||||
is_json = content_type.startswith("application/json")
|
||||
|
||||
if payload is None:
|
||||
try:
|
||||
form = await request.form
|
||||
payload = form.to_dict()
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
payload = None
|
||||
if not has_body:
|
||||
payload = {}
|
||||
elif is_json:
|
||||
payload = await request.get_json(force=False, silent=False)
|
||||
if isinstance(payload, dict):
|
||||
payload = payload or {}
|
||||
elif isinstance(payload, str):
|
||||
raise AttributeError("'str' object has no attribute 'get'")
|
||||
else:
|
||||
raise TypeError("JSON payload must be an object.")
|
||||
else:
|
||||
form = await request.form
|
||||
payload = form.to_dict() if form else None
|
||||
if payload is None:
|
||||
raise TypeError("Request body is not a valid form payload.")
|
||||
|
||||
if payload is None:
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise ValueError("No JSON body or form data found in request.")
|
||||
|
||||
if isinstance(payload, dict):
|
||||
return payload or {}
|
||||
|
||||
if isinstance(payload, str):
|
||||
raise AttributeError("'str' object has no attribute 'get'")
|
||||
|
||||
raise TypeError(f"Unsupported request payload type: {type(payload)!r}")
|
||||
request._cached_payload = payload
|
||||
return payload
|
||||
|
||||
async def get_request_json():
|
||||
return await _coerce_request_data()
|
||||
@ -115,7 +126,7 @@ def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!
|
||||
continue
|
||||
else:
|
||||
response[key] = value
|
||||
return jsonify(response)
|
||||
return _safe_jsonify(response)
|
||||
|
||||
|
||||
def server_error_response(e):
|
||||
@ -124,16 +135,12 @@ def server_error_response(e):
|
||||
try:
|
||||
msg = repr(e).lower()
|
||||
if getattr(e, "code", None) == 401 or ("unauthorized" in msg) or ("401" in msg):
|
||||
return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
|
||||
resp = get_json_result(code=RetCode.UNAUTHORIZED, message="Unauthorized")
|
||||
resp.status_code = RetCode.UNAUTHORIZED
|
||||
return resp
|
||||
except Exception as ex:
|
||||
logging.warning(f"error checking authorization: {ex}")
|
||||
|
||||
if len(e.args) > 1:
|
||||
try:
|
||||
serialized_data = serialize_for_json(e.args[1])
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data)
|
||||
except Exception:
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None)
|
||||
if repr(e).find("index_not_found_exception") >= 0:
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
|
||||
|
||||
@ -168,7 +175,17 @@ def validate_request(*args, **kwargs):
|
||||
def wrapper(func):
|
||||
@wraps(func)
|
||||
async def decorated_function(*_args, **_kwargs):
|
||||
errs = process_args(await _coerce_request_data())
|
||||
exception_types = (AttributeError, TypeError, WerkzeugBadRequest)
|
||||
if QuartBadRequest is not None:
|
||||
exception_types = exception_types + (QuartBadRequest,)
|
||||
if args or kwargs:
|
||||
try:
|
||||
input_arguments = await _coerce_request_data()
|
||||
except exception_types:
|
||||
input_arguments = {}
|
||||
else:
|
||||
input_arguments = await _coerce_request_data()
|
||||
errs = process_args(input_arguments)
|
||||
if errs:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
|
||||
if inspect.iscoroutinefunction(func):
|
||||
@ -215,7 +232,7 @@ def active_required(func):
|
||||
|
||||
def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
|
||||
response = {"code": code, "message": message, "data": data}
|
||||
return jsonify(response)
|
||||
return _safe_jsonify(response)
|
||||
|
||||
|
||||
def apikey_required(func):
|
||||
@ -236,16 +253,16 @@ def apikey_required(func):
|
||||
|
||||
def build_error_result(code=RetCode.FORBIDDEN, message="success"):
|
||||
response = {"code": code, "message": message}
|
||||
response = jsonify(response)
|
||||
response.status_code = code
|
||||
response = _safe_jsonify(response)
|
||||
if hasattr(response, "status_code"):
|
||||
response.status_code = code
|
||||
return response
|
||||
|
||||
|
||||
def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
|
||||
if data is None:
|
||||
return jsonify({"code": code, "message": message})
|
||||
else:
|
||||
return jsonify({"code": code, "message": message, "data": data})
|
||||
return _safe_jsonify({"code": code, "message": message})
|
||||
return _safe_jsonify({"code": code, "message": message, "data": data})
|
||||
|
||||
|
||||
def token_required(func):
|
||||
@ -304,7 +321,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None, total=None):
|
||||
else:
|
||||
response["message"] = message or "Error"
|
||||
|
||||
return jsonify(response)
|
||||
return _safe_jsonify(response)
|
||||
|
||||
|
||||
def get_error_data_result(
|
||||
@ -318,7 +335,7 @@ def get_error_data_result(
|
||||
continue
|
||||
else:
|
||||
response[key] = value
|
||||
return jsonify(response)
|
||||
return _safe_jsonify(response)
|
||||
|
||||
|
||||
def get_error_argument_result(message="Invalid arguments"):
|
||||
@ -683,7 +700,7 @@ async def is_strong_enough(chat_model, embedding_model):
|
||||
nonlocal chat_model, embedding_model
|
||||
if embedding_model:
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]),
|
||||
thread_pool_exec(embedding_model.encode, ["Are you strong enough!?"]),
|
||||
timeout=10
|
||||
)
|
||||
|
||||
|
||||
@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import xxhash
|
||||
|
||||
|
||||
def string_to_bytes(string):
|
||||
return string if isinstance(
|
||||
@ -22,3 +24,6 @@ def string_to_bytes(string):
|
||||
def bytes_to_string(byte):
|
||||
return byte.decode(encoding="utf-8")
|
||||
|
||||
# 128 bit = 32 character
|
||||
def hash128(data: str) -> str:
|
||||
return xxhash.xxh128(data).hexdigest()
|
||||
|
||||
@ -24,7 +24,7 @@ from common.file_utils import get_project_base_directory
|
||||
|
||||
def crypt(line):
|
||||
"""
|
||||
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
|
||||
decrypt(crypt(input_string)) == base64(input_string), which frontend and ragflow_cli use.
|
||||
"""
|
||||
file_path = os.path.join(get_project_base_directory(), "conf", "public.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||
|
||||
@ -82,6 +82,8 @@ async def validate_and_parse_json_request(request: Request, validator: type[Base
|
||||
2. Extra fields added via `extras` parameter are automatically removed
|
||||
from the final output after validation
|
||||
"""
|
||||
if request.mimetype != "application/json":
|
||||
return None, f"Unsupported content type: Expected application/json, got {request.content_type}"
|
||||
try:
|
||||
payload = await request.get_json() or {}
|
||||
except UnsupportedMediaType:
|
||||
|
||||
@ -86,6 +86,9 @@ CONTENT_TYPE_MAP = {
|
||||
"ico": "image/x-icon",
|
||||
"avif": "image/avif",
|
||||
"heic": "image/heic",
|
||||
# PPTX
|
||||
"ppt": "application/vnd.ms-powerpoint",
|
||||
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
}
|
||||
|
||||
|
||||
@ -239,4 +242,4 @@ def hash_code(code: str, salt: bytes) -> str:
|
||||
|
||||
def captcha_key(email: str) -> str:
|
||||
return f"captcha:{email}"
|
||||
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ from strenum import StrEnum
|
||||
SERVICE_CONF = "service_conf.yaml"
|
||||
RAG_FLOW_SERVICE_NAME = "ragflow"
|
||||
|
||||
|
||||
class CustomEnum(Enum):
|
||||
@classmethod
|
||||
def valid(cls, value):
|
||||
@ -68,13 +69,13 @@ class ActiveEnum(Enum):
|
||||
|
||||
|
||||
class LLMType(StrEnum):
|
||||
CHAT = 'chat'
|
||||
EMBEDDING = 'embedding'
|
||||
SPEECH2TEXT = 'speech2text'
|
||||
IMAGE2TEXT = 'image2text'
|
||||
RERANK = 'rerank'
|
||||
TTS = 'tts'
|
||||
OCR = 'ocr'
|
||||
CHAT = "chat"
|
||||
EMBEDDING = "embedding"
|
||||
SPEECH2TEXT = "speech2text"
|
||||
IMAGE2TEXT = "image2text"
|
||||
RERANK = "rerank"
|
||||
TTS = "tts"
|
||||
OCR = "ocr"
|
||||
|
||||
|
||||
class TaskStatus(StrEnum):
|
||||
@ -86,8 +87,7 @@ class TaskStatus(StrEnum):
|
||||
SCHEDULE = "5"
|
||||
|
||||
|
||||
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL,
|
||||
TaskStatus.SCHEDULE}
|
||||
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL, TaskStatus.SCHEDULE}
|
||||
|
||||
|
||||
class ParserType(StrEnum):
|
||||
@ -136,6 +136,7 @@ class FileSource(StrEnum):
|
||||
BITBUCKET = "bitbucket"
|
||||
ZENDESK = "zendesk"
|
||||
|
||||
|
||||
class PipelineTaskType(StrEnum):
|
||||
PARSE = "Parse"
|
||||
DOWNLOAD = "Download"
|
||||
@ -145,15 +146,17 @@ class PipelineTaskType(StrEnum):
|
||||
MEMORY = "Memory"
|
||||
|
||||
|
||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR,
|
||||
PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
|
||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
|
||||
|
||||
|
||||
class MCPServerType(StrEnum):
|
||||
SSE = "sse"
|
||||
STREAMABLE_HTTP = "streamable-http"
|
||||
|
||||
|
||||
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
||||
|
||||
|
||||
class Storage(Enum):
|
||||
MINIO = 1
|
||||
AZURE_SPN = 2
|
||||
@ -165,10 +168,10 @@ class Storage(Enum):
|
||||
|
||||
|
||||
class MemoryType(Enum):
|
||||
RAW = 0b0001 # 1 << 0 = 1 (0b00000001)
|
||||
SEMANTIC = 0b0010 # 1 << 1 = 2 (0b00000010)
|
||||
EPISODIC = 0b0100 # 1 << 2 = 4 (0b00000100)
|
||||
PROCEDURAL = 0b1000 # 1 << 3 = 8 (0b00001000)
|
||||
RAW = 0b0001 # 1 << 0 = 1 (0b00000001)
|
||||
SEMANTIC = 0b0010 # 1 << 1 = 2 (0b00000010)
|
||||
EPISODIC = 0b0100 # 1 << 2 = 4 (0b00000100)
|
||||
PROCEDURAL = 0b1000 # 1 << 3 = 8 (0b00001000)
|
||||
|
||||
|
||||
class MemoryStorageType(StrEnum):
|
||||
@ -239,3 +242,10 @@ MINERU_DEFAULT_CONFIG = {
|
||||
"MINERU_SERVER_URL": "",
|
||||
"MINERU_DELETE_OUTPUT": 1,
|
||||
}
|
||||
|
||||
PADDLEOCR_ENV_KEYS = ["PADDLEOCR_API_URL", "PADDLEOCR_ACCESS_TOKEN", "PADDLEOCR_ALGORITHM"]
|
||||
PADDLEOCR_DEFAULT_CONFIG = {
|
||||
"PADDLEOCR_API_URL": "",
|
||||
"PADDLEOCR_ACCESS_TOKEN": None,
|
||||
"PADDLEOCR_ALGORITHM": "PaddleOCR-VL",
|
||||
}
|
||||
|
||||
@ -75,7 +75,6 @@ class AirtableConnector(LoadConnector, PollConnector):
|
||||
batch: list[Document] = []
|
||||
|
||||
for record in records:
|
||||
print(record)
|
||||
record_id = record.get("id")
|
||||
fields = record.get("fields", {})
|
||||
created_time = record.get("createdTime")
|
||||
|
||||
@ -12,6 +12,7 @@ from email.utils import collapse_rfc2231_value, parseaddr
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
import uuid
|
||||
|
||||
import bs4
|
||||
from pydantic import BaseModel
|
||||
@ -635,7 +636,6 @@ def _parse_singular_addr(raw_header: str) -> tuple[str, str]:
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
import uuid
|
||||
from types import TracebackType
|
||||
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
|
||||
@ -164,7 +164,7 @@ class DocStoreConnection(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
|
||||
"""
|
||||
Create an index with given name
|
||||
"""
|
||||
|
||||
@ -123,7 +123,8 @@ class ESConnectionBase(DocStoreConnection):
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
|
||||
# parser_id is used by Infinity but not needed for ES (kept for interface compatibility)
|
||||
if self.index_exist(index_name, dataset_id):
|
||||
return True
|
||||
try:
|
||||
|
||||
@ -228,15 +228,26 @@ class InfinityConnectionBase(DocStoreConnection):
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
|
||||
table_name = f"{index_name}_{dataset_id}"
|
||||
self.logger.debug(f"CREATE_IDX: Creating table {table_name}, parser_id: {parser_id}")
|
||||
|
||||
inf_conn = self.connPool.get_conn()
|
||||
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
||||
|
||||
# Use configured schema
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
|
||||
if not os.path.exists(fp_mapping):
|
||||
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||
schema = json.load(open(fp_mapping))
|
||||
|
||||
if parser_id is not None:
|
||||
from common.constants import ParserType
|
||||
if parser_id == ParserType.TABLE.value:
|
||||
# Table parser: add chunk_data JSON column to store table-specific fields
|
||||
schema["chunk_data"] = {"type": "json", "default": "{}"}
|
||||
self.logger.info("Added chunk_data column for TABLE parser")
|
||||
|
||||
vector_name = f"q_{vector_size}_vec"
|
||||
schema[vector_name] = {"type": f"vector,{vector_size},float"}
|
||||
inf_table = inf_db.create_table(
|
||||
@ -367,7 +378,10 @@ class InfinityConnectionBase(DocStoreConnection):
|
||||
num_rows = len(res)
|
||||
column_id = res["id"]
|
||||
if field_name not in res:
|
||||
return {}
|
||||
if field_name == "content_with_weight" and "content" in res:
|
||||
field_name = "content"
|
||||
else:
|
||||
return {}
|
||||
for i in range(num_rows):
|
||||
id = column_id[i]
|
||||
txt = res[field_name][i]
|
||||
@ -450,4 +464,198 @@ class InfinityConnectionBase(DocStoreConnection):
|
||||
"""
|
||||
|
||||
def sql(self, sql: str, fetch_size: int, format: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
"""
|
||||
Execute SQL query on Infinity database via psql command.
|
||||
Transform text-to-sql for Infinity's SQL syntax.
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
self.logger.debug(f"InfinityConnection.sql get sql: {sql}")
|
||||
|
||||
# Clean up SQL
|
||||
sql = re.sub(r"[ `]+", " ", sql)
|
||||
sql = sql.replace("%", "")
|
||||
|
||||
# Transform SELECT field aliases to actual stored field names
|
||||
# Build field mapping from infinity_mapping.json comment field
|
||||
field_mapping = {}
|
||||
# Also build reverse mapping for column names in result
|
||||
reverse_mapping = {}
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
|
||||
if os.path.exists(fp_mapping):
|
||||
schema = json.load(open(fp_mapping))
|
||||
for field_name, field_info in schema.items():
|
||||
if "comment" in field_info:
|
||||
# Parse comma-separated aliases from comment
|
||||
# e.g., "docnm_kwd, title_tks, title_sm_tks"
|
||||
aliases = [a.strip() for a in field_info["comment"].split(",")]
|
||||
for alias in aliases:
|
||||
field_mapping[alias] = field_name
|
||||
reverse_mapping[field_name] = alias # Store first alias for reverse mapping
|
||||
|
||||
# Replace field names in SELECT clause
|
||||
select_match = re.search(r"(select\s+.*?)(from\s+)", sql, re.IGNORECASE)
|
||||
if select_match:
|
||||
select_clause = select_match.group(1)
|
||||
from_clause = select_match.group(2)
|
||||
|
||||
# Apply field transformations
|
||||
for alias, actual in field_mapping.items():
|
||||
select_clause = re.sub(
|
||||
rf'(^|[, ]){alias}([, ]|$)',
|
||||
rf'\1{actual}\2',
|
||||
select_clause
|
||||
)
|
||||
|
||||
sql = select_clause + from_clause + sql[select_match.end():]
|
||||
|
||||
# Also replace field names in WHERE, ORDER BY, GROUP BY, and HAVING clauses
|
||||
for alias, actual in field_mapping.items():
|
||||
# Transform in WHERE clause
|
||||
sql = re.sub(
|
||||
rf'(\bwhere\s+[^;]*?)(\b){re.escape(alias)}\b',
|
||||
rf'\1{actual}',
|
||||
sql,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
# Transform in ORDER BY clause
|
||||
sql = re.sub(
|
||||
rf'(\border by\s+[^;]*?)(\b){re.escape(alias)}\b',
|
||||
rf'\1{actual}',
|
||||
sql,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
# Transform in GROUP BY clause
|
||||
sql = re.sub(
|
||||
rf'(\bgroup by\s+[^;]*?)(\b){re.escape(alias)}\b',
|
||||
rf'\1{actual}',
|
||||
sql,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
# Transform in HAVING clause
|
||||
sql = re.sub(
|
||||
rf'(\bhaving\s+[^;]*?)(\b){re.escape(alias)}\b',
|
||||
rf'\1{actual}',
|
||||
sql,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
self.logger.debug(f"InfinityConnection.sql to execute: {sql}")
|
||||
|
||||
# Get connection parameters from the Infinity connection pool wrapper
|
||||
# We need to use INFINITY_CONN singleton, not the raw ConnectionPool
|
||||
from common.doc_store.infinity_conn_pool import INFINITY_CONN
|
||||
conn_info = INFINITY_CONN.get_conn_uri()
|
||||
|
||||
# Parse host and port from conn_info
|
||||
if conn_info and "host=" in conn_info:
|
||||
host_match = re.search(r"host=(\S+)", conn_info)
|
||||
if host_match:
|
||||
host = host_match.group(1)
|
||||
else:
|
||||
host = "infinity"
|
||||
else:
|
||||
host = "infinity"
|
||||
|
||||
# Parse port from conn_info, default to 5432 if not found
|
||||
if conn_info and "port=" in conn_info:
|
||||
port_match = re.search(r"port=(\d+)", conn_info)
|
||||
if port_match:
|
||||
port = port_match.group(1)
|
||||
else:
|
||||
port = "5432"
|
||||
else:
|
||||
port = "5432"
|
||||
|
||||
# Use psql command to execute SQL
|
||||
# Use full path to psql to avoid PATH issues
|
||||
psql_path = "/usr/bin/psql"
|
||||
# Check if psql exists at expected location, otherwise try to find it
|
||||
import shutil
|
||||
psql_from_path = shutil.which("psql")
|
||||
if psql_from_path:
|
||||
psql_path = psql_from_path
|
||||
|
||||
# Execute SQL with psql to get both column names and data in one call
|
||||
psql_cmd = [
|
||||
psql_path,
|
||||
"-h", host,
|
||||
"-p", port,
|
||||
"-c", sql,
|
||||
]
|
||||
|
||||
self.logger.debug(f"Executing psql command: {' '.join(psql_cmd)}")
|
||||
|
||||
result = subprocess.run(
|
||||
psql_cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10 # 10 second timeout
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
error_msg = result.stderr.strip()
|
||||
raise Exception(f"psql command failed: {error_msg}\nSQL: {sql}")
|
||||
|
||||
# Parse the output
|
||||
output = result.stdout.strip()
|
||||
if not output:
|
||||
# No results
|
||||
return {
|
||||
"columns": [],
|
||||
"rows": []
|
||||
} if format == "json" else []
|
||||
|
||||
# Parse psql table output which has format:
|
||||
# col1 | col2 | col3
|
||||
# -----+-----+-----
|
||||
# val1 | val2 | val3
|
||||
lines = output.split("\n")
|
||||
|
||||
# Extract column names from first line
|
||||
columns = []
|
||||
rows = []
|
||||
|
||||
if len(lines) >= 1:
|
||||
header_line = lines[0]
|
||||
for col_name in header_line.split("|"):
|
||||
col_name = col_name.strip()
|
||||
if col_name:
|
||||
columns.append({"name": col_name})
|
||||
|
||||
# Data starts after the separator line (line with dashes)
|
||||
data_start = 2 if len(lines) >= 2 and "-" in lines[1] else 1
|
||||
for i in range(data_start, len(lines)):
|
||||
line = lines[i].strip()
|
||||
# Skip empty lines and footer lines like "(1 row)"
|
||||
if not line or re.match(r"^\(\d+ row", line):
|
||||
continue
|
||||
# Split by | and strip each cell
|
||||
row = [cell.strip() for cell in line.split("|")]
|
||||
# Ensure row matches column count
|
||||
if len(row) == len(columns):
|
||||
rows.append(row)
|
||||
elif len(row) > len(columns):
|
||||
# Row has more cells than columns - truncate
|
||||
rows.append(row[:len(columns)])
|
||||
elif len(row) < len(columns):
|
||||
# Row has fewer cells - pad with empty strings
|
||||
rows.append(row + [""] * (len(columns) - len(row)))
|
||||
|
||||
if format == "json":
|
||||
result = {
|
||||
"columns": columns,
|
||||
"rows": rows[:fetch_size] if fetch_size > 0 else rows
|
||||
}
|
||||
else:
|
||||
result = rows[:fetch_size] if fetch_size > 0 else rows
|
||||
|
||||
return result
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
self.logger.exception(f"InfinityConnection.sql timeout. SQL:\n{sql}")
|
||||
raise Exception(f"SQL timeout\n\nSQL: {sql}")
|
||||
except Exception as e:
|
||||
self.logger.exception(f"InfinityConnection.sql got exception. SQL:\n{sql}")
|
||||
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
|
||||
|
||||
@ -31,7 +31,11 @@ class InfinityConnectionPool:
|
||||
if hasattr(settings, "INFINITY"):
|
||||
self.INFINITY_CONFIG = settings.INFINITY
|
||||
else:
|
||||
self.INFINITY_CONFIG = settings.get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
self.INFINITY_CONFIG = settings.get_base_config("infinity", {
|
||||
"uri": "infinity:23817",
|
||||
"postgres_port": 5432,
|
||||
"db_name": "default_db"
|
||||
})
|
||||
|
||||
infinity_uri = self.INFINITY_CONFIG["uri"]
|
||||
if ":" in infinity_uri:
|
||||
@ -61,6 +65,19 @@ class InfinityConnectionPool:
|
||||
def get_conn_pool(self):
|
||||
return self.conn_pool
|
||||
|
||||
def get_conn_uri(self):
|
||||
"""
|
||||
Get connection URI for PostgreSQL protocol.
|
||||
"""
|
||||
infinity_uri = self.INFINITY_CONFIG["uri"]
|
||||
postgres_port = self.INFINITY_CONFIG["postgres_port"]
|
||||
db_name = self.INFINITY_CONFIG["db_name"]
|
||||
|
||||
if ":" in infinity_uri:
|
||||
host, _ = infinity_uri.split(":")
|
||||
return f"host={host} port={postgres_port} dbname={db_name}"
|
||||
return f"host=localhost port={postgres_port} dbname={db_name}"
|
||||
|
||||
def refresh_conn_pool(self):
|
||||
try:
|
||||
inf_conn = self.conn_pool.get_conn()
|
||||
|
||||
@ -212,7 +212,7 @@ def update_metadata_to(metadata, meta):
|
||||
return metadata
|
||||
|
||||
|
||||
def metadata_schema(metadata: list|None) -> Dict[str, Any]:
|
||||
def metadata_schema(metadata: dict|list|None) -> Dict[str, Any]:
|
||||
if not metadata:
|
||||
return {}
|
||||
properties = {}
|
||||
@ -238,3 +238,47 @@ def metadata_schema(metadata: list|None) -> Dict[str, Any]:
|
||||
|
||||
json_schema["additionalProperties"] = False
|
||||
return json_schema
|
||||
|
||||
|
||||
def _is_json_schema(obj: dict) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
if "$schema" in obj:
|
||||
return True
|
||||
return obj.get("type") == "object" and isinstance(obj.get("properties"), dict)
|
||||
|
||||
|
||||
def _is_metadata_list(obj: list) -> bool:
|
||||
if not isinstance(obj, list) or not obj:
|
||||
return False
|
||||
for item in obj:
|
||||
if not isinstance(item, dict):
|
||||
return False
|
||||
key = item.get("key")
|
||||
if not isinstance(key, str) or not key:
|
||||
return False
|
||||
if "enum" in item and not isinstance(item["enum"], list):
|
||||
return False
|
||||
if "description" in item and not isinstance(item["description"], str):
|
||||
return False
|
||||
if "descriptions" in item and not isinstance(item["descriptions"], str):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def turn2jsonschema(obj: dict | list) -> Dict[str, Any]:
|
||||
if isinstance(obj, dict) and _is_json_schema(obj):
|
||||
return obj
|
||||
if isinstance(obj, list) and _is_metadata_list(obj):
|
||||
normalized = []
|
||||
for item in obj:
|
||||
description = item.get("description", item.get("descriptions", ""))
|
||||
normalized_item = {
|
||||
"key": item.get("key"),
|
||||
"description": description,
|
||||
}
|
||||
if "enum" in item:
|
||||
normalized_item["enum"] = item["enum"]
|
||||
normalized.append(normalized_item)
|
||||
return metadata_schema(normalized)
|
||||
return {}
|
||||
|
||||
@ -14,15 +14,20 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import functools
|
||||
import hashlib
|
||||
import uuid
|
||||
import requests
|
||||
import threading
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
|
||||
def get_uuid():
|
||||
return uuid.uuid1().hex
|
||||
@ -106,3 +111,22 @@ def pip_install_torch():
|
||||
logging.info("Installing pytorch")
|
||||
pkg_names = ["torch>=2.5.0,<3.0.0"]
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
|
||||
|
||||
|
||||
def _thread_pool_executor():
|
||||
max_workers_env = os.getenv("THREAD_POOL_MAX_WORKERS", "128")
|
||||
try:
|
||||
max_workers = int(max_workers_env)
|
||||
except ValueError:
|
||||
max_workers = 128
|
||||
if max_workers < 1:
|
||||
max_workers = 1
|
||||
return ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
|
||||
async def thread_pool_exec(func, *args, **kwargs):
|
||||
loop = asyncio.get_running_loop()
|
||||
if kwargs:
|
||||
func = functools.partial(func, *args, **kwargs)
|
||||
return await loop.run_in_executor(_thread_pool_executor(), func)
|
||||
return await loop.run_in_executor(_thread_pool_executor(), func, *args)
|
||||
|
||||
@ -26,5 +26,8 @@ def normalize_layout_recognizer(layout_recognizer_raw: Any) -> tuple[Any, str |
|
||||
if lowered.endswith("@mineru"):
|
||||
parser_model_name = layout_recognizer_raw.rsplit("@", 1)[0]
|
||||
layout_recognizer = "MinerU"
|
||||
elif lowered.endswith("@paddleocr"):
|
||||
parser_model_name = layout_recognizer_raw.rsplit("@", 1)[0]
|
||||
layout_recognizer = "PaddleOCR"
|
||||
|
||||
return layout_recognizer, parser_model_name
|
||||
|
||||
@ -249,7 +249,11 @@ def init_settings():
|
||||
ES = get_base_config("es", {})
|
||||
docStoreConn = rag.utils.es_conn.ESConnection()
|
||||
elif lower_case_doc_engine == "infinity":
|
||||
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
INFINITY = get_base_config("infinity", {
|
||||
"uri": "infinity:23817",
|
||||
"postgres_port": 5432,
|
||||
"db_name": "default_db"
|
||||
})
|
||||
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
||||
elif lower_case_doc_engine == "opensearch":
|
||||
OS = get_base_config("os", {})
|
||||
@ -257,6 +261,9 @@ def init_settings():
|
||||
elif lower_case_doc_engine == "oceanbase":
|
||||
OB = get_base_config("oceanbase", {})
|
||||
docStoreConn = rag.utils.ob_conn.OBConnection()
|
||||
elif lower_case_doc_engine == "seekdb":
|
||||
OB = get_base_config("seekdb", {})
|
||||
docStoreConn = rag.utils.ob_conn.OBConnection()
|
||||
else:
|
||||
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
||||
|
||||
@ -266,7 +273,11 @@ def init_settings():
|
||||
ES = get_base_config("es", {})
|
||||
msgStoreConn = memory_es_conn.ESConnection()
|
||||
elif DOC_ENGINE == "infinity":
|
||||
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
INFINITY = get_base_config("infinity", {
|
||||
"uri": "infinity:23817",
|
||||
"postgres_port": 5432,
|
||||
"db_name": "default_db"
|
||||
})
|
||||
msgStoreConn = memory_infinity_conn.InfinityConnection()
|
||||
|
||||
global AZURE, S3, MINIO, OSS, GCS
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
"docnm": {"type": "varchar", "default": "", "analyzer": ["rag-coarse", "rag-fine"], "comment": "docnm_kwd, title_tks, title_sm_tks"},
|
||||
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"important_kwd_empty_count": {"type": "integer", "default": 0},
|
||||
"important_keywords": {"type": "varchar", "default": "", "analyzer": ["rag-coarse", "rag-fine"], "comment": "important_kwd, important_tks"},
|
||||
"questions": {"type": "varchar", "default": "", "analyzer": ["rag-coarse", "rag-fine"], "comment": "question_kwd, question_tks"},
|
||||
"content": {"type": "varchar", "default": "", "analyzer": ["rag-coarse", "rag-fine"], "comment": "content_with_weight, content_ltks, content_sm_ltks"},
|
||||
|
||||
@ -5531,6 +5531,51 @@
|
||||
"status": "1",
|
||||
"rank": "900",
|
||||
"llm": []
|
||||
},
|
||||
{
|
||||
"name": "PaddleOCR",
|
||||
"logo": "",
|
||||
"tags": "OCR",
|
||||
"status": "1",
|
||||
"rank": "910",
|
||||
"llm": []
|
||||
},
|
||||
{
|
||||
"name": "n1n",
|
||||
"logo": "",
|
||||
"tags": "LLM",
|
||||
"status": "1",
|
||||
"rank": "900",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "gpt-4o-mini",
|
||||
"tags": "LLM,CHAT,128K,IMAGE2TEXT",
|
||||
"max_tokens": 128000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-4o",
|
||||
"tags": "LLM,CHAT,128K,IMAGE2TEXT",
|
||||
"max_tokens": 128000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-3.5-turbo",
|
||||
"tags": "LLM,CHAT,4K",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "deepseek-chat",
|
||||
"tags": "LLM,CHAT,128K",
|
||||
"max_tokens": 128000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@ -29,6 +29,7 @@ os:
|
||||
password: 'infini_rag_flow_OS_01'
|
||||
infinity:
|
||||
uri: 'localhost:23817'
|
||||
postgres_port: 5432
|
||||
db_name: 'default_db'
|
||||
oceanbase:
|
||||
scheme: 'oceanbase' # set 'mysql' to create connection using mysql config
|
||||
@ -67,9 +68,11 @@ user_default_llm:
|
||||
# oss:
|
||||
# access_key: 'access_key'
|
||||
# secret_key: 'secret_key'
|
||||
# endpoint_url: 'http://oss-cn-hangzhou.aliyuncs.com'
|
||||
# endpoint_url: 'https://s3.oss-cn-hangzhou.aliyuncs.com'
|
||||
# region: 'cn-hangzhou'
|
||||
# bucket: 'bucket_name'
|
||||
# signature_version: 's3'
|
||||
# addressing_style: 'virtual'
|
||||
# azure:
|
||||
# auth_type: 'sas'
|
||||
# container_url: 'container_url'
|
||||
|
||||
64
conf/system_settings.json
Normal file
64
conf/system_settings.json
Normal file
@ -0,0 +1,64 @@
|
||||
{
|
||||
"system_settings": [
|
||||
{
|
||||
"name": "enable_whitelist",
|
||||
"source": "variable",
|
||||
"data_type": "bool",
|
||||
"value": "true"
|
||||
},
|
||||
{
|
||||
"name": "default_role",
|
||||
"source": "variable",
|
||||
"data_type": "string",
|
||||
"value": ""
|
||||
},
|
||||
{
|
||||
"name": "mail.server",
|
||||
"source": "variable",
|
||||
"data_type": "string",
|
||||
"value": ""
|
||||
},
|
||||
{
|
||||
"name": "mail.port",
|
||||
"source": "variable",
|
||||
"data_type": "integer",
|
||||
"value": ""
|
||||
},
|
||||
{
|
||||
"name": "mail.use_ssl",
|
||||
"source": "variable",
|
||||
"data_type": "bool",
|
||||
"value": "false"
|
||||
},
|
||||
{
|
||||
"name": "mail.use_tls",
|
||||
"source": "variable",
|
||||
"data_type": "bool",
|
||||
"value": "false"
|
||||
},
|
||||
{
|
||||
"name": "mail.username",
|
||||
"source": "variable",
|
||||
"data_type": "string",
|
||||
"value": ""
|
||||
},
|
||||
{
|
||||
"name": "mail.password",
|
||||
"source": "variable",
|
||||
"data_type": "string",
|
||||
"value": ""
|
||||
},
|
||||
{
|
||||
"name": "mail.timeout",
|
||||
"source": "variable",
|
||||
"data_type": "integer",
|
||||
"value": "10"
|
||||
},
|
||||
{
|
||||
"name": "mail.default_sender",
|
||||
"source": "variable",
|
||||
"data_type": "string",
|
||||
"value": ""
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -103,6 +103,31 @@ We use vision information to resolve problems as human being.
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/>
|
||||
</div>
|
||||
|
||||
- **Table Auto-Rotation**. For scanned PDFs where tables may be incorrectly oriented (rotated 90°, 180°, or 270°),
|
||||
the PDF parser automatically detects the best rotation angle using OCR confidence scores before performing
|
||||
table structure recognition. This significantly improves OCR accuracy and table structure detection for rotated tables.
|
||||
|
||||
The feature evaluates 4 rotation angles (0°, 90°, 180°, 270°) and selects the one with highest OCR confidence.
|
||||
After determining the best orientation, it re-performs OCR on the correctly rotated table image.
|
||||
|
||||
This feature is **enabled by default**. You can control it via environment variable:
|
||||
```bash
|
||||
# Disable table auto-rotation
|
||||
export TABLE_AUTO_ROTATE=false
|
||||
|
||||
# Enable table auto-rotation (default)
|
||||
export TABLE_AUTO_ROTATE=true
|
||||
```
|
||||
|
||||
Or via API parameter:
|
||||
```python
|
||||
from deepdoc.parser import PdfParser
|
||||
|
||||
parser = PdfParser()
|
||||
# Disable auto-rotation for this call
|
||||
boxes, tables = parser(pdf_path, auto_rotate_tables=False)
|
||||
```
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Parser
|
||||
|
||||
@ -102,6 +102,30 @@ export HF_ENDPOINT=https://hf-mirror.com
|
||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/>
|
||||
</div>
|
||||
|
||||
- **表格自动旋转(Table Auto-Rotation)**。对于扫描的 PDF 文档,表格可能存在方向错误(旋转了 90°、180° 或 270°),
|
||||
PDF 解析器会在进行表格结构识别之前,自动使用 OCR 置信度来检测最佳旋转角度。这大大提高了旋转表格的 OCR 准确性和表格结构检测效果。
|
||||
|
||||
该功能会评估 4 个旋转角度(0°、90°、180°、270°),并选择 OCR 置信度最高的角度。
|
||||
确定最佳方向后,会对旋转后的表格图像重新进行 OCR 识别。
|
||||
|
||||
此功能**默认启用**。您可以通过环境变量控制:
|
||||
```bash
|
||||
# 禁用表格自动旋转
|
||||
export TABLE_AUTO_ROTATE=false
|
||||
|
||||
# 启用表格自动旋转(默认)
|
||||
export TABLE_AUTO_ROTATE=true
|
||||
```
|
||||
|
||||
或通过 API 参数控制:
|
||||
```python
|
||||
from deepdoc.parser import PdfParser
|
||||
|
||||
parser = PdfParser()
|
||||
# 禁用此次调用的自动旋转
|
||||
boxes, tables = parser(pdf_path, auto_rotate_tables=False)
|
||||
```
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 解析器
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import logging
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@ -21,9 +22,10 @@ from common.constants import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.connection_utils import timeout
|
||||
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
||||
from rag.prompts.generator import vision_llm_figure_describe_prompt
|
||||
|
||||
from rag.prompts.generator import vision_llm_figure_describe_prompt, vision_llm_figure_describe_prompt_with_context
|
||||
from rag.nlp import append_context2table_image4pdf
|
||||
|
||||
# need to delete before pr
|
||||
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
||||
if not figures_data_without_positions:
|
||||
return []
|
||||
@ -36,7 +38,6 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
||||
if isinstance(figure_data[1], Image.Image)
|
||||
]
|
||||
|
||||
|
||||
def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs):
|
||||
if not sections:
|
||||
return tbls
|
||||
@ -84,20 +85,36 @@ def vision_figure_parser_figure_xlsx_wrapper(images,callback=None, **kwargs):
|
||||
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
|
||||
if not tbls:
|
||||
return []
|
||||
sections = kwargs.get("sections")
|
||||
parser_config = kwargs.get("parser_config", {})
|
||||
context_size = max(0, int(parser_config.get("image_context_size", 0) or 0))
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
|
||||
except Exception:
|
||||
vision_model = None
|
||||
if vision_model:
|
||||
|
||||
def is_figure_item(item):
|
||||
return (
|
||||
isinstance(item[0][0], Image.Image) and
|
||||
isinstance(item[0][1], list)
|
||||
)
|
||||
return isinstance(item[0][0], Image.Image) and isinstance(item[0][1], list)
|
||||
|
||||
figures_data = [item for item in tbls if is_figure_item(item)]
|
||||
figure_contexts = []
|
||||
if sections and figures_data and context_size > 0:
|
||||
figure_contexts = append_context2table_image4pdf(
|
||||
sections,
|
||||
figures_data,
|
||||
context_size,
|
||||
return_context=True,
|
||||
)
|
||||
try:
|
||||
docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs)
|
||||
docx_vision_parser = VisionFigureParser(
|
||||
vision_model=vision_model,
|
||||
figures_data=figures_data,
|
||||
figure_contexts=figure_contexts,
|
||||
context_size=context_size,
|
||||
**kwargs,
|
||||
)
|
||||
boosted_figures = docx_vision_parser(callback=callback)
|
||||
tbls = [item for item in tbls if not is_figure_item(item)]
|
||||
tbls.extend(boosted_figures)
|
||||
@ -106,12 +123,57 @@ def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
|
||||
return tbls
|
||||
|
||||
|
||||
shared_executor = ThreadPoolExecutor(max_workers=10)
|
||||
def vision_figure_parser_docx_wrapper_naive(chunks, idx_lst, callback=None, **kwargs):
|
||||
if not chunks:
|
||||
return []
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
|
||||
except Exception:
|
||||
vision_model = None
|
||||
if vision_model:
|
||||
@timeout(30, 3)
|
||||
def worker(idx, ck):
|
||||
context_above = ck.get("context_above", "")
|
||||
context_below = ck.get("context_below", "")
|
||||
if context_above or context_below:
|
||||
prompt = vision_llm_figure_describe_prompt_with_context(
|
||||
# context_above + caption if any
|
||||
context_above=ck.get("context_above") + ck.get("text", ""),
|
||||
context_below=ck.get("context_below"),
|
||||
)
|
||||
logging.info(f"[VisionFigureParser] figure={idx} context_above_len={len(context_above)} context_below_len={len(context_below)} prompt=with_context")
|
||||
logging.info(f"[VisionFigureParser] figure={idx} context_above_snippet={context_above[:512]}")
|
||||
logging.info(f"[VisionFigureParser] figure={idx} context_below_snippet={context_below[:512]}")
|
||||
else:
|
||||
prompt = vision_llm_figure_describe_prompt()
|
||||
logging.info(f"[VisionFigureParser] figure={idx} context_len=0 prompt=default")
|
||||
|
||||
description_text = picture_vision_llm_chunk(
|
||||
binary=ck.get("image"),
|
||||
vision_model=vision_model,
|
||||
prompt=prompt,
|
||||
callback=callback,
|
||||
)
|
||||
return idx, description_text
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [
|
||||
executor.submit(worker, idx, chunks[idx])
|
||||
for idx in idx_lst
|
||||
]
|
||||
|
||||
for future in as_completed(futures):
|
||||
idx, description = future.result()
|
||||
chunks[idx]['text'] += description
|
||||
|
||||
shared_executor = ThreadPoolExecutor(max_workers=10)
|
||||
|
||||
class VisionFigureParser:
|
||||
def __init__(self, vision_model, figures_data, *args, **kwargs):
|
||||
self.vision_model = vision_model
|
||||
self.figure_contexts = kwargs.get("figure_contexts") or []
|
||||
self.context_size = max(0, int(kwargs.get("context_size", 0) or 0))
|
||||
self._extract_figures_info(figures_data)
|
||||
assert len(self.figures) == len(self.descriptions)
|
||||
assert not self.positions or (len(self.figures) == len(self.positions))
|
||||
@ -156,10 +218,25 @@ class VisionFigureParser:
|
||||
|
||||
@timeout(30, 3)
|
||||
def process(figure_idx, figure_binary):
|
||||
context_above = ""
|
||||
context_below = ""
|
||||
if figure_idx < len(self.figure_contexts):
|
||||
context_above, context_below = self.figure_contexts[figure_idx]
|
||||
if context_above or context_below:
|
||||
prompt = vision_llm_figure_describe_prompt_with_context(
|
||||
context_above=context_above,
|
||||
context_below=context_below,
|
||||
)
|
||||
logging.info(f"[VisionFigureParser] figure={figure_idx} context_size={self.context_size} context_above_len={len(context_above)} context_below_len={len(context_below)} prompt=with_context")
|
||||
logging.info(f"[VisionFigureParser] figure={figure_idx} context_above_snippet={context_above[:512]}")
|
||||
logging.info(f"[VisionFigureParser] figure={figure_idx} context_below_snippet={context_below[:512]}")
|
||||
else:
|
||||
prompt = vision_llm_figure_describe_prompt()
|
||||
logging.info(f"[VisionFigureParser] figure={figure_idx} context_size={self.context_size} context_len=0 prompt=default")
|
||||
description_text = picture_vision_llm_chunk(
|
||||
binary=figure_binary,
|
||||
vision_model=self.vision_model,
|
||||
prompt=vision_llm_figure_describe_prompt(),
|
||||
prompt=prompt,
|
||||
callback=callback,
|
||||
)
|
||||
return figure_idx, description_text
|
||||
|
||||
@ -17,6 +17,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
@ -138,39 +139,58 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
self.outlines = []
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
@staticmethod
|
||||
def _is_zipinfo_symlink(member: zipfile.ZipInfo) -> bool:
|
||||
return (member.external_attr >> 16) & 0o170000 == 0o120000
|
||||
|
||||
def _extract_zip_no_root(self, zip_path, extract_to, root_dir):
|
||||
self.logger.info(f"[MinerU] Extract zip: zip_path={zip_path}, extract_to={extract_to}, root_hint={root_dir}")
|
||||
base_dir = Path(extract_to).resolve()
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
members = zip_ref.infolist()
|
||||
if not root_dir:
|
||||
files = zip_ref.namelist()
|
||||
if files and files[0].endswith("/"):
|
||||
root_dir = files[0]
|
||||
if members and members[0].filename.endswith("/"):
|
||||
root_dir = members[0].filename
|
||||
else:
|
||||
root_dir = None
|
||||
if root_dir:
|
||||
root_dir = root_dir.replace("\\", "/")
|
||||
if not root_dir.endswith("/"):
|
||||
root_dir += "/"
|
||||
|
||||
if not root_dir or not root_dir.endswith("/"):
|
||||
self.logger.info(f"[MinerU] No root directory found, extracting all (root_hint={root_dir})")
|
||||
zip_ref.extractall(extract_to)
|
||||
return
|
||||
for member in members:
|
||||
if member.flag_bits & 0x1:
|
||||
raise RuntimeError(f"[MinerU] Encrypted zip entry not supported: {member.filename}")
|
||||
if self._is_zipinfo_symlink(member):
|
||||
raise RuntimeError(f"[MinerU] Symlink zip entry not supported: {member.filename}")
|
||||
|
||||
root_len = len(root_dir)
|
||||
for member in zip_ref.infolist():
|
||||
filename = member.filename
|
||||
if filename == root_dir:
|
||||
name = member.filename.replace("\\", "/")
|
||||
if root_dir and name == root_dir:
|
||||
self.logger.info("[MinerU] Ignore root folder...")
|
||||
continue
|
||||
if root_dir and name.startswith(root_dir):
|
||||
name = name[len(root_dir) :]
|
||||
if not name:
|
||||
continue
|
||||
if name.startswith("/") or name.startswith("//") or re.match(r"^[A-Za-z]:", name):
|
||||
raise RuntimeError(f"[MinerU] Unsafe zip path (absolute): {member.filename}")
|
||||
|
||||
path = filename
|
||||
if path.startswith(root_dir):
|
||||
path = path[root_len:]
|
||||
parts = [p for p in name.split("/") if p not in ("", ".")]
|
||||
if any(p == ".." for p in parts):
|
||||
raise RuntimeError(f"[MinerU] Unsafe zip path (traversal): {member.filename}")
|
||||
|
||||
rel_path = os.path.join(*parts) if parts else ""
|
||||
dest_path = (Path(extract_to) / rel_path).resolve(strict=False)
|
||||
if dest_path != base_dir and base_dir not in dest_path.parents:
|
||||
raise RuntimeError(f"[MinerU] Unsafe zip path (escape): {member.filename}")
|
||||
|
||||
full_path = os.path.join(extract_to, path)
|
||||
if member.is_dir():
|
||||
os.makedirs(full_path, exist_ok=True)
|
||||
else:
|
||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||
with open(full_path, "wb") as f:
|
||||
f.write(zip_ref.read(filename))
|
||||
os.makedirs(dest_path, exist_ok=True)
|
||||
continue
|
||||
|
||||
os.makedirs(dest_path.parent, exist_ok=True)
|
||||
with zip_ref.open(member) as src, open(dest_path, "wb") as dst:
|
||||
shutil.copyfileobj(src, dst)
|
||||
|
||||
@staticmethod
|
||||
def _is_http_endpoint_valid(url, timeout=5):
|
||||
@ -237,8 +257,6 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
output_path = tempfile.mkdtemp(prefix=f"{pdf_file_name}_{options.method}_", dir=str(output_dir))
|
||||
output_zip_path = os.path.join(str(output_dir), f"{Path(output_path).name}.zip")
|
||||
|
||||
files = {"files": (pdf_file_name + ".pdf", open(pdf_file_path, "rb"), "application/pdf")}
|
||||
|
||||
data = {
|
||||
"output_dir": "./output",
|
||||
"lang_list": options.lang,
|
||||
@ -270,26 +288,35 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
self.logger.info(f"[MinerU] invoke api: {self.mineru_api}/file_parse backend={options.backend} server_url={data.get('server_url')}")
|
||||
if callback:
|
||||
callback(0.20, f"[MinerU] invoke api: {self.mineru_api}/file_parse")
|
||||
response = requests.post(url=f"{self.mineru_api}/file_parse", files=files, data=data, headers=headers,
|
||||
timeout=1800)
|
||||
with open(pdf_file_path, "rb") as pdf_file:
|
||||
files = {"files": (pdf_file_name + ".pdf", pdf_file, "application/pdf")}
|
||||
with requests.post(
|
||||
url=f"{self.mineru_api}/file_parse",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers,
|
||||
timeout=1800,
|
||||
stream=True,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if content_type.startswith("application/zip"):
|
||||
self.logger.info(f"[MinerU] zip file returned, saving to {output_zip_path}...")
|
||||
|
||||
response.raise_for_status()
|
||||
if response.headers.get("Content-Type") == "application/zip":
|
||||
self.logger.info(f"[MinerU] zip file returned, saving to {output_zip_path}...")
|
||||
if callback:
|
||||
callback(0.30, f"[MinerU] zip file returned, saving to {output_zip_path}...")
|
||||
|
||||
if callback:
|
||||
callback(0.30, f"[MinerU] zip file returned, saving to {output_zip_path}...")
|
||||
with open(output_zip_path, "wb") as f:
|
||||
response.raw.decode_content = True
|
||||
shutil.copyfileobj(response.raw, f)
|
||||
|
||||
with open(output_zip_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
self.logger.info(f"[MinerU] Unzip to {output_path}...")
|
||||
self._extract_zip_no_root(output_zip_path, output_path, pdf_file_name + "/")
|
||||
|
||||
self.logger.info(f"[MinerU] Unzip to {output_path}...")
|
||||
self._extract_zip_no_root(output_zip_path, output_path, pdf_file_name + "/")
|
||||
|
||||
if callback:
|
||||
callback(0.40, f"[MinerU] Unzip to {output_path}...")
|
||||
else:
|
||||
self.logger.warning(f"[MinerU] not zip returned from api: {response.headers.get('Content-Type')}")
|
||||
if callback:
|
||||
callback(0.40, f"[MinerU] Unzip to {output_path}...")
|
||||
else:
|
||||
self.logger.warning(f"[MinerU] not zip returned from api: {content_type}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"[MinerU] api failed with exception {e}")
|
||||
self.logger.info("[MinerU] Api completed successfully.")
|
||||
|
||||
554
deepdoc/parser/paddleocr_parser.py
Normal file
554
deepdoc/parser/paddleocr_parser.py
Normal file
@ -0,0 +1,554 @@
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from io import BytesIO
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, ClassVar, Literal, Optional, Union, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import pdfplumber
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
except Exception:
|
||||
|
||||
class RAGFlowPdfParser:
|
||||
pass
|
||||
|
||||
|
||||
AlgorithmType = Literal["PaddleOCR-VL"]
|
||||
SectionTuple = tuple[str, ...]
|
||||
TableTuple = tuple[str, ...]
|
||||
ParseResult = tuple[list[SectionTuple], list[TableTuple]]
|
||||
|
||||
|
||||
_MARKDOWN_IMAGE_PATTERN = re.compile(
|
||||
r"""
|
||||
<div[^>]*>\s*
|
||||
<img[^>]*/>\s*
|
||||
</div>
|
||||
|
|
||||
<img[^>]*/>
|
||||
""",
|
||||
re.IGNORECASE | re.VERBOSE | re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def _remove_images_from_markdown(markdown: str) -> str:
|
||||
return _MARKDOWN_IMAGE_PATTERN.sub("", markdown)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PaddleOCRVLConfig:
|
||||
"""Configuration for PaddleOCR-VL algorithm."""
|
||||
|
||||
use_doc_orientation_classify: Optional[bool] = False
|
||||
use_doc_unwarping: Optional[bool] = False
|
||||
use_layout_detection: Optional[bool] = None
|
||||
use_polygon_points: Optional[bool] = None
|
||||
use_chart_recognition: Optional[bool] = None
|
||||
use_seal_recognition: Optional[bool] = None
|
||||
use_ocr_for_image_block: Optional[bool] = None
|
||||
layout_threshold: Optional[Union[float, dict]] = None
|
||||
layout_nms: Optional[bool] = None
|
||||
layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None
|
||||
layout_merge_bboxes_mode: Optional[Union[str, dict]] = None
|
||||
prompt_label: Optional[str] = None
|
||||
format_block_content: Optional[bool] = True
|
||||
repetition_penalty: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
min_pixels: Optional[int] = None
|
||||
max_pixels: Optional[int] = None
|
||||
max_new_tokens: Optional[int] = None
|
||||
merge_layout_blocks: Optional[bool] = False
|
||||
markdown_ignore_labels: Optional[List[str]] = None
|
||||
vlm_extra_args: Optional[dict] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PaddleOCRConfig:
|
||||
"""Main configuration for PaddleOCR parser."""
|
||||
|
||||
api_url: str = ""
|
||||
access_token: Optional[str] = None
|
||||
algorithm: AlgorithmType = "PaddleOCR-VL"
|
||||
request_timeout: int = 600
|
||||
prettify_markdown: bool = True
|
||||
show_formula_number: bool = True
|
||||
visualize: bool = False
|
||||
additional_params: dict[str, Any] = field(default_factory=dict)
|
||||
algorithm_config: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config: Optional[dict[str, Any]]) -> "PaddleOCRConfig":
|
||||
"""Create configuration from dictionary."""
|
||||
if not config:
|
||||
return cls()
|
||||
|
||||
cfg = config.copy()
|
||||
algorithm = cfg.get("algorithm", "PaddleOCR-VL")
|
||||
|
||||
# Validate algorithm
|
||||
if algorithm not in ("PaddleOCR-VL",):
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
||||
# Extract algorithm-specific configuration
|
||||
algorithm_config: dict[str, Any] = {}
|
||||
if algorithm == "PaddleOCR-VL":
|
||||
# Create default PaddleOCRVLConfig object and convert to dict
|
||||
algorithm_config = asdict(PaddleOCRVLConfig())
|
||||
algorithm_config_user = cfg.get("algorithm_config")
|
||||
if isinstance(algorithm_config_user, dict):
|
||||
algorithm_config.update({k: v for k, v in algorithm_config_user.items() if v is not None})
|
||||
|
||||
# Remove processed keys
|
||||
cfg.pop("algorithm_config", None)
|
||||
|
||||
# Prepare initialization arguments
|
||||
field_names = {field.name for field in fields(cls)}
|
||||
init_kwargs: dict[str, Any] = {}
|
||||
|
||||
for field_name in field_names:
|
||||
if field_name in cfg:
|
||||
init_kwargs[field_name] = cfg[field_name]
|
||||
|
||||
init_kwargs["algorithm_config"] = algorithm_config
|
||||
|
||||
return cls(**init_kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_kwargs(cls, **kwargs: Any) -> "PaddleOCRConfig":
|
||||
"""Create configuration from keyword arguments."""
|
||||
return cls.from_dict(kwargs)
|
||||
|
||||
|
||||
class PaddleOCRParser(RAGFlowPdfParser):
|
||||
"""Parser for PDF documents using PaddleOCR API."""
|
||||
|
||||
_ZOOMIN = 2
|
||||
|
||||
_COMMON_FIELD_MAPPING: ClassVar[dict[str, str]] = {
|
||||
"prettify_markdown": "prettifyMarkdown",
|
||||
"show_formula_number": "showFormulaNumber",
|
||||
"visualize": "visualize",
|
||||
}
|
||||
|
||||
_ALGORITHM_FIELD_MAPPINGS: ClassVar[dict[str, dict[str, str]]] = {
|
||||
"PaddleOCR-VL": {
|
||||
"use_doc_orientation_classify": "useDocOrientationClassify",
|
||||
"use_doc_unwarping": "useDocUnwarping",
|
||||
"use_layout_detection": "useLayoutDetection",
|
||||
"use_polygon_points": "usePolygonPoints",
|
||||
"use_chart_recognition": "useChartRecognition",
|
||||
"use_seal_recognition": "useSealRecognition",
|
||||
"use_ocr_for_image_block": "useOcrForImageBlock",
|
||||
"layout_threshold": "layoutThreshold",
|
||||
"layout_nms": "layoutNms",
|
||||
"layout_unclip_ratio": "layoutUnclipRatio",
|
||||
"layout_merge_bboxes_mode": "layoutMergeBboxesMode",
|
||||
"prompt_label": "promptLabel",
|
||||
"format_block_content": "formatBlockContent",
|
||||
"repetition_penalty": "repetitionPenalty",
|
||||
"temperature": "temperature",
|
||||
"top_p": "topP",
|
||||
"min_pixels": "minPixels",
|
||||
"max_pixels": "maxPixels",
|
||||
"max_new_tokens": "maxNewTokens",
|
||||
"merge_layout_blocks": "mergeLayoutBlocks",
|
||||
"markdown_ignore_labels": "markdownIgnoreLabels",
|
||||
"vlm_extra_args": "vlmExtraArgs",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_url: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
algorithm: AlgorithmType = "PaddleOCR-VL",
|
||||
*,
|
||||
request_timeout: int = 600,
|
||||
):
|
||||
"""Initialize PaddleOCR parser."""
|
||||
super().__init__()
|
||||
|
||||
self.api_url = api_url.rstrip("/") if api_url else os.getenv("PADDLEOCR_API_URL", "")
|
||||
self.access_token = access_token or os.getenv("PADDLEOCR_ACCESS_TOKEN")
|
||||
self.algorithm = algorithm
|
||||
self.request_timeout = request_timeout
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
# Force PDF file type
|
||||
self.file_type = 0
|
||||
|
||||
# Initialize page images for cropping
|
||||
self.page_images: list[Image.Image] = []
|
||||
self.page_from = 0
|
||||
|
||||
# Public methods
|
||||
def check_installation(self) -> tuple[bool, str]:
|
||||
"""Check if the parser is properly installed and configured."""
|
||||
if not self.api_url:
|
||||
return False, "[PaddleOCR] API URL not configured"
|
||||
|
||||
# TODO [@Bobholamovic]: Check URL availability and token validity
|
||||
|
||||
return True, ""
|
||||
|
||||
def parse_pdf(
|
||||
self,
|
||||
filepath: str | PathLike[str],
|
||||
binary: BytesIO | bytes | None = None,
|
||||
callback: Optional[Callable[[float, str], None]] = None,
|
||||
*,
|
||||
parse_method: str = "raw",
|
||||
api_url: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
algorithm: Optional[AlgorithmType] = None,
|
||||
request_timeout: Optional[int] = None,
|
||||
prettify_markdown: Optional[bool] = None,
|
||||
show_formula_number: Optional[bool] = None,
|
||||
visualize: Optional[bool] = None,
|
||||
additional_params: Optional[dict[str, Any]] = None,
|
||||
algorithm_config: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> ParseResult:
|
||||
"""Parse PDF document using PaddleOCR API."""
|
||||
# Create configuration - pass all kwargs to capture VL config parameters
|
||||
config_dict = {
|
||||
"api_url": api_url if api_url is not None else self.api_url,
|
||||
"access_token": access_token if access_token is not None else self.access_token,
|
||||
"algorithm": algorithm if algorithm is not None else self.algorithm,
|
||||
"request_timeout": request_timeout if request_timeout is not None else self.request_timeout,
|
||||
}
|
||||
if prettify_markdown is not None:
|
||||
config_dict["prettify_markdown"] = prettify_markdown
|
||||
if show_formula_number is not None:
|
||||
config_dict["show_formula_number"] = show_formula_number
|
||||
if visualize is not None:
|
||||
config_dict["visualize"] = visualize
|
||||
if additional_params is not None:
|
||||
config_dict["additional_params"] = additional_params
|
||||
if algorithm_config is not None:
|
||||
config_dict["algorithm_config"] = algorithm_config
|
||||
|
||||
cfg = PaddleOCRConfig.from_dict(config_dict)
|
||||
|
||||
if not cfg.api_url:
|
||||
raise RuntimeError("[PaddleOCR] API URL missing")
|
||||
|
||||
# Prepare file data and generate page images for cropping
|
||||
data_bytes = self._prepare_file_data(filepath, binary)
|
||||
|
||||
# Generate page images for cropping functionality
|
||||
input_source = filepath if binary is None else binary
|
||||
try:
|
||||
self.__images__(input_source, callback=callback)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"[PaddleOCR] Failed to generate page images for cropping: {e}")
|
||||
|
||||
# Build and send request
|
||||
result = self._send_request(data_bytes, cfg, callback)
|
||||
|
||||
# Process response
|
||||
sections = self._transfer_to_sections(result, algorithm=cfg.algorithm, parse_method=parse_method)
|
||||
if callback:
|
||||
callback(0.9, f"[PaddleOCR] done, sections: {len(sections)}")
|
||||
|
||||
tables = self._transfer_to_tables(result)
|
||||
if callback:
|
||||
callback(1.0, f"[PaddleOCR] done, tables: {len(tables)}")
|
||||
|
||||
return sections, tables
|
||||
|
||||
def _prepare_file_data(self, filepath: str | PathLike[str], binary: BytesIO | bytes | None) -> bytes:
|
||||
"""Prepare file data for API request."""
|
||||
source_path = Path(filepath)
|
||||
|
||||
if binary is not None:
|
||||
if isinstance(binary, (bytes, bytearray)):
|
||||
return binary
|
||||
return binary.getbuffer().tobytes()
|
||||
|
||||
if not source_path.exists():
|
||||
raise FileNotFoundError(f"[PaddleOCR] file not found: {source_path}")
|
||||
|
||||
return source_path.read_bytes()
|
||||
|
||||
def _build_payload(self, data: bytes, file_type: int, config: PaddleOCRConfig) -> dict[str, Any]:
|
||||
"""Build payload for API request."""
|
||||
payload: dict[str, Any] = {
|
||||
"file": base64.b64encode(data).decode("ascii"),
|
||||
"fileType": file_type,
|
||||
}
|
||||
|
||||
# Add common parameters
|
||||
for param_key, param_value in [
|
||||
("prettify_markdown", config.prettify_markdown),
|
||||
("show_formula_number", config.show_formula_number),
|
||||
("visualize", config.visualize),
|
||||
]:
|
||||
if param_value is not None:
|
||||
api_param = self._COMMON_FIELD_MAPPING[param_key]
|
||||
payload[api_param] = param_value
|
||||
|
||||
# Add algorithm-specific parameters
|
||||
algorithm_mapping = self._ALGORITHM_FIELD_MAPPINGS.get(config.algorithm, {})
|
||||
for param_key, param_value in config.algorithm_config.items():
|
||||
if param_value is not None and param_key in algorithm_mapping:
|
||||
api_param = algorithm_mapping[param_key]
|
||||
payload[api_param] = param_value
|
||||
|
||||
# Add any additional parameters
|
||||
if config.additional_params:
|
||||
payload.update(config.additional_params)
|
||||
|
||||
return payload
|
||||
|
||||
def _send_request(self, data: bytes, config: PaddleOCRConfig, callback: Optional[Callable[[float, str], None]]) -> dict[str, Any]:
|
||||
"""Send request to PaddleOCR API and parse response."""
|
||||
# Build payload
|
||||
payload = self._build_payload(data, self.file_type, config)
|
||||
|
||||
# Prepare headers
|
||||
headers = {"Content-Type": "application/json", "Client-Platform": "ragflow"}
|
||||
if config.access_token:
|
||||
headers["Authorization"] = f"token {config.access_token}"
|
||||
|
||||
self.logger.info("[PaddleOCR] invoking API")
|
||||
if callback:
|
||||
callback(0.1, "[PaddleOCR] submitting request")
|
||||
|
||||
# Send request
|
||||
try:
|
||||
resp = requests.post(config.api_url, json=payload, headers=headers, timeout=self.request_timeout)
|
||||
resp.raise_for_status()
|
||||
except Exception as exc:
|
||||
if callback:
|
||||
callback(-1, f"[PaddleOCR] request failed: {exc}")
|
||||
raise RuntimeError(f"[PaddleOCR] request failed: {exc}")
|
||||
|
||||
# Parse response
|
||||
try:
|
||||
response_data = resp.json()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"[PaddleOCR] response is not JSON: {exc}") from exc
|
||||
|
||||
if callback:
|
||||
callback(0.8, "[PaddleOCR] response received")
|
||||
|
||||
# Validate response format
|
||||
if response_data.get("errorCode") != 0 or not isinstance(response_data.get("result"), dict):
|
||||
if callback:
|
||||
callback(-1, "[PaddleOCR] invalid response format")
|
||||
raise RuntimeError("[PaddleOCR] invalid response format")
|
||||
|
||||
return response_data["result"]
|
||||
|
||||
def _transfer_to_sections(self, result: dict[str, Any], algorithm: AlgorithmType, parse_method: str) -> list[SectionTuple]:
|
||||
"""Convert API response to section tuples."""
|
||||
sections: list[SectionTuple] = []
|
||||
|
||||
if algorithm == "PaddleOCR-VL":
|
||||
layout_parsing_results = result.get("layoutParsingResults", [])
|
||||
|
||||
for page_idx, layout_result in enumerate(layout_parsing_results):
|
||||
pruned_result = layout_result.get("prunedResult", {})
|
||||
parsing_res_list = pruned_result.get("parsing_res_list", [])
|
||||
|
||||
for block in parsing_res_list:
|
||||
block_content = block.get("block_content", "").strip()
|
||||
if not block_content:
|
||||
continue
|
||||
|
||||
# Remove images
|
||||
block_content = _remove_images_from_markdown(block_content)
|
||||
|
||||
label = block.get("block_label", "")
|
||||
block_bbox = block.get("block_bbox", [0, 0, 0, 0])
|
||||
|
||||
tag = f"@@{page_idx + 1}\t{block_bbox[0] // self._ZOOMIN}\t{block_bbox[2] // self._ZOOMIN}\t{block_bbox[1] // self._ZOOMIN}\t{block_bbox[3] // self._ZOOMIN}##"
|
||||
|
||||
if parse_method == "manual":
|
||||
sections.append((block_content, label, tag))
|
||||
elif parse_method == "paper":
|
||||
sections.append((block_content + tag, label))
|
||||
else:
|
||||
sections.append((block_content, tag))
|
||||
|
||||
return sections
|
||||
|
||||
def _transfer_to_tables(self, result: dict[str, Any]) -> list[TableTuple]:
|
||||
"""Convert API response to table tuples."""
|
||||
return []
|
||||
|
||||
def __images__(self, fnm, page_from=0, page_to=100, callback=None):
|
||||
"""Generate page images from PDF for cropping."""
|
||||
self.page_from = page_from
|
||||
self.page_to = page_to
|
||||
try:
|
||||
with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf:
|
||||
self.pdf = pdf
|
||||
self.page_images = [p.to_image(resolution=72, antialias=True).original for i, p in enumerate(self.pdf.pages[page_from:page_to])]
|
||||
except Exception as e:
|
||||
self.page_images = None
|
||||
self.logger.exception(e)
|
||||
|
||||
@staticmethod
|
||||
def extract_positions(txt: str):
|
||||
"""Extract position information from text tags."""
|
||||
poss = []
|
||||
for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt):
|
||||
pn, left, right, top, bottom = tag.strip("#").strip("@").split("\t")
|
||||
left, right, top, bottom = float(left), float(right), float(top), float(bottom)
|
||||
poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom))
|
||||
return poss
|
||||
|
||||
def crop(self, text: str, need_position: bool = False):
|
||||
"""Crop images from PDF based on position tags in text."""
|
||||
imgs = []
|
||||
poss = self.extract_positions(text)
|
||||
|
||||
if not poss:
|
||||
if need_position:
|
||||
return None, None
|
||||
return
|
||||
|
||||
if not getattr(self, "page_images", None):
|
||||
self.logger.warning("[PaddleOCR] crop called without page images; skipping image generation.")
|
||||
if need_position:
|
||||
return None, None
|
||||
return
|
||||
|
||||
page_count = len(self.page_images)
|
||||
|
||||
filtered_poss = []
|
||||
for pns, left, right, top, bottom in poss:
|
||||
if not pns:
|
||||
self.logger.warning("[PaddleOCR] Empty page index list in crop; skipping this position.")
|
||||
continue
|
||||
valid_pns = [p for p in pns if 0 <= p < page_count]
|
||||
if not valid_pns:
|
||||
self.logger.warning(f"[PaddleOCR] All page indices {pns} out of range for {page_count} pages; skipping.")
|
||||
continue
|
||||
filtered_poss.append((valid_pns, left, right, top, bottom))
|
||||
|
||||
poss = filtered_poss
|
||||
if not poss:
|
||||
self.logger.warning("[PaddleOCR] No valid positions after filtering; skip cropping.")
|
||||
if need_position:
|
||||
return None, None
|
||||
return
|
||||
|
||||
max_width = max(np.max([right - left for (_, left, right, _, _) in poss]), 6)
|
||||
GAP = 6
|
||||
pos = poss[0]
|
||||
first_page_idx = pos[0][0]
|
||||
poss.insert(0, ([first_page_idx], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
|
||||
pos = poss[-1]
|
||||
last_page_idx = pos[0][-1]
|
||||
if not (0 <= last_page_idx < page_count):
|
||||
self.logger.warning(f"[PaddleOCR] Last page index {last_page_idx} out of range for {page_count} pages; skipping crop.")
|
||||
if need_position:
|
||||
return None, None
|
||||
return
|
||||
last_page_height = self.page_images[last_page_idx].size[1]
|
||||
poss.append(
|
||||
(
|
||||
[last_page_idx],
|
||||
pos[1],
|
||||
pos[2],
|
||||
min(last_page_height, pos[4] + GAP),
|
||||
min(last_page_height, pos[4] + 120),
|
||||
)
|
||||
)
|
||||
|
||||
positions = []
|
||||
for ii, (pns, left, right, top, bottom) in enumerate(poss):
|
||||
right = left + max_width
|
||||
|
||||
if bottom <= top:
|
||||
bottom = top + 2
|
||||
|
||||
for pn in pns[1:]:
|
||||
if 0 <= pn - 1 < page_count:
|
||||
bottom += self.page_images[pn - 1].size[1]
|
||||
else:
|
||||
self.logger.warning(f"[PaddleOCR] Page index {pn}-1 out of range for {page_count} pages during crop; skipping height accumulation.")
|
||||
|
||||
if not (0 <= pns[0] < page_count):
|
||||
self.logger.warning(f"[PaddleOCR] Base page index {pns[0]} out of range for {page_count} pages during crop; skipping this segment.")
|
||||
continue
|
||||
|
||||
img0 = self.page_images[pns[0]]
|
||||
x0, y0, x1, y1 = int(left), int(top), int(right), int(min(bottom, img0.size[1]))
|
||||
crop0 = img0.crop((x0, y0, x1, y1))
|
||||
imgs.append(crop0)
|
||||
if 0 < ii < len(poss) - 1:
|
||||
positions.append((pns[0] + self.page_from, x0, x1, y0, y1))
|
||||
|
||||
bottom -= img0.size[1]
|
||||
for pn in pns[1:]:
|
||||
if not (0 <= pn < page_count):
|
||||
self.logger.warning(f"[PaddleOCR] Page index {pn} out of range for {page_count} pages during crop; skipping this page.")
|
||||
continue
|
||||
page = self.page_images[pn]
|
||||
x0, y0, x1, y1 = int(left), 0, int(right), int(min(bottom, page.size[1]))
|
||||
cimgp = page.crop((x0, y0, x1, y1))
|
||||
imgs.append(cimgp)
|
||||
if 0 < ii < len(poss) - 1:
|
||||
positions.append((pn + self.page_from, x0, x1, y0, y1))
|
||||
bottom -= page.size[1]
|
||||
|
||||
if not imgs:
|
||||
if need_position:
|
||||
return None, None
|
||||
return
|
||||
|
||||
height = 0
|
||||
for img in imgs:
|
||||
height += img.size[1] + GAP
|
||||
height = int(height)
|
||||
width = int(np.max([i.size[0] for i in imgs]))
|
||||
pic = Image.new("RGB", (width, height), (245, 245, 245))
|
||||
height = 0
|
||||
for ii, img in enumerate(imgs):
|
||||
if ii == 0 or ii + 1 == len(imgs):
|
||||
img = img.convert("RGBA")
|
||||
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
||||
overlay.putalpha(128)
|
||||
img = Image.alpha_composite(img, overlay).convert("RGB")
|
||||
pic.paste(img, (0, int(height)))
|
||||
height += img.size[1] + GAP
|
||||
|
||||
if need_position:
|
||||
return pic, positions
|
||||
return pic
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
parser = PaddleOCRParser(api_url=os.getenv("PADDLEOCR_API_URL", ""), algorithm=os.getenv("PADDLEOCR_ALGORITHM", "PaddleOCR-VL"))
|
||||
ok, reason = parser.check_installation()
|
||||
print("PaddleOCR available:", ok, reason)
|
||||
@ -43,6 +43,10 @@ from rag.nlp import rag_tokenizer
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
from common import settings
|
||||
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||
@ -88,6 +92,7 @@ class RAGFlowPdfParser:
|
||||
try:
|
||||
pip_install_torch()
|
||||
import torch.cuda
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.updown_cnt_mdl.set_param({"device": "cuda"})
|
||||
except Exception:
|
||||
@ -192,13 +197,112 @@ class RAGFlowPdfParser:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _table_transformer_job(self, ZM):
|
||||
def _evaluate_table_orientation(self, table_img, sample_ratio=0.3):
|
||||
"""
|
||||
Evaluate the best rotation orientation for a table image.
|
||||
|
||||
Tests 4 rotation angles (0°, 90°, 180°, 270°) and uses OCR
|
||||
confidence scores to determine the best orientation.
|
||||
|
||||
Args:
|
||||
table_img: PIL Image object of the table region
|
||||
sample_ratio: Sampling ratio for quick evaluation
|
||||
|
||||
Returns:
|
||||
tuple: (best_angle, best_img, confidence_scores)
|
||||
- best_angle: Best rotation angle (0, 90, 180, 270)
|
||||
- best_img: Image rotated to best orientation
|
||||
- confidence_scores: Dict of scores for each angle
|
||||
"""
|
||||
|
||||
rotations = [
|
||||
(0, "original"),
|
||||
(90, "rotate_90"), # clockwise 90°
|
||||
(180, "rotate_180"), # 180°
|
||||
(270, "rotate_270"), # clockwise 270° (counter-clockwise 90°)
|
||||
]
|
||||
|
||||
results = {}
|
||||
best_score = -1
|
||||
best_angle = 0
|
||||
best_img = table_img
|
||||
|
||||
for angle, name in rotations:
|
||||
# Rotate image
|
||||
if angle == 0:
|
||||
rotated_img = table_img
|
||||
else:
|
||||
# PIL's rotate is counter-clockwise, use negative angle for clockwise
|
||||
rotated_img = table_img.rotate(-angle, expand=True)
|
||||
|
||||
# Convert to numpy array for OCR
|
||||
img_array = np.array(rotated_img)
|
||||
|
||||
# Perform OCR detection and recognition
|
||||
try:
|
||||
ocr_results = self.ocr(img_array)
|
||||
|
||||
if ocr_results:
|
||||
# Calculate average confidence
|
||||
scores = [conf for _, (_, conf) in ocr_results]
|
||||
avg_score = sum(scores) / len(scores) if scores else 0
|
||||
total_regions = len(scores)
|
||||
|
||||
# Combined score: considers both average confidence and number of regions
|
||||
# More regions + higher confidence = better orientation
|
||||
combined_score = avg_score * (1 + 0.1 * min(total_regions, 50) / 50)
|
||||
else:
|
||||
avg_score = 0
|
||||
total_regions = 0
|
||||
combined_score = 0
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"OCR failed for angle {angle}: {e}")
|
||||
avg_score = 0
|
||||
total_regions = 0
|
||||
combined_score = 0
|
||||
|
||||
results[angle] = {"avg_confidence": avg_score, "total_regions": total_regions, "combined_score": combined_score}
|
||||
|
||||
logging.debug(f"Table orientation {angle}°: avg_conf={avg_score:.4f}, regions={total_regions}, combined={combined_score:.4f}")
|
||||
|
||||
if combined_score > best_score:
|
||||
best_score = combined_score
|
||||
best_angle = angle
|
||||
best_img = rotated_img
|
||||
|
||||
logging.info(f"Best table orientation: {best_angle}° (score={best_score:.4f})")
|
||||
|
||||
return best_angle, best_img, results
|
||||
|
||||
def _table_transformer_job(self, ZM, auto_rotate=True):
|
||||
"""
|
||||
Process table structure recognition.
|
||||
|
||||
When auto_rotate=True, the complete workflow:
|
||||
1. Evaluate table orientation and select the best rotation angle
|
||||
2. Use rotated image for table structure recognition (TSR)
|
||||
3. Re-OCR the rotated image
|
||||
4. Match new OCR results with TSR cell coordinates
|
||||
|
||||
Args:
|
||||
ZM: Zoom factor
|
||||
auto_rotate: Whether to enable auto orientation correction
|
||||
"""
|
||||
logging.debug("Table processing...")
|
||||
imgs, pos = [], []
|
||||
tbcnt = [0]
|
||||
MARGIN = 10
|
||||
self.tb_cpns = []
|
||||
self.table_rotations = {} # Store rotation info for each table
|
||||
self.rotated_table_imgs = {} # Store rotated table images
|
||||
|
||||
assert len(self.page_layout) == len(self.page_images)
|
||||
|
||||
# Collect layout info for all tables
|
||||
table_layouts = [] # [(page, table_layout, left, top, right, bott), ...]
|
||||
|
||||
table_index = 0
|
||||
for p, tbls in enumerate(self.page_layout): # for page
|
||||
tbls = [f for f in tbls if f["type"] == "table"]
|
||||
tbcnt.append(len(tbls))
|
||||
@ -210,29 +314,70 @@ class RAGFlowPdfParser:
|
||||
top *= ZM
|
||||
right *= ZM
|
||||
bott *= ZM
|
||||
pos.append((left, top))
|
||||
imgs.append(self.page_images[p].crop((left, top, right, bott)))
|
||||
pos.append((left, top, p, table_index)) # Add page and table_index
|
||||
|
||||
# Record table layout info
|
||||
table_layouts.append({"page": p, "table_index": table_index, "layout": tb, "coords": (left, top, right, bott)})
|
||||
|
||||
# Crop table image
|
||||
table_img = self.page_images[p].crop((left, top, right, bott))
|
||||
|
||||
if auto_rotate:
|
||||
# Evaluate table orientation
|
||||
logging.debug(f"Evaluating orientation for table {table_index} on page {p}")
|
||||
best_angle, rotated_img, rotation_scores = self._evaluate_table_orientation(table_img)
|
||||
|
||||
# Store rotation info
|
||||
self.table_rotations[table_index] = {
|
||||
"page": p,
|
||||
"original_pos": (left, top, right, bott),
|
||||
"best_angle": best_angle,
|
||||
"scores": rotation_scores,
|
||||
"rotated_size": rotated_img.size, # (width, height)
|
||||
}
|
||||
|
||||
# Store the rotated image
|
||||
self.rotated_table_imgs[table_index] = rotated_img
|
||||
imgs.append(rotated_img)
|
||||
|
||||
if best_angle != 0:
|
||||
logging.info(f"Table {table_index} on page {p}: rotated {best_angle}° for better recognition")
|
||||
else:
|
||||
imgs.append(table_img)
|
||||
self.table_rotations[table_index] = {"page": p, "original_pos": (left, top, right, bott), "best_angle": 0, "scores": {}, "rotated_size": table_img.size}
|
||||
self.rotated_table_imgs[table_index] = table_img
|
||||
|
||||
table_index += 1
|
||||
|
||||
assert len(self.page_images) == len(tbcnt) - 1
|
||||
if not imgs:
|
||||
return
|
||||
|
||||
# Perform table structure recognition (TSR)
|
||||
recos = self.tbl_det(imgs)
|
||||
|
||||
# If tables were rotated, re-OCR the rotated images and replace table boxes
|
||||
if auto_rotate:
|
||||
self._ocr_rotated_tables(ZM, table_layouts, recos, tbcnt)
|
||||
|
||||
# Process TSR results (keep original logic but handle rotated coordinates)
|
||||
tbcnt = np.cumsum(tbcnt)
|
||||
for i in range(len(tbcnt) - 1): # for page
|
||||
pg = []
|
||||
for j, tb_items in enumerate(recos[tbcnt[i] : tbcnt[i + 1]]): # for table
|
||||
poss = pos[tbcnt[i] : tbcnt[i + 1]]
|
||||
for it in tb_items: # for table components
|
||||
it["x0"] = it["x0"] + poss[j][0]
|
||||
it["x1"] = it["x1"] + poss[j][0]
|
||||
it["top"] = it["top"] + poss[j][1]
|
||||
it["bottom"] = it["bottom"] + poss[j][1]
|
||||
for n in ["x0", "x1", "top", "bottom"]:
|
||||
it[n] /= ZM
|
||||
it["top"] += self.page_cum_height[i]
|
||||
it["bottom"] += self.page_cum_height[i]
|
||||
it["pn"] = i
|
||||
# TSR coordinates are relative to rotated image, need to record
|
||||
it["x0_rotated"] = it["x0"]
|
||||
it["x1_rotated"] = it["x1"]
|
||||
it["top_rotated"] = it["top"]
|
||||
it["bottom_rotated"] = it["bottom"]
|
||||
|
||||
# For rotated tables, coordinate transformation to page space requires rotation
|
||||
# Since we already re-OCR'd on rotated image, keep simple processing here
|
||||
it["pn"] = poss[j][2] # page number
|
||||
it["layoutno"] = j
|
||||
it["table_index"] = poss[j][3] # table index
|
||||
pg.append(it)
|
||||
self.tb_cpns.extend(pg)
|
||||
|
||||
@ -245,8 +390,9 @@ class RAGFlowPdfParser:
|
||||
headers = gather(r".*header$")
|
||||
rows = gather(r".* (row|header)")
|
||||
spans = gather(r".*spanning")
|
||||
clmns = sorted([r for r in self.tb_cpns if re.match(r"table column$", r["label"])], key=lambda x: (x["pn"], x["layoutno"], x["x0"]))
|
||||
clmns = sorted([r for r in self.tb_cpns if re.match(r"table column$", r["label"])], key=lambda x: (x["pn"], x["layoutno"], x["x0_rotated"] if "x0_rotated" in x else x["x0"]))
|
||||
clmns = Recognizer.layouts_cleanup(self.boxes, clmns, 5, 0.5)
|
||||
|
||||
for b in self.boxes:
|
||||
if b.get("layout_type", "") != "table":
|
||||
continue
|
||||
@ -278,6 +424,109 @@ class RAGFlowPdfParser:
|
||||
b["H_right"] = spans[ii]["x1"]
|
||||
b["SP"] = ii
|
||||
|
||||
def _ocr_rotated_tables(self, ZM, table_layouts, tsr_results, tbcnt):
|
||||
"""
|
||||
Re-OCR rotated table images and update self.boxes.
|
||||
|
||||
Args:
|
||||
ZM: Zoom factor
|
||||
table_layouts: List of table layout info
|
||||
tsr_results: TSR recognition results
|
||||
tbcnt: Cumulative table count per page
|
||||
"""
|
||||
tbcnt = np.cumsum(tbcnt)
|
||||
|
||||
for tbl_info in table_layouts:
|
||||
table_index = tbl_info["table_index"]
|
||||
page = tbl_info["page"]
|
||||
layout = tbl_info["layout"]
|
||||
left, top, right, bott = tbl_info["coords"]
|
||||
|
||||
rotation_info = self.table_rotations.get(table_index, {})
|
||||
best_angle = rotation_info.get("best_angle", 0)
|
||||
|
||||
# Get the rotated table image
|
||||
rotated_img = self.rotated_table_imgs.get(table_index)
|
||||
if rotated_img is None:
|
||||
continue
|
||||
|
||||
# If table was rotated, re-OCR the rotated image
|
||||
if best_angle != 0:
|
||||
logging.info(f"Re-OCR table {table_index} on page {page} with rotation {best_angle}°")
|
||||
|
||||
# Perform OCR on rotated image
|
||||
img_array = np.array(rotated_img)
|
||||
ocr_results = self.ocr(img_array)
|
||||
|
||||
if not ocr_results:
|
||||
logging.warning(f"No OCR results for rotated table {table_index}")
|
||||
continue
|
||||
|
||||
# Remove original text boxes from this table region in self.boxes
|
||||
# Table region is defined by layout's x0, top, x1, bottom
|
||||
table_x0 = layout["x0"]
|
||||
table_top = layout["top"]
|
||||
table_x1 = layout["x1"]
|
||||
table_bottom = layout["bottom"]
|
||||
|
||||
# Filter out original boxes within the table region
|
||||
original_box_count = len(self.boxes)
|
||||
self.boxes = [
|
||||
b
|
||||
for b in self.boxes
|
||||
if not (
|
||||
b.get("page_number") == page + self.page_from
|
||||
and b.get("layout_type") == "table"
|
||||
and b["x0"] >= table_x0 - 5
|
||||
and b["x1"] <= table_x1 + 5
|
||||
and b["top"] >= table_top - 5
|
||||
and b["bottom"] <= table_bottom + 5
|
||||
)
|
||||
]
|
||||
removed_count = original_box_count - len(self.boxes)
|
||||
logging.debug(f"Removed {removed_count} original boxes from table {table_index}")
|
||||
|
||||
# Add new OCR results to self.boxes
|
||||
# OCR coordinates are relative to rotated image, need to preserve
|
||||
rotated_width, rotated_height = rotated_img.size
|
||||
|
||||
for bbox, (text, conf) in ocr_results:
|
||||
if conf < 0.5: # Filter low confidence results
|
||||
continue
|
||||
|
||||
# bbox format: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
||||
x_coords = [p[0] for p in bbox]
|
||||
y_coords = [p[1] for p in bbox]
|
||||
|
||||
# Coordinates in rotated image
|
||||
box_x0 = min(x_coords) / ZM
|
||||
box_x1 = max(x_coords) / ZM
|
||||
box_top = min(y_coords) / ZM
|
||||
box_bottom = max(y_coords) / ZM
|
||||
|
||||
# Create new box, mark as from rotated table
|
||||
new_box = {
|
||||
"text": text,
|
||||
"x0": box_x0 + table_x0, # Coordinates relative to page
|
||||
"x1": box_x1 + table_x0,
|
||||
"top": box_top + table_top + self.page_cum_height[page],
|
||||
"bottom": box_bottom + table_top + self.page_cum_height[page],
|
||||
"page_number": page + self.page_from,
|
||||
"layout_type": "table",
|
||||
"layoutno": f"table-{table_index}",
|
||||
"_rotated": True,
|
||||
"_rotation_angle": best_angle,
|
||||
"_table_index": table_index,
|
||||
# Save original coordinates in rotated image for table reconstruction
|
||||
"_rotated_x0": box_x0,
|
||||
"_rotated_x1": box_x1,
|
||||
"_rotated_top": box_top,
|
||||
"_rotated_bottom": box_bottom,
|
||||
}
|
||||
self.boxes.append(new_box)
|
||||
|
||||
logging.info(f"Added {len(ocr_results)} OCR results from rotated table {table_index}")
|
||||
|
||||
def __ocr(self, pagenum, img, chars, ZM=3, device_id: int | None = None):
|
||||
start = timer()
|
||||
bxs = self.ocr.detect(np.array(img), device_id)
|
||||
@ -408,11 +657,9 @@ class RAGFlowPdfParser:
|
||||
page_cols[pg] = best_k
|
||||
logging.info(f"[Page {pg}] best_score={best_score:.2f}, best_k={best_k}")
|
||||
|
||||
|
||||
global_cols = Counter(page_cols.values()).most_common(1)[0][0]
|
||||
logging.info(f"Global column_num decided by majority: {global_cols}")
|
||||
|
||||
|
||||
for pg, bxs in by_page.items():
|
||||
if not bxs:
|
||||
continue
|
||||
@ -476,7 +723,7 @@ class RAGFlowPdfParser:
|
||||
self.boxes = bxs
|
||||
|
||||
def _naive_vertical_merge(self, zoomin=3):
|
||||
#bxs = self._assign_column(self.boxes, zoomin)
|
||||
# bxs = self._assign_column(self.boxes, zoomin)
|
||||
bxs = self.boxes
|
||||
|
||||
grouped = defaultdict(list)
|
||||
@ -553,7 +800,8 @@ class RAGFlowPdfParser:
|
||||
|
||||
merged_boxes.extend(bxs)
|
||||
|
||||
#self.boxes = sorted(merged_boxes, key=lambda x: (x["page_number"], x.get("col_id", 0), x["top"]))
|
||||
# self.boxes = sorted(merged_boxes, key=lambda x: (x["page_number"], x.get("col_id", 0), x["top"]))
|
||||
self.boxes = merged_boxes
|
||||
|
||||
def _final_reading_order_merge(self, zoomin=3):
|
||||
if not self.boxes:
|
||||
@ -1113,7 +1361,7 @@ class RAGFlowPdfParser:
|
||||
|
||||
if limiter:
|
||||
async with limiter:
|
||||
await asyncio.to_thread(self.__ocr, i + 1, img, chars, zoomin, id)
|
||||
await thread_pool_exec(self.__ocr, i + 1, img, chars, zoomin, id)
|
||||
else:
|
||||
self.__ocr(i + 1, img, chars, zoomin, id)
|
||||
|
||||
@ -1179,10 +1427,26 @@ class RAGFlowPdfParser:
|
||||
if len(self.boxes) == 0 and zoomin < 9:
|
||||
self.__images__(fnm, zoomin * 3, page_from, page_to, callback)
|
||||
|
||||
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
|
||||
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False, auto_rotate_tables=None):
|
||||
"""
|
||||
Parse a PDF file.
|
||||
|
||||
Args:
|
||||
fnm: PDF file path or binary content
|
||||
need_image: Whether to extract images
|
||||
zoomin: Zoom factor
|
||||
return_html: Whether to return tables in HTML format
|
||||
auto_rotate_tables: Whether to enable auto orientation correction for tables.
|
||||
None: Use TABLE_AUTO_ROTATE env var setting (default: True)
|
||||
True: Enable auto orientation correction
|
||||
False: Disable auto orientation correction
|
||||
"""
|
||||
if auto_rotate_tables is None:
|
||||
auto_rotate_tables = os.getenv("TABLE_AUTO_ROTATE", "true").lower() in ("true", "1", "yes")
|
||||
|
||||
self.__images__(fnm, zoomin)
|
||||
self._layouts_rec(zoomin)
|
||||
self._table_transformer_job(zoomin)
|
||||
self._table_transformer_job(zoomin, auto_rotate=auto_rotate_tables)
|
||||
self._text_merge()
|
||||
self._concat_downward()
|
||||
self._filter_forpages()
|
||||
@ -1200,8 +1464,11 @@ class RAGFlowPdfParser:
|
||||
if callback:
|
||||
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
# Read table auto-rotation setting from environment variable
|
||||
auto_rotate_tables = os.getenv("TABLE_AUTO_ROTATE", "true").lower() in ("true", "1", "yes")
|
||||
|
||||
start = timer()
|
||||
self._table_transformer_job(zoomin)
|
||||
self._table_transformer_job(zoomin, auto_rotate=auto_rotate_tables)
|
||||
if callback:
|
||||
callback(0.83, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
@ -1493,10 +1760,7 @@ class VisionParser(RAGFlowPdfParser):
|
||||
|
||||
if text:
|
||||
width, height = self.page_images[idx].size
|
||||
all_docs.append((
|
||||
text,
|
||||
f"@@{pdf_page_num + 1}\t{0.0:.1f}\t{width / zoomin:.1f}\t{0.0:.1f}\t{height / zoomin:.1f}##"
|
||||
))
|
||||
all_docs.append((text, f"@@{pdf_page_num + 1}\t{0.0:.1f}\t{width / zoomin:.1f}\t{0.0:.1f}\t{height / zoomin:.1f}##"))
|
||||
return all_docs, []
|
||||
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
@ -48,10 +49,10 @@ class TencentCloudAPIClient:
|
||||
self.secret_key = secret_key
|
||||
self.region = region
|
||||
self.outlines = []
|
||||
|
||||
|
||||
# Create credentials
|
||||
self.cred = credential.Credential(secret_id, secret_key)
|
||||
|
||||
|
||||
# Instantiate an http option, optional, can be skipped if no special requirements
|
||||
self.httpProfile = HttpProfile()
|
||||
self.httpProfile.endpoint = "lkeap.tencentcloudapi.com"
|
||||
@ -59,7 +60,7 @@ class TencentCloudAPIClient:
|
||||
# Instantiate a client option, optional, can be skipped if no special requirements
|
||||
self.clientProfile = ClientProfile()
|
||||
self.clientProfile.httpProfile = self.httpProfile
|
||||
|
||||
|
||||
# Instantiate the client object for the product to be requested, clientProfile is optional
|
||||
self.client = lkeap_client.LkeapClient(self.cred, region, self.clientProfile)
|
||||
|
||||
@ -68,14 +69,14 @@ class TencentCloudAPIClient:
|
||||
try:
|
||||
# Instantiate a request object, each interface corresponds to a request object
|
||||
req = models.ReconstructDocumentSSERequest()
|
||||
|
||||
|
||||
# Build request parameters
|
||||
params = {
|
||||
"FileType": file_type,
|
||||
"FileStartPageNumber": file_start_page,
|
||||
"FileEndPageNumber": file_end_page,
|
||||
}
|
||||
|
||||
|
||||
# According to Tencent Cloud API documentation, either FileUrl or FileBase64 parameter must be provided, if both are provided only FileUrl will be used
|
||||
if file_url:
|
||||
params["FileUrl"] = file_url
|
||||
@ -94,7 +95,7 @@ class TencentCloudAPIClient:
|
||||
# The returned resp is an instance of ReconstructDocumentSSEResponse, corresponding to the request object
|
||||
resp = self.client.ReconstructDocumentSSE(req)
|
||||
parser_result = {}
|
||||
|
||||
|
||||
# Output json format string response
|
||||
if isinstance(resp, types.GeneratorType): # Streaming response
|
||||
logging.info("[TCADP] Detected streaming response")
|
||||
@ -104,7 +105,7 @@ class TencentCloudAPIClient:
|
||||
try:
|
||||
data_dict = json.loads(event['data'])
|
||||
logging.info(f"[TCADP] Parsed data: {data_dict}")
|
||||
|
||||
|
||||
if data_dict.get('Progress') == "100":
|
||||
parser_result = data_dict
|
||||
logging.info("[TCADP] Document parsing completed!")
|
||||
@ -118,14 +119,14 @@ class TencentCloudAPIClient:
|
||||
logging.warning("[TCADP] Failed parsing pages:")
|
||||
for page in failed_pages:
|
||||
logging.warning(f"[TCADP] Page number: {page.get('PageNumber')}, Error: {page.get('ErrorMsg')}")
|
||||
|
||||
|
||||
# Check if there is a download link
|
||||
download_url = data_dict.get("DocumentRecognizeResultUrl")
|
||||
if download_url:
|
||||
logging.info(f"[TCADP] Got download link: {download_url}")
|
||||
else:
|
||||
logging.warning("[TCADP] No download link obtained")
|
||||
|
||||
|
||||
break # Found final result, exit loop
|
||||
else:
|
||||
# Print progress information
|
||||
@ -168,9 +169,6 @@ class TencentCloudAPIClient:
|
||||
return None
|
||||
|
||||
try:
|
||||
response = requests.get(download_url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Ensure output directory exists
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
@ -179,29 +177,36 @@ class TencentCloudAPIClient:
|
||||
filename = f"tcadp_result_{timestamp}.zip"
|
||||
file_path = os.path.join(output_dir, filename)
|
||||
|
||||
# Save file
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
with requests.get(download_url, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
response.raw.decode_content = True
|
||||
shutil.copyfileobj(response.raw, f)
|
||||
|
||||
logging.info(f"[TCADP] Document parsing result downloaded to: {os.path.basename(file_path)}")
|
||||
return file_path
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
logging.error(f"[TCADP] Failed to download file: {e}")
|
||||
try:
|
||||
if "file_path" in locals() and os.path.exists(file_path):
|
||||
os.unlink(file_path)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
class TCADPParser(RAGFlowPdfParser):
|
||||
def __init__(self, secret_id: str = None, secret_key: str = None, region: str = "ap-guangzhou",
|
||||
def __init__(self, secret_id: str = None, secret_key: str = None, region: str = "ap-guangzhou",
|
||||
table_result_type: str = None, markdown_image_response_type: str = None):
|
||||
super().__init__()
|
||||
|
||||
|
||||
# First initialize logger
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
|
||||
# Log received parameters
|
||||
self.logger.info(f"[TCADP] Initializing with parameters - table_result_type: {table_result_type}, markdown_image_response_type: {markdown_image_response_type}")
|
||||
|
||||
|
||||
# Priority: read configuration from RAGFlow configuration system (service_conf.yaml)
|
||||
try:
|
||||
tcadp_parser = get_base_config("tcadp_config", {})
|
||||
@ -212,7 +217,7 @@ class TCADPParser(RAGFlowPdfParser):
|
||||
# Set table_result_type and markdown_image_response_type from config or parameters
|
||||
self.table_result_type = table_result_type if table_result_type is not None else tcadp_parser.get("table_result_type", "1")
|
||||
self.markdown_image_response_type = markdown_image_response_type if markdown_image_response_type is not None else tcadp_parser.get("markdown_image_response_type", "1")
|
||||
|
||||
|
||||
else:
|
||||
self.logger.error("[TCADP] Please configure tcadp_config in service_conf.yaml first")
|
||||
# If config file is empty, use provided parameters or defaults
|
||||
@ -237,6 +242,10 @@ class TCADPParser(RAGFlowPdfParser):
|
||||
if not self.secret_id or not self.secret_key:
|
||||
raise ValueError("[TCADP] Please set Tencent Cloud API keys, configure tcadp_config in service_conf.yaml")
|
||||
|
||||
@staticmethod
|
||||
def _is_zipinfo_symlink(member: zipfile.ZipInfo) -> bool:
|
||||
return (member.external_attr >> 16) & 0o170000 == 0o120000
|
||||
|
||||
def check_installation(self) -> bool:
|
||||
"""Check if Tencent Cloud API configuration is correct"""
|
||||
try:
|
||||
@ -255,7 +264,7 @@ class TCADPParser(RAGFlowPdfParser):
|
||||
|
||||
def _file_to_base64(self, file_path: str, binary: bytes = None) -> str:
|
||||
"""Convert file to Base64 format"""
|
||||
|
||||
|
||||
if binary:
|
||||
# If binary data is directly available, convert directly
|
||||
return base64.b64encode(binary).decode('utf-8')
|
||||
@ -271,23 +280,34 @@ class TCADPParser(RAGFlowPdfParser):
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_file:
|
||||
# Find JSON result files
|
||||
json_files = [f for f in zip_file.namelist() if f.endswith(".json")]
|
||||
members = zip_file.infolist()
|
||||
for member in members:
|
||||
name = member.filename.replace("\\", "/")
|
||||
if member.is_dir():
|
||||
continue
|
||||
if member.flag_bits & 0x1:
|
||||
raise RuntimeError(f"[TCADP] Encrypted zip entry not supported: {member.filename}")
|
||||
if self._is_zipinfo_symlink(member):
|
||||
raise RuntimeError(f"[TCADP] Symlink zip entry not supported: {member.filename}")
|
||||
if name.startswith("/") or name.startswith("//") or re.match(r"^[A-Za-z]:", name):
|
||||
raise RuntimeError(f"[TCADP] Unsafe zip path (absolute): {member.filename}")
|
||||
parts = [p for p in name.split("/") if p not in ("", ".")]
|
||||
if any(p == ".." for p in parts):
|
||||
raise RuntimeError(f"[TCADP] Unsafe zip path (traversal): {member.filename}")
|
||||
|
||||
for json_file in json_files:
|
||||
with zip_file.open(json_file) as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, list):
|
||||
results.extend(data)
|
||||
if not (name.endswith(".json") or name.endswith(".md")):
|
||||
continue
|
||||
|
||||
with zip_file.open(member) as f:
|
||||
if name.endswith(".json"):
|
||||
data = json.load(f)
|
||||
if isinstance(data, list):
|
||||
results.extend(data)
|
||||
else:
|
||||
results.append(data)
|
||||
else:
|
||||
results.append(data)
|
||||
|
||||
# Find Markdown files
|
||||
md_files = [f for f in zip_file.namelist() if f.endswith(".md")]
|
||||
for md_file in md_files:
|
||||
with zip_file.open(md_file) as f:
|
||||
content = f.read().decode("utf-8")
|
||||
results.append({"type": "text", "content": content, "file": md_file})
|
||||
content = f.read().decode("utf-8")
|
||||
results.append({"type": "text", "content": content, "file": name})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"[TCADP] Failed to extract ZIP file content: {e}")
|
||||
@ -395,7 +415,7 @@ class TCADPParser(RAGFlowPdfParser):
|
||||
# Convert file to Base64 format
|
||||
if callback:
|
||||
callback(0.2, "[TCADP] Converting file to Base64 format")
|
||||
|
||||
|
||||
file_base64 = self._file_to_base64(file_path, binary)
|
||||
if callback:
|
||||
callback(0.25, f"[TCADP] File converted to Base64, size: {len(file_base64)} characters")
|
||||
@ -420,23 +440,23 @@ class TCADPParser(RAGFlowPdfParser):
|
||||
"TableResultType": self.table_result_type,
|
||||
"MarkdownImageResponseType": self.markdown_image_response_type
|
||||
}
|
||||
|
||||
|
||||
self.logger.info(f"[TCADP] API request config - TableResultType: {self.table_result_type}, MarkdownImageResponseType: {self.markdown_image_response_type}")
|
||||
|
||||
result = client.reconstruct_document_sse(
|
||||
file_type=file_type,
|
||||
file_base64=file_base64,
|
||||
file_start_page=file_start_page,
|
||||
file_end_page=file_end_page,
|
||||
file_type=file_type,
|
||||
file_base64=file_base64,
|
||||
file_start_page=file_start_page,
|
||||
file_end_page=file_end_page,
|
||||
config=config
|
||||
)
|
||||
|
||||
|
||||
if result:
|
||||
self.logger.info(f"[TCADP] Attempt {attempt + 1} successful")
|
||||
break
|
||||
else:
|
||||
self.logger.warning(f"[TCADP] Attempt {attempt + 1} failed, result is None")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"[TCADP] Attempt {attempt + 1} exception: {e}")
|
||||
if attempt == max_retries - 1:
|
||||
|
||||
@ -18,6 +18,10 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
sys.path.insert(
|
||||
0,
|
||||
os.path.abspath(
|
||||
@ -64,9 +68,9 @@ def main(args):
|
||||
if limiter:
|
||||
async with limiter:
|
||||
print(f"Task {i} use device {id}")
|
||||
await asyncio.to_thread(__ocr, i, id, img)
|
||||
await thread_pool_exec(__ocr, i, id, img)
|
||||
else:
|
||||
await asyncio.to_thread(__ocr, i, id, img)
|
||||
await thread_pool_exec(__ocr, i, id, img)
|
||||
|
||||
|
||||
async def __ocr_launcher():
|
||||
|
||||
24
docker/.env
24
docker/.env
@ -16,6 +16,7 @@
|
||||
# - `infinity` (https://github.com/infiniflow/infinity)
|
||||
# - `oceanbase` (https://github.com/oceanbase/oceanbase)
|
||||
# - `opensearch` (https://github.com/opensearch-project/OpenSearch)
|
||||
# - `seekdb` (https://github.com/oceanbase/seekdb)
|
||||
DOC_ENGINE=${DOC_ENGINE:-elasticsearch}
|
||||
|
||||
# Device on which deepdoc inference run.
|
||||
@ -92,6 +93,19 @@ OB_SYSTEM_MEMORY=${OB_SYSTEM_MEMORY:-2G}
|
||||
OB_DATAFILE_SIZE=${OB_DATAFILE_SIZE:-20G}
|
||||
OB_LOG_DISK_SIZE=${OB_LOG_DISK_SIZE:-20G}
|
||||
|
||||
# The hostname where the SeekDB service is exposed
|
||||
SEEKDB_HOST=seekdb
|
||||
# The port used to expose the SeekDB service
|
||||
SEEKDB_PORT=2881
|
||||
# The username for SeekDB
|
||||
SEEKDB_USER=root
|
||||
# The password for SeekDB
|
||||
SEEKDB_PASSWORD=infini_rag_flow
|
||||
# The doc database of the SeekDB service to use
|
||||
SEEKDB_DOC_DBNAME=ragflow_doc
|
||||
# SeekDB memory limit
|
||||
SEEKDB_MEMORY_LIMIT=2G
|
||||
|
||||
# The password for MySQL.
|
||||
# WARNING: Change this for production!
|
||||
MYSQL_PASSWORD=infini_rag_flow
|
||||
@ -99,9 +113,12 @@ MYSQL_PASSWORD=infini_rag_flow
|
||||
MYSQL_HOST=mysql
|
||||
# The database of the MySQL service to use
|
||||
MYSQL_DBNAME=rag_flow
|
||||
# The port used to connect to MySQL from RAGFlow container.
|
||||
# Change this if you use external MySQL.
|
||||
MYSQL_PORT=3306
|
||||
# The port used to expose the MySQL service to the host machine,
|
||||
# allowing EXTERNAL access to the MySQL database running inside the Docker container.
|
||||
MYSQL_PORT=5455
|
||||
EXPOSE_MYSQL_PORT=5455
|
||||
# The maximum size of communication packets sent to the MySQL server
|
||||
MYSQL_MAX_PACKET=1073741824
|
||||
|
||||
@ -210,6 +227,7 @@ EMBEDDING_BATCH_SIZE=${EMBEDDING_BATCH_SIZE:-16}
|
||||
# ENDPOINT=http://oss-cn-hangzhou.aliyuncs.com
|
||||
# REGION=cn-hangzhou
|
||||
# BUCKET=ragflow65536
|
||||
#
|
||||
|
||||
# A user registration switch:
|
||||
# - Enable registration: 1
|
||||
@ -255,3 +273,7 @@ DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
|
||||
# RAGFLOW_CRYPTO_ENABLED=true
|
||||
# RAGFLOW_CRYPTO_ALGORITHM=aes-256-cbc # one of aes-256-cbc, aes-128-cbc, sm4-cbc
|
||||
# RAGFLOW_CRYPTO_KEY=ragflow-crypto-key
|
||||
|
||||
|
||||
# Used for ThreadPoolExecutor
|
||||
THREAD_POOL_MAX_WORKERS=128
|
||||
@ -52,6 +52,8 @@ The [.env](./.env) file contains important environment variables for Docker.
|
||||
- `MYSQL_PASSWORD`
|
||||
The password for MySQL.
|
||||
- `MYSQL_PORT`
|
||||
The port to connect to MySQL from RAGFlow container. Defaults to `3306`. Change this if you use an external MySQL.
|
||||
- `EXPOSE_MYSQL_PORT`
|
||||
The port used to expose the MySQL service to the host machine, allowing **external** access to the MySQL database running inside the Docker container. Defaults to `5455`.
|
||||
|
||||
### MinIO
|
||||
|
||||
@ -72,7 +72,7 @@ services:
|
||||
infinity:
|
||||
profiles:
|
||||
- infinity
|
||||
image: infiniflow/infinity:v0.6.15
|
||||
image: infiniflow/infinity:v0.7.0-dev1
|
||||
volumes:
|
||||
- infinity_data:/var/infinity
|
||||
- ./infinity_conf.toml:/infinity_conf.toml
|
||||
@ -121,6 +121,30 @@ services:
|
||||
- ragflow
|
||||
restart: unless-stopped
|
||||
|
||||
seekdb:
|
||||
profiles:
|
||||
- seekdb
|
||||
image: oceanbase/seekdb:latest
|
||||
container_name: seekdb
|
||||
volumes:
|
||||
- ./seekdb:/var/lib/oceanbase
|
||||
ports:
|
||||
- ${SEEKDB_PORT:-2881}:2881
|
||||
env_file: .env
|
||||
environment:
|
||||
- ROOT_PASSWORD=${SEEKDB_PASSWORD:-infini_rag_flow}
|
||||
- MEMORY_LIMIT=${SEEKDB_MEMORY_LIMIT:-2G}
|
||||
- REPORTER=ragflow-seekdb
|
||||
mem_limit: ${MEM_LIMIT}
|
||||
healthcheck:
|
||||
test: ['CMD-SHELL', 'mysql -h127.0.0.1 -P2881 -uroot -p${SEEKDB_PASSWORD:-infini_rag_flow} -e "CREATE DATABASE IF NOT EXISTS ${SEEKDB_DOC_DBNAME:-ragflow_doc};"']
|
||||
interval: 5s
|
||||
retries: 60
|
||||
timeout: 5s
|
||||
networks:
|
||||
- ragflow
|
||||
restart: unless-stopped
|
||||
|
||||
sandbox-executor-manager:
|
||||
profiles:
|
||||
- sandbox
|
||||
@ -164,7 +188,7 @@ services:
|
||||
--init-file /data/application/init.sql
|
||||
--binlog_expire_logs_seconds=604800
|
||||
ports:
|
||||
- ${MYSQL_PORT}:3306
|
||||
- ${EXPOSE_MYSQL_PORT}:3306
|
||||
volumes:
|
||||
- mysql_data:/var/lib/mysql
|
||||
- ./init.sql:/data/application/init.sql
|
||||
@ -283,6 +307,8 @@ volumes:
|
||||
driver: local
|
||||
ob_data:
|
||||
driver: local
|
||||
seekdb_data:
|
||||
driver: local
|
||||
mysql_data:
|
||||
driver: local
|
||||
minio_data:
|
||||
|
||||
@ -39,7 +39,6 @@ services:
|
||||
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
|
||||
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
|
||||
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
|
||||
- ../history_data_agent:/ragflow/history_data_agent
|
||||
- ./service_conf.yaml.template:/ragflow/conf/service_conf.yaml.template
|
||||
- ./entrypoint.sh:/ragflow/entrypoint.sh
|
||||
env_file: .env
|
||||
@ -88,7 +87,6 @@ services:
|
||||
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
|
||||
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
|
||||
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
|
||||
- ../history_data_agent:/ragflow/history_data_agent
|
||||
- ./service_conf.yaml.template:/ragflow/conf/service_conf.yaml.template
|
||||
- ./entrypoint.sh:/ragflow/entrypoint.sh
|
||||
env_file: .env
|
||||
|
||||
@ -156,8 +156,20 @@ TEMPLATE_FILE="${CONF_DIR}/service_conf.yaml.template"
|
||||
CONF_FILE="${CONF_DIR}/service_conf.yaml"
|
||||
|
||||
rm -f "${CONF_FILE}"
|
||||
DEF_ENV_VALUE_PATTERN="\$\{([^:]+):-([^}]+)\}"
|
||||
while IFS= read -r line || [[ -n "$line" ]]; do
|
||||
eval "echo \"$line\"" >> "${CONF_FILE}"
|
||||
if [[ "$line" =~ DEF_ENV_VALUE_PATTERN ]]; then
|
||||
varname="${BASH_REMATCH[1]}"
|
||||
default="${BASH_REMATCH[2]}"
|
||||
|
||||
if [ -n "${!varname}" ]; then
|
||||
eval "echo \"$line"\" >> "${CONF_FILE}"
|
||||
else
|
||||
echo "$line" | sed -E "s/\\\$\{[^:]+:-([^}]+)\}/\1/g" >> "${CONF_FILE}"
|
||||
fi
|
||||
else
|
||||
eval "echo \"$line\"" >> "${CONF_FILE}"
|
||||
fi
|
||||
done < "${TEMPLATE_FILE}"
|
||||
|
||||
export LD_LIBRARY_PATH="/usr/lib/x86_64-linux-gnu/"
|
||||
@ -195,10 +207,9 @@ function start_mcp_server() {
|
||||
|
||||
function ensure_docling() {
|
||||
[[ "${USE_DOCLING}" == "true" ]] || { echo "[docling] disabled by USE_DOCLING"; return 0; }
|
||||
python3 -c 'import pip' >/dev/null 2>&1 || python3 -m ensurepip --upgrade || true
|
||||
DOCLING_PIN="${DOCLING_VERSION:-==2.58.0}"
|
||||
python3 -c "import importlib.util,sys; sys.exit(0 if importlib.util.find_spec('docling') else 1)" \
|
||||
|| python3 -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --extra-index-url https://pypi.org/simple --no-cache-dir "docling${DOCLING_PIN}"
|
||||
"$PY" -c "import importlib.util,sys; sys.exit(0 if importlib.util.find_spec('docling') else 1)" \
|
||||
|| uv pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --extra-index-url https://pypi.org/simple --no-cache-dir "docling${DOCLING_PIN}"
|
||||
}
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
[general]
|
||||
version = "0.6.15"
|
||||
version = "0.7.0"
|
||||
time_zone = "utc-8"
|
||||
|
||||
[network]
|
||||
|
||||
@ -9,7 +9,7 @@ mysql:
|
||||
user: '${MYSQL_USER:-root}'
|
||||
password: '${MYSQL_PASSWORD:-infini_rag_flow}'
|
||||
host: '${MYSQL_HOST:-mysql}'
|
||||
port: 3306
|
||||
port: ${MYSQL_PORT:-3306}
|
||||
max_connections: 900
|
||||
stale_timeout: 300
|
||||
max_allowed_packet: ${MYSQL_MAX_PACKET:-1073741824}
|
||||
@ -29,6 +29,7 @@ os:
|
||||
password: '${OPENSEARCH_PASSWORD:-infini_rag_flow_OS_01}'
|
||||
infinity:
|
||||
uri: '${INFINITY_HOST:-infinity}:23817'
|
||||
postgres_port: 5432
|
||||
db_name: 'default_db'
|
||||
oceanbase:
|
||||
scheme: 'oceanbase' # set 'mysql' to create connection using mysql config
|
||||
@ -38,6 +39,14 @@ oceanbase:
|
||||
password: '${OCEANBASE_PASSWORD:-infini_rag_flow}'
|
||||
host: '${OCEANBASE_HOST:-oceanbase}'
|
||||
port: ${OCEANBASE_PORT:-2881}
|
||||
seekdb:
|
||||
scheme: 'oceanbase' # SeekDB is the lite version of OceanBase
|
||||
config:
|
||||
db_name: '${SEEKDB_DOC_DBNAME:-ragflow_doc}'
|
||||
user: '${SEEKDB_USER:-root}'
|
||||
password: '${SEEKDB_PASSWORD:-infini_rag_flow}'
|
||||
host: '${SEEKDB_HOST:-seekdb}'
|
||||
port: ${SEEKDB_PORT:-2881}
|
||||
redis:
|
||||
db: 1
|
||||
username: '${REDIS_USERNAME:-}'
|
||||
@ -72,6 +81,8 @@ user_default_llm:
|
||||
# region: '${REGION}'
|
||||
# bucket: '${BUCKET}'
|
||||
# prefix_path: '${OSS_PREFIX_PATH}'
|
||||
# signature_version: 's3'
|
||||
# addressing_style: 'virtual'
|
||||
# azure:
|
||||
# auth_type: 'sas'
|
||||
# container_url: 'container_url'
|
||||
|
||||
@ -3,7 +3,7 @@ sidebar_position: 1
|
||||
slug: /what-is-rag
|
||||
---
|
||||
|
||||
# What is Retreival-Augmented-Generation (RAG)?
|
||||
# What is Retreival-Augmented-Generation (RAG)?
|
||||
|
||||
Since large language models (LLMs) became the focus of technology, their ability to handle general knowledge has been astonishing. However, when questions shift to internal corporate documents, proprietary knowledge bases, or real-time data, the limitations of LLMs become glaringly apparent: they cannot access private information outside their training data. Retrieval-Augmented Generation (RAG) was born precisely to address this core need. Before an LLM generates an answer, it first retrieves the most relevant context from an external knowledge base and inputs it as "reference material" to the LLM, thereby guiding it to produce accurate answers. In short, RAG elevates LLMs from "relying on memory" to "having evidence to rely on," significantly improving their accuracy and trustworthiness in specialized fields and real-time information queries.
|
||||
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
---
|
||||
sidebar_position: 1
|
||||
slug: /configurations
|
||||
sidebar_custom_props: {
|
||||
sidebarIcon: LucideCog
|
||||
}
|
||||
---
|
||||
|
||||
# Configuration
|
||||
|
||||
Configurations for deploying RAGFlow via Docker.
|
||||
@ -70,6 +72,8 @@ The [.env](https://github.com/infiniflow/ragflow/blob/main/docker/.env) file con
|
||||
- `MYSQL_PASSWORD`
|
||||
The password for MySQL.
|
||||
- `MYSQL_PORT`
|
||||
The port to connect to MySQL from RAGFlow container. Defaults to `3306`. Change this if you use an external MySQL.
|
||||
- `EXPOSE_MYSQL_PORT`
|
||||
The port used to expose the MySQL service to the host machine, allowing **external** access to the MySQL database running inside the Docker container. Defaults to `5455`.
|
||||
|
||||
### MinIO
|
||||
|
||||
@ -4,5 +4,8 @@
|
||||
"link": {
|
||||
"type": "generated-index",
|
||||
"description": "Miscellaneous contribution guides."
|
||||
},
|
||||
"customProps": {
|
||||
"sidebarIcon": "LucideHandshake"
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
---
|
||||
sidebar_position: 1
|
||||
slug: /contributing
|
||||
sidebar_custom_props: {
|
||||
categoryIcon: LucideBookA
|
||||
}
|
||||
---
|
||||
|
||||
# Contribution guidelines
|
||||
|
||||
General guidelines for RAGFlow's community contributors.
|
||||
|
||||
@ -4,5 +4,8 @@
|
||||
"link": {
|
||||
"type": "generated-index",
|
||||
"description": "Guides for hardcore developers"
|
||||
},
|
||||
"customProps": {
|
||||
"sidebarIcon": "LucideWrench"
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
---
|
||||
sidebar_position: 4
|
||||
slug: /acquire_ragflow_api_key
|
||||
sidebar_custom_props: {
|
||||
categoryIcon: LucideKey
|
||||
}
|
||||
---
|
||||
|
||||
# Acquire RAGFlow API key
|
||||
|
||||
An API key is required for the RAGFlow server to authenticate your HTTP/Python or MCP requests. This documents provides instructions on obtaining a RAGFlow API key.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user