Compare commits

..

24 Commits

Author SHA1 Message Date
427e0540ca Fix: table tag on chunks. (#12126)
- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-25 11:23:51 +08:00
8f16fac898 Fix: Add a no-data filter condition to MetaData (#12189)
### What problem does this PR solve?

Fix: Add a no-data filter condition to MetaData

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-25 10:42:34 +08:00
ea00478a21 Bump infinity to 0.6.13 (#12181)
### What problem does this PR solve?

Bump infinity to 0.6.13

### Type of change

- [x] Refactoring
2025-12-24 22:33:56 +08:00
906f19e863 Dragging down a downstream node of a Switch operator will cause the end_cpn_ids to contain the ID of the placeholder operator. #12177 (#12178)
### What problem does this PR solve?

Dragging down a downstream node of a Switch operator will cause the
end_cpn_ids to contain the ID of the placeholder operator. #12177

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 19:45:35 +08:00
667dc5467e Fix: Fixed the issue of incorrect agent translation text. #10427 (#12172)
### What problem does this PR solve?

Fix: Fixed the issue of incorrect agent translation text. #10427

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 19:05:26 +08:00
977962fdfe Fix: loopitem None issue. (#12166)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 17:22:31 +08:00
1e9374a373 Fix:Metadata saving, copywriting and other related issues (#12169)
### What problem does this PR solve?

Fix:Bugs Fixed
- Text overflow issues that caused rendering problems
- Metadata saving, copywriting and other related issues

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 17:21:36 +08:00
9b52ba8061 Feat: add image table context to pipeline splitter (#12167)
### What problem does this PR solve?

Add image table context to pipeline splitter.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-24 16:58:14 +08:00
44671ea413 Fix: type check for chunks (#12164)
### What problem does this PR solve?

Fix: type check for chunks

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 16:36:00 +08:00
c81421d340 Feat: add document metadata setting (#12156)
### What problem does this PR solve?

Add document metadata setting.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 16:13:50 +08:00
ee93a80e91 Feat: add MiniMax M2.1 (#12148)
### What problem does this PR solve?

Add MiniMax M2.1.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 16:08:09 +08:00
5c981978c1 Run infinity test before ES (#12159)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 15:52:55 +08:00
7fef285af5 Revert "Bump infinity to 0.6.12 (#12140)" (#12161)
This reverts commit 0588fe79b9.
2025-12-24 15:24:12 +08:00
b1efb905e5 Fix: metadata_obj issue. (#12146)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 13:40:34 +08:00
6400bf87ba Fix: LLM tool does not exist in multiple retrieval case (#12143)
### What problem does this PR solve?

 Fix LLM tool does not exist in multiple retrieval case

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 13:26:48 +08:00
f239bc02d3 Feat: Support Markdown Rendering for tips in user-fill-up Component #11825 (#12147)
### What problem does this PR solve?

Feat: Support Markdown Rendering for tips in user-fill-up Component
#11825

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-12-24 13:25:56 +08:00
5776fa73a7 refactor: improve memory service date time consistency (#12144)
### What problem does this PR solve?

 improve memory service date time consistency

### Type of change

- [x] Refactoring
2025-12-24 11:00:31 +08:00
fc6af1998b Doc: Added an HTTP request component reference (#12141)
### Type of change

- [x] Documentation Update
2025-12-24 09:35:32 +08:00
0588fe79b9 Bump infinity to 0.6.12 (#12140)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-12-24 09:34:54 +08:00
f545265f93 Fix:remove duplicate tool_meta (#12139)
### What problem does this PR solve?
pr:#12117
change:remove duplicate tool_meta

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 09:34:08 +08:00
c987d33649 Feat: deduplicate metadata lists during updates (#12125)
### What problem does this PR solve?

Deduplicate metadata lists during updates.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-24 09:32:55 +08:00
d72debf0db Fix: Add prompts when merging or deleting metadata. (#12138)
### What problem does this PR solve?

Fix: Add prompts when merging or deleting metadata.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
2025-12-24 09:32:41 +08:00
c33134ea2c Fix: table tag on chunks. (#12126)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-24 09:32:19 +08:00
17b8bb62b6 Feat: message manage (#12083)
### What problem does this PR solve?

Message CRUD.

Issue #4213 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-23 21:16:25 +08:00
1036 changed files with 44583 additions and 74868 deletions

View File

@ -1,22 +1 @@
# Project instructions for Copilot
## How to run (minimum)
- Install:
- python -m venv .venv && source .venv/bin/activate
- pip install -r requirements.txt
- Run:
- (fill) e.g. uvicorn app.main:app --reload
- Verify:
- (fill) curl http://127.0.0.1:8000/health
## Project layout (what matters)
- app/: API entrypoints + routers
- services/: business logic
- configs/: config loading (.env)
- docs/: documents
- tests/: pytest
## Conventions
- Prefer small, incremental changes.
- Add logging for new flows.
- Add/adjust tests for behavior changes.
Refer to [AGENTS.MD](../AGENTS.md) for all repo instructions.

View File

@ -10,12 +10,6 @@ on:
tags:
- "v*.*.*" # normal release
permissions:
contents: write
actions: read
checks: read
statuses: read
# https://docs.github.com/en/actions/using-jobs/using-concurrency
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@ -82,14 +76,6 @@ jobs:
# The body field does not support environment variable substitution directly.
body_path: release_body.md
- name: Build and push image
run: |
sudo docker login --username infiniflow --password-stdin <<< ${{ secrets.DOCKERHUB_TOKEN }}
sudo docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -t infiniflow/ragflow:${RELEASE_TAG} -f Dockerfile .
sudo docker tag infiniflow/ragflow:${RELEASE_TAG} infiniflow/ragflow:latest
sudo docker push infiniflow/ragflow:${RELEASE_TAG}
sudo docker push infiniflow/ragflow:latest
- name: Build and push ragflow-sdk
if: startsWith(github.ref, 'refs/tags/v')
run: |
@ -99,3 +85,11 @@ jobs:
if: startsWith(github.ref, 'refs/tags/v')
run: |
cd admin/client && uv build && uv publish --token ${{ secrets.PYPI_API_TOKEN }}
- name: Build and push image
run: |
sudo docker login --username infiniflow --password-stdin <<< ${{ secrets.DOCKERHUB_TOKEN }}
sudo docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -t infiniflow/ragflow:${RELEASE_TAG} -f Dockerfile .
sudo docker tag infiniflow/ragflow:${RELEASE_TAG} infiniflow/ragflow:latest
sudo docker push infiniflow/ragflow:${RELEASE_TAG}
sudo docker push infiniflow/ragflow:latest

View File

@ -86,9 +86,6 @@ jobs:
mkdir -p ${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}
echo "${PR_SHA} ${GITHUB_RUN_ID}" > ${PR_SHA_FP}
fi
ARTIFACTS_DIR=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/${GITHUB_RUN_ID}
echo "ARTIFACTS_DIR=${ARTIFACTS_DIR}" >> ${GITHUB_ENV}
rm -rf ${ARTIFACTS_DIR} && mkdir -p ${ARTIFACTS_DIR}
# https://github.com/astral-sh/ruff-action
- name: Static check with Ruff
@ -164,7 +161,7 @@ jobs:
INFINITY_THRIFT_PORT=$((23817 + RUNNER_NUM * 10))
INFINITY_HTTP_PORT=$((23820 + RUNNER_NUM * 10))
INFINITY_PSQL_PORT=$((5432 + RUNNER_NUM * 10))
EXPOSE_MYSQL_PORT=$((5455 + RUNNER_NUM * 10))
MYSQL_PORT=$((5455 + RUNNER_NUM * 10))
MINIO_PORT=$((9000 + RUNNER_NUM * 10))
MINIO_CONSOLE_PORT=$((9001 + RUNNER_NUM * 10))
REDIS_PORT=$((6379 + RUNNER_NUM * 10))
@ -184,7 +181,7 @@ jobs:
echo -e "INFINITY_THRIFT_PORT=${INFINITY_THRIFT_PORT}" >> docker/.env
echo -e "INFINITY_HTTP_PORT=${INFINITY_HTTP_PORT}" >> docker/.env
echo -e "INFINITY_PSQL_PORT=${INFINITY_PSQL_PORT}" >> docker/.env
echo -e "EXPOSE_MYSQL_PORT=${EXPOSE_MYSQL_PORT}" >> docker/.env
echo -e "MYSQL_PORT=${MYSQL_PORT}" >> docker/.env
echo -e "MINIO_PORT=${MINIO_PORT}" >> docker/.env
echo -e "MINIO_CONSOLE_PORT=${MINIO_CONSOLE_PORT}" >> docker/.env
echo -e "REDIS_PORT=${REDIS_PORT}" >> docker/.env
@ -200,188 +197,38 @@ jobs:
echo -e "COMPOSE_PROFILES=\${COMPOSE_PROFILES},tei-cpu" >> docker/.env
echo -e "TEI_MODEL=BAAI/bge-small-en-v1.5" >> docker/.env
echo -e "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> docker/.env
sed -i '1i DOC_ENGINE=infinity' docker/.env
echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> ${GITHUB_ENV}
# Patch entrypoint.sh for coverage
sed -i '/"\$PY" api\/ragflow_server.py \${INIT_SUPERUSER_ARGS} &/c\ echo "Ensuring coverage is installed..."\n "$PY" -m pip install coverage\n export COVERAGE_FILE=/ragflow/logs/.coverage\n echo "Starting ragflow_server with coverage..."\n "$PY" -m coverage run --source=./api/apps --omit="*/tests/*,*/migrations/*" -a api/ragflow_server.py ${INIT_SUPERUSER_ARGS} &' docker/entrypoint.sh
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
uv sync --python 3.12 --group test --frozen && uv pip install -e sdk/python
uv sync --python 3.12 --only-group test --no-default-groups --frozen && uv pip install sdk/python --group test
- name: Run sdk tests against Elasticsearch
- name: Run sdk tests against Infinity
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} --junitxml=pytest-infinity-sdk.xml --cov=sdk/python/ragflow_sdk --cov-branch --cov-report=xml:coverage-es-sdk.xml test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
- name: Run web api tests against Elasticsearch
- name: Run frontend api tests against Infinity
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api 2>&1 | tee es_web_api_test.log
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee infinity_api_test.log
- name: Run http api tests against Elasticsearch
- name: Run http api tests against Infinity
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log
- name: RAGFlow CLI retrieval test Elasticsearch
env:
PYTHONPATH: ${{ github.workspace }}
run: |
set -euo pipefail
source .venv/bin/activate
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
EMAIL="ci-${GITHUB_RUN_ID}@example.com"
PASS="ci-pass-${GITHUB_RUN_ID}"
DATASET="ci_dataset_${GITHUB_RUN_ID}"
CLI="python admin/client/ragflow_cli.py"
LOG_FILE="es_cli_test.log"
: > "${LOG_FILE}"
ERROR_RE='Traceback|ModuleNotFoundError|ImportError|Parse error|Bad response|Fail to|code:\\s*[1-9]'
run_cli() {
local logfile="$1"
shift
local allow_re=""
if [[ "${1:-}" == "--allow" ]]; then
allow_re="$2"
shift 2
fi
local cmd_display="$*"
echo "===== $(date -u +\"%Y-%m-%dT%H:%M:%SZ\") CMD: ${cmd_display} =====" | tee -a "${logfile}"
local tmp_log
tmp_log="$(mktemp)"
set +e
timeout 180s "$@" 2>&1 | tee "${tmp_log}"
local status=${PIPESTATUS[0]}
set -e
cat "${tmp_log}" >> "${logfile}"
if grep -qiE "${ERROR_RE}" "${tmp_log}"; then
if [[ -n "${allow_re}" ]] && grep -qiE "${allow_re}" "${tmp_log}"; then
echo "Allowed CLI error markers in ${logfile}"
rm -f "${tmp_log}"
return 0
fi
echo "Detected CLI error markers in ${logfile}"
rm -f "${tmp_log}"
exit 1
fi
rm -f "${tmp_log}"
return ${status}
}
set -a
source docker/.env
set +a
HOST_ADDRESS="http://host.docker.internal:${SVR_HTTP_PORT}"
USER_HOST="$(echo "${HOST_ADDRESS}" | sed -E 's#^https?://([^:/]+).*#\1#')"
USER_PORT="${SVR_HTTP_PORT}"
ADMIN_HOST="${USER_HOST}"
ADMIN_PORT="${ADMIN_SVR_HTTP_PORT}"
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
admin_ready=0
for i in $(seq 1 30); do
if run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "ping"; then
admin_ready=1
break
fi
sleep 1
done
if [[ "${admin_ready}" -ne 1 ]]; then
echo "Admin service did not become ready"
exit 1
fi
run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "show version"
ALLOW_USER_EXISTS_RE='already exists|already exist|duplicate|already.*registered|exist(s)?'
run_cli "${LOG_FILE}" --allow "${ALLOW_USER_EXISTS_RE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "create user '$EMAIL' '$PASS'"
user_ready=0
for i in $(seq 1 30); do
if run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "ping"; then
user_ready=1
break
fi
sleep 1
done
if [[ "${user_ready}" -ne 1 ]]; then
echo "User service did not become ready"
exit 1
fi
run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "show version"
run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "create dataset '$DATASET' with embedding 'BAAI/bge-small-en-v1.5@Builtin' parser 'auto'"
run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "import 'test/benchmark/test_docs/Doc1.pdf,test/benchmark/test_docs/Doc2.pdf' into dataset '$DATASET'"
run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "parse dataset '$DATASET' sync"
run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "Benchmark 16 100 search 'what are these documents about' on datasets '$DATASET'"
- name: Stop ragflow to save coverage Elasticsearch
if: ${{ !cancelled() }}
run: |
# Send SIGINT to ragflow_server.py to trigger coverage save
PID=$(sudo docker exec ${RAGFLOW_CONTAINER} ps aux | grep "ragflow_server.py" | grep -v grep | awk '{print $2}' | head -n 1)
if [ -n "$PID" ]; then
echo "Sending SIGINT to ragflow_server.py (PID: $PID)..."
sudo docker exec ${RAGFLOW_CONTAINER} kill -INT $PID
# Wait for process to exit and coverage file to be written
sleep 10
else
echo "ragflow_server.py not found!"
fi
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} stop
- name: Generate server coverage report Elasticsearch
if: ${{ !cancelled() }}
run: |
# .coverage file should be in docker/ragflow-logs/.coverage
if [ -f docker/ragflow-logs/.coverage ]; then
echo "Found .coverage file"
cp docker/ragflow-logs/.coverage .coverage
source .venv/bin/activate
# Create .coveragerc to map container paths to host paths
echo "[paths]" > .coveragerc
echo "source =" >> .coveragerc
echo " ." >> .coveragerc
echo " /ragflow" >> .coveragerc
coverage xml -o coverage-es-server.xml
rm .coveragerc
# Clean up for next run
sudo rm docker/ragflow-logs/.coverage
else
echo ".coverage file not found!"
fi
- name: Collect ragflow log Elasticsearch
if: ${{ !cancelled() }}
run: |
if [ -d docker/ragflow-logs ]; then
cp -r docker/ragflow-logs ${ARTIFACTS_DIR}/ragflow-logs-es
echo "ragflow log" && tail -n 200 docker/ragflow-logs/ragflow_server.log || true
else
echo "No docker/ragflow-logs directory found; skipping log collection"
fi
sudo rm -rf docker/ragflow-logs || true
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log
- name: Stop ragflow:nightly
if: always() # always run this step even if previous steps failed
@ -391,188 +238,35 @@ jobs:
- name: Start ragflow:nightly
run: |
sed -i '1i DOC_ENGINE=infinity' docker/.env
sed -i '1i DOC_ENGINE=elasticsearch' docker/.env
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
- name: Run sdk tests against Infinity
- name: Run sdk tests against Elasticsearch
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} --junitxml=pytest-infinity-sdk.xml --cov=sdk/python/ragflow_sdk --cov-branch --cov-report=xml:coverage-infinity-sdk.xml test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
- name: Run web api tests against Infinity
- name: Run frontend api tests against Elasticsearch
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api/test_api_app 2>&1 | tee infinity_web_api_test.log
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee es_api_test.log
- name: Run http api tests against Infinity
- name: Run http api tests against Elasticsearch
run: |
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log
- name: RAGFlow CLI retrieval test Infinity
env:
PYTHONPATH: ${{ github.workspace }}
run: |
set -euo pipefail
source .venv/bin/activate
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
EMAIL="ci-${GITHUB_RUN_ID}@example.com"
PASS="ci-pass-${GITHUB_RUN_ID}"
DATASET="ci_dataset_${GITHUB_RUN_ID}"
CLI="python admin/client/ragflow_cli.py"
LOG_FILE="infinity_cli_test.log"
: > "${LOG_FILE}"
ERROR_RE='Traceback|ModuleNotFoundError|ImportError|Parse error|Bad response|Fail to|code:\\s*[1-9]'
run_cli() {
local logfile="$1"
shift
local allow_re=""
if [[ "${1:-}" == "--allow" ]]; then
allow_re="$2"
shift 2
fi
local cmd_display="$*"
echo "===== $(date -u +\"%Y-%m-%dT%H:%M:%SZ\") CMD: ${cmd_display} =====" | tee -a "${logfile}"
local tmp_log
tmp_log="$(mktemp)"
set +e
timeout 180s "$@" 2>&1 | tee "${tmp_log}"
local status=${PIPESTATUS[0]}
set -e
cat "${tmp_log}" >> "${logfile}"
if grep -qiE "${ERROR_RE}" "${tmp_log}"; then
if [[ -n "${allow_re}" ]] && grep -qiE "${allow_re}" "${tmp_log}"; then
echo "Allowed CLI error markers in ${logfile}"
rm -f "${tmp_log}"
return 0
fi
echo "Detected CLI error markers in ${logfile}"
rm -f "${tmp_log}"
exit 1
fi
rm -f "${tmp_log}"
return ${status}
}
set -a
source docker/.env
set +a
HOST_ADDRESS="http://host.docker.internal:${SVR_HTTP_PORT}"
USER_HOST="$(echo "${HOST_ADDRESS}" | sed -E 's#^https?://([^:/]+).*#\1#')"
USER_PORT="${SVR_HTTP_PORT}"
ADMIN_HOST="${USER_HOST}"
ADMIN_PORT="${ADMIN_SVR_HTTP_PORT}"
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
echo "Waiting for service to be available..."
sleep 5
done
admin_ready=0
for i in $(seq 1 30); do
if run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "ping"; then
admin_ready=1
break
fi
sleep 1
done
if [[ "${admin_ready}" -ne 1 ]]; then
echo "Admin service did not become ready"
exit 1
fi
run_cli "${LOG_FILE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "show version"
ALLOW_USER_EXISTS_RE='already exists|already exist|duplicate|already.*registered|exist(s)?'
run_cli "${LOG_FILE}" --allow "${ALLOW_USER_EXISTS_RE}" $CLI --type admin --host "$ADMIN_HOST" --port "$ADMIN_PORT" --username "admin@ragflow.io" --password "admin" command "create user '$EMAIL' '$PASS'"
user_ready=0
for i in $(seq 1 30); do
if run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "ping"; then
user_ready=1
break
fi
sleep 1
done
if [[ "${user_ready}" -ne 1 ]]; then
echo "User service did not become ready"
exit 1
fi
run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "show version"
run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "create dataset '$DATASET' with embedding 'BAAI/bge-small-en-v1.5@Builtin' parser 'auto'"
run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "import 'test/benchmark/test_docs/Doc1.pdf,test/benchmark/test_docs/Doc2.pdf' into dataset '$DATASET'"
run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "parse dataset '$DATASET' sync"
run_cli "${LOG_FILE}" $CLI --type user --host "$USER_HOST" --port "$USER_PORT" --username "$EMAIL" --password "$PASS" command "Benchmark 16 100 search 'what are these documents about' on datasets '$DATASET'"
- name: Stop ragflow to save coverage Infinity
if: ${{ !cancelled() }}
run: |
# Send SIGINT to ragflow_server.py to trigger coverage save
PID=$(sudo docker exec ${RAGFLOW_CONTAINER} ps aux | grep "ragflow_server.py" | grep -v grep | awk '{print $2}' | head -n 1)
if [ -n "$PID" ]; then
echo "Sending SIGINT to ragflow_server.py (PID: $PID)..."
sudo docker exec ${RAGFLOW_CONTAINER} kill -INT $PID
# Wait for process to exit and coverage file to be written
sleep 10
else
echo "ragflow_server.py not found!"
fi
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} stop
- name: Generate server coverage report Infinity
if: ${{ !cancelled() }}
run: |
# .coverage file should be in docker/ragflow-logs/.coverage
if [ -f docker/ragflow-logs/.coverage ]; then
echo "Found .coverage file"
cp docker/ragflow-logs/.coverage .coverage
source .venv/bin/activate
# Create .coveragerc to map container paths to host paths
echo "[paths]" > .coveragerc
echo "source =" >> .coveragerc
echo " ." >> .coveragerc
echo " /ragflow" >> .coveragerc
coverage xml -o coverage-infinity-server.xml
rm .coveragerc
else
echo ".coverage file not found!"
fi
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
if: ${{ !cancelled() }}
with:
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false
- name: Collect ragflow log
if: ${{ !cancelled() }}
run: |
if [ -d docker/ragflow-logs ]; then
cp -r docker/ragflow-logs ${ARTIFACTS_DIR}/ragflow-logs-infinity
echo "ragflow log" && tail -n 200 docker/ragflow-logs/ragflow_server.log || true
else
echo "No docker/ragflow-logs directory found; skipping log collection"
fi
sudo rm -rf docker/ragflow-logs || true
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log
- name: Stop ragflow:nightly
if: always() # always run this step even if previous steps failed

13
.gitignore vendored
View File

@ -44,7 +44,6 @@ cl100k_base.tiktoken
chrome*
huggingface.co/
nltk_data/
uv-x86_64*.tar.gz
# Exclude hash-like temporary files like 9b5ad71b2ce5302211f9c61530b329a4922fc6a4
*[0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f]*
@ -52,13 +51,6 @@ uv-x86_64*.tar.gz
.venv
docker/data
# OceanBase data and conf
docker/oceanbase/conf
docker/oceanbase/data
# SeekDB data and conf
docker/seekdb
#--------------------------------------------------#
# The following was generated with gitignore.nvim: #
@ -206,8 +198,3 @@ backup
.hypothesis
# Added by cargo
/target

View File

@ -27,7 +27,7 @@ RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on d
- **Document Processing**: `deepdoc/` - PDF parsing, OCR, layout analysis
- **LLM Integration**: `rag/llm/` - Model abstractions for chat, embedding, reranking
- **RAG Pipeline**: `rag/flow/` - Chunking, parsing, tokenization
- **Graph RAG**: `rag/graphrag/` - Knowledge graph construction and querying
- **Graph RAG**: `graphrag/` - Knowledge graph construction and querying
### Agent System (`/agent/`)
- **Components**: Modular workflow components (LLM, retrieval, categorize, etc.)

View File

@ -19,16 +19,17 @@ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co
# This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
cp -r /deps/nltk_data /root/ && \
cp /deps/tika-server-standard-3.2.3.jar /deps/tika-server-standard-3.2.3.jar.md5 /ragflow/ && \
cp /deps/tika-server-standard-3.0.0.jar /deps/tika-server-standard-3.0.0.jar.md5 /ragflow/ && \
cp /deps/cl100k_base.tiktoken /ragflow/9b5ad71b2ce5302211f9c61530b329a4922fc6a4
ENV TIKA_SERVER_JAR="file:///ragflow/tika-server-standard-3.2.3.jar"
ENV TIKA_SERVER_JAR="file:///ragflow/tika-server-standard-3.0.0.jar"
ENV DEBIAN_FRONTEND=noninteractive
# Setup apt
# Python package and implicit dependencies:
# opencv-python: libglib2.0-0 libglx-mesa0 libgl1
# python-pptx: default-jdk tika-server-standard-3.2.3.jar
# aspose-slides: pkg-config libicu-dev libgdiplus libssl1.1_1.1.1f-1ubuntu2_amd64.deb
# python-pptx: default-jdk tika-server-standard-3.0.0.jar
# selenium: libatk-bridge2.0-0 chrome-linux64-121-0-6167-85
# Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
@ -52,8 +53,7 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
apt install -y ghostscript && \
apt install -y pandoc && \
apt install -y texlive && \
apt install -y fonts-freefont-ttf fonts-noto-cjk && \
apt install -y postgresql-client
apt install -y fonts-freefont-ttf fonts-noto-cjk
# Install uv
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
@ -64,12 +64,10 @@ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps
echo 'url = "https://pypi.tuna.tsinghua.edu.cn/simple"' >> /etc/uv/uv.toml && \
echo 'default = true' >> /etc/uv/uv.toml; \
fi; \
arch="$(uname -m)"; \
if [ "$arch" = "x86_64" ]; then uv_arch="x86_64"; else uv_arch="aarch64"; fi; \
tar xzf "/deps/uv-${uv_arch}-unknown-linux-gnu.tar.gz" \
&& cp "uv-${uv_arch}-unknown-linux-gnu/"* /usr/local/bin/ \
&& rm -rf "uv-${uv_arch}-unknown-linux-gnu" \
&& uv python install 3.12
tar xzf /deps/uv-x86_64-unknown-linux-gnu.tar.gz \
&& cp uv-x86_64-unknown-linux-gnu/* /usr/local/bin/ \
&& rm -rf uv-x86_64-unknown-linux-gnu \
&& uv python install 3.11
ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
ENV PATH=/root/.local/bin:$PATH
@ -127,6 +125,8 @@ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/chromedriver-l
mv chromedriver /usr/local/bin/ && \
rm -f /usr/bin/google-chrome
# https://forum.aspose.com/t/aspose-slides-for-net-no-usable-version-of-libssl-found-with-linux-server/271344/13
# aspose-slides on linux/arm64 is unavailable
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
if [ "$(uname -m)" = "x86_64" ]; then \
dpkg -i /deps/libssl1.1_1.1.1f-1ubuntu2_amd64.deb; \
@ -152,14 +152,11 @@ RUN --mount=type=cache,id=ragflow_uv,target=/root/.cache/uv,sharing=locked \
else \
sed -i 's|pypi.tuna.tsinghua.edu.cn|pypi.org|g' uv.lock; \
fi; \
uv sync --python 3.12 --frozen && \
# Ensure pip is available in the venv for runtime package installation (fixes #12651)
.venv/bin/python3 -m ensurepip --upgrade
uv sync --python 3.12 --frozen
COPY web web
COPY docs docs
RUN --mount=type=cache,id=ragflow_npm,target=/root/.npm,sharing=locked \
export NODE_OPTIONS="--max-old-space-size=4096" && \
cd web && npm install && npm run build
COPY .git /ragflow/.git
@ -189,8 +186,11 @@ COPY conf conf
COPY deepdoc deepdoc
COPY rag rag
COPY agent agent
COPY graphrag graphrag
COPY agentic_reasoning agentic_reasoning
COPY pyproject.toml uv.lock ./
COPY mcp mcp
COPY plugin plugin
COPY common common
COPY memory memory

View File

@ -3,7 +3,7 @@
FROM scratch
# Copy resources downloaded via download_deps.py
COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.2.3.jar tika-server-standard-3.2.3.jar.md5 libssl*.deb uv-x86_64-unknown-linux-gnu.tar.gz uv-aarch64-unknown-linux-gnu.tar.gz /
COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.0.0.jar tika-server-standard-3.0.0.jar.md5 libssl*.deb uv-x86_64-unknown-linux-gnu.tar.gz /
COPY nltk_data /nltk_data

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
</a>
<a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -37,7 +37,7 @@
<h4 align="center">
<a href="https://ragflow.io/docs/dev/">Document</a> |
<a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> |
<a href="https://twitter.com/infiniflowai">Twitter</a> |
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
<a href="https://demo.ragflow.io">Demo</a>
@ -72,7 +72,7 @@
## 💡 What is RAGFlow?
[RAGFlow](https://ragflow.io/) is a leading open-source Retrieval-Augmented Generation ([RAG](https://ragflow.io/basics/what-is-rag)) engine that fuses cutting-edge RAG with Agent capabilities to create a superior context layer for LLMs. It offers a streamlined RAG workflow adaptable to enterprises of any scale. Powered by a converged [context engine](https://ragflow.io/basics/what-is-agent-context-engine) and pre-built agent templates, RAGFlow enables developers to transform complex data into high-fidelity, production-ready AI systems with exceptional efficiency and precision.
[RAGFlow](https://ragflow.io/) is a leading open-source Retrieval-Augmented Generation (RAG) engine that fuses cutting-edge RAG with Agent capabilities to create a superior context layer for LLMs. It offers a streamlined RAG workflow adaptable to enterprises of any scale. Powered by a converged context engine and pre-built agent templates, RAGFlow enables developers to transform complex data into high-fidelity, production-ready AI systems with exceptional efficiency and precision.
## 🎮 Demo
@ -85,7 +85,6 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io).
## 🔥 Latest Updates
- 2025-12-26 Supports 'Memory' for AI agent.
- 2025-11-19 Supports Gemini 3 Pro.
- 2025-11-12 Supports data synchronization from Confluence, S3, Notion, Discord, Google Drive.
- 2025-10-23 Supports MinerU & Docling as document parsing methods.
@ -188,12 +187,12 @@ releases! 🌟
> All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64.
> If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system.
> The command below downloads the `v0.23.1` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.23.1`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
> The command below downloads the `v0.22.1` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.22.1`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
```bash
$ cd ragflow/docker
# git checkout v0.23.1
# git checkout v0.22.1
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
# This step ensures the **entrypoint.sh** file in the code matches the Docker image version.
@ -233,7 +232,7 @@ releases! 🌟
* Running on all addresses (0.0.0.0)
```
> If you skip this confirmation step and directly log in to RAGFlow, your browser may prompt a `network abnormal`
> If you skip this confirmation step and directly log in to RAGFlow, your browser may prompt a `network anormal`
> error because, at that moment, your RAGFlow may not be fully initialized.
>
5. In your web browser, enter the IP address of your server and log in to RAGFlow.
@ -303,15 +302,6 @@ cd ragflow/
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
```
Or if you are behind a proxy, you can pass proxy arguments:
```bash
docker build --platform linux/amd64 \
--build-arg http_proxy=http://YOUR_PROXY:PORT \
--build-arg https_proxy=http://YOUR_PROXY:PORT \
-f Dockerfile -t infiniflow/ragflow:nightly .
```
## 🔨 Launch service from source for development
1. Install `uv` and `pre-commit`, or skip this step if they are already installed:
@ -396,7 +386,7 @@ docker build --platform linux/amd64 \
## 📜 Roadmap
See the [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241)
See the [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214)
## 🏄 Community

View File

@ -22,7 +22,7 @@
<img alt="Lencana Daring" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
</a>
<a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Rilis%20Terbaru" alt="Rilis Terbaru">
@ -37,7 +37,7 @@
<h4 align="center">
<a href="https://ragflow.io/docs/dev/">Dokumentasi</a> |
<a href="https://github.com/infiniflow/ragflow/issues/12241">Peta Jalan</a> |
<a href="https://github.com/infiniflow/ragflow/issues/4214">Peta Jalan</a> |
<a href="https://twitter.com/infiniflowai">Twitter</a> |
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
<a href="https://demo.ragflow.io">Demo</a>
@ -72,7 +72,7 @@
## 💡 Apa Itu RAGFlow?
[RAGFlow](https://ragflow.io/) adalah mesin [RAG](https://ragflow.io/basics/what-is-rag) (Retrieval-Augmented Generation) open-source terkemuka yang mengintegrasikan teknologi RAG mutakhir dengan kemampuan Agent untuk menciptakan lapisan kontekstual superior bagi LLM. Menyediakan alur kerja RAG yang efisien dan dapat diadaptasi untuk perusahaan segala skala. Didukung oleh mesin konteks terkonvergensi dan template Agent yang telah dipra-bangun, RAGFlow memungkinkan pengembang mengubah data kompleks menjadi sistem AI kesetiaan-tinggi dan siap-produksi dengan efisiensi dan presisi yang luar biasa.
[RAGFlow](https://ragflow.io/) adalah mesin RAG (Retrieval-Augmented Generation) open-source terkemuka yang mengintegrasikan teknologi RAG mutakhir dengan kemampuan Agent untuk menciptakan lapisan kontekstual superior bagi LLM. Menyediakan alur kerja RAG yang efisien dan dapat diadaptasi untuk perusahaan segala skala. Didukung oleh mesin konteks terkonvergensi dan template Agent yang telah dipra-bangun, RAGFlow memungkinkan pengembang mengubah data kompleks menjadi sistem AI kesetiaan-tinggi dan siap-produksi dengan efisiensi dan presisi yang luar biasa.
## 🎮 Demo
@ -85,7 +85,6 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
## 🔥 Pembaruan Terbaru
- 2025-12-26 Mendukung 'Memori' untuk agen AI.
- 2025-11-19 Mendukung Gemini 3 Pro.
- 2025-11-12 Mendukung sinkronisasi data dari Confluence, S3, Notion, Discord, Google Drive.
- 2025-10-23 Mendukung MinerU & Docling sebagai metode penguraian dokumen.
@ -188,12 +187,12 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
> Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64.
> Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image).
> Perintah di bawah ini mengunduh edisi v0.23.1 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.23.1, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
> Perintah di bawah ini mengunduh edisi v0.22.1 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.22.1, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
```bash
$ cd ragflow/docker
# git checkout v0.23.1
# git checkout v0.22.1
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases)
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
@ -233,7 +232,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
* Running on all addresses (0.0.0.0)
```
> Jika Anda melewatkan langkah ini dan langsung login ke RAGFlow, browser Anda mungkin menampilkan error `network abnormal`
> Jika Anda melewatkan langkah ini dan langsung login ke RAGFlow, browser Anda mungkin menampilkan error `network anormal`
> karena RAGFlow mungkin belum sepenuhnya siap.
>
2. Buka browser web Anda, masukkan alamat IP server Anda, dan login ke RAGFlow.
@ -277,15 +276,6 @@ cd ragflow/
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
```
Jika berada di belakang proxy, Anda dapat melewatkan argumen proxy:
```bash
docker build --platform linux/amd64 \
--build-arg http_proxy=http://YOUR_PROXY:PORT \
--build-arg https_proxy=http://YOUR_PROXY:PORT \
-f Dockerfile -t infiniflow/ragflow:nightly .
```
## 🔨 Menjalankan Aplikasi dari untuk Pengembangan
1. Instal `uv` dan `pre-commit`, atau lewati langkah ini jika sudah terinstal:
@ -368,7 +358,7 @@ docker build --platform linux/amd64 \
## 📜 Roadmap
Lihat [Roadmap RAGFlow 2026](https://github.com/infiniflow/ragflow/issues/12241)
Lihat [Roadmap RAGFlow 2025](https://github.com/infiniflow/ragflow/issues/4214)
## 🏄 Komunitas

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
</a>
<a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -37,7 +37,7 @@
<h4 align="center">
<a href="https://ragflow.io/docs/dev/">Document</a> |
<a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> |
<a href="https://twitter.com/infiniflowai">Twitter</a> |
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
<a href="https://demo.ragflow.io">Demo</a>
@ -53,7 +53,7 @@
## 💡 RAGFlow とは?
[RAGFlow](https://ragflow.io/) は、先進的な[RAG](https://ragflow.io/basics/what-is-rag)Retrieval-Augmented Generation技術と Agent 機能を融合し、大規模言語モデルLLMに優れたコンテキスト層を構築する最先端のオープンソース RAG エンジンです。あらゆる規模の企業に対応可能な合理化された RAG ワークフローを提供し、統合型[コンテキストエンジン](https://ragflow.io/basics/what-is-agent-context-engine)と事前構築されたAgentテンプレートにより、開発者が複雑なデータを驚異的な効率性と精度で高精細なプロダクションレディAIシステムへ変換することを可能にします。
[RAGFlow](https://ragflow.io/) は、先進的なRAGRetrieval-Augmented Generation技術と Agent 機能を融合し、大規模言語モデルLLMに優れたコンテキスト層を構築する最先端のオープンソース RAG エンジンです。あらゆる規模の企業に対応可能な合理化された RAG ワークフローを提供し、統合型コンテキストエンジンと事前構築されたAgentテンプレートにより、開発者が複雑なデータを驚異的な効率性と精度で高精細なプロダクションレディAIシステムへ変換することを可能にします。
## 🎮 Demo
@ -66,8 +66,7 @@
## 🔥 最新情報
- 2025-12-26 AIエージェントの「メモリ」機能をサポート。
- 2025-11-19 Gemini 3 Proをサポートしています。
- 2025-11-19 Gemini 3 Proをサポートしています
- 2025-11-12 Confluence、S3、Notion、Discord、Google Drive からのデータ同期をサポートします。
- 2025-10-23 ドキュメント解析方法として MinerU と Docling をサポートします。
- 2025-10-15 オーケストレーションされたデータパイプラインのサポート。
@ -168,12 +167,12 @@
> 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。
> ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。
> 以下のコマンドは、RAGFlow Docker イメージの v0.23.1 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.23.1 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
> 以下のコマンドは、RAGFlow Docker イメージの v0.22.1 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.22.1 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
```bash
$ cd ragflow/docker
# git checkout v0.23.1
# git checkout v0.22.1
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases)
# この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。
@ -277,15 +276,6 @@ cd ragflow/
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
```
プロキシ環境下にいる場合は、プロキシ引数を指定できます:
```bash
docker build --platform linux/amd64 \
--build-arg http_proxy=http://YOUR_PROXY:PORT \
--build-arg https_proxy=http://YOUR_PROXY:PORT \
-f Dockerfile -t infiniflow/ragflow:nightly .
```
## 🔨 ソースコードからサービスを起動する方法
1. `uv` と `pre-commit` をインストールする。すでにインストールされている場合は、このステップをスキップしてください:
@ -368,7 +358,7 @@ docker build --platform linux/amd64 \
## 📜 ロードマップ
[RAGFlow ロードマップ 2026](https://github.com/infiniflow/ragflow/issues/12241) を参照
[RAGFlow ロードマップ 2025](https://github.com/infiniflow/ragflow/issues/4214) を参照
## 🏄 コミュニティ

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
</a>
<a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -37,7 +37,7 @@
<h4 align="center">
<a href="https://ragflow.io/docs/dev/">Document</a> |
<a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> |
<a href="https://twitter.com/infiniflowai">Twitter</a> |
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
<a href="https://demo.ragflow.io">Demo</a>
@ -54,7 +54,7 @@
## 💡 RAGFlow란?
[RAGFlow](https://ragflow.io/) 는 최첨단 [RAG](https://ragflow.io/basics/what-is-rag)(Retrieval-Augmented Generation)와 Agent 기능을 융합하여 대규모 언어 모델(LLM)을 위한 우수한 컨텍스트 계층을 생성하는 선도적인 오픈소스 RAG 엔진입니다. 모든 규모의 기업에 적용 가능한 효율적인 RAG 워크플로를 제공하며, 통합 [컨텍스트 엔진](https://ragflow.io/basics/what-is-agent-context-engine)과 사전 구축된 Agent 템플릿을 통해 개발자들이 복잡한 데이터를 예외적인 효율성과 정밀도로 고급 구현도의 프로덕션 준비 완료 AI 시스템으로 변환할 수 있도록 지원합니다.
[RAGFlow](https://ragflow.io/) 는 최첨단 RAG(Retrieval-Augmented Generation)와 Agent 기능을 융합하여 대규모 언어 모델(LLM)을 위한 우수한 컨텍스트 계층을 생성하는 선도적인 오픈소스 RAG 엔진입니다. 모든 규모의 기업에 적용 가능한 효율적인 RAG 워크플로를 제공하며, 통합 컨텍스트 엔진과 사전 구축된 Agent 템플릿을 통해 개발자들이 복잡한 데이터를 예외적인 효율성과 정밀도로 고급 구현도의 프로덕션 준비 완료 AI 시스템으로 변환할 수 있도록 지원합니다.
## 🎮 데모
@ -67,7 +67,6 @@
## 🔥 업데이트
- 2025-12-26 AI 에이전트의 '메모리' 기능 지원.
- 2025-11-19 Gemini 3 Pro를 지원합니다.
- 2025-11-12 Confluence, S3, Notion, Discord, Google Drive에서 데이터 동기화를 지원합니다.
- 2025-10-23 문서 파싱 방법으로 MinerU 및 Docling을 지원합니다.
@ -170,12 +169,12 @@
> 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다.
> ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image).
> 아래 명령어는 RAGFlow Docker 이미지의 v0.23.1 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.23.1과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
> 아래 명령어는 RAGFlow Docker 이미지의 v0.22.1 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.22.1과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
```bash
$ cd ragflow/docker
# git checkout v0.23.1
# git checkout v0.22.1
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
# 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다.
@ -214,7 +213,7 @@
* Running on all addresses (0.0.0.0)
```
> 만약 확인 단계를 건너뛰고 바로 RAGFlow에 로그인하면, RAGFlow가 완전히 초기화되지 않았기 때문에 브라우저에서 `network abnormal` 오류가 발생할 수 있습니다.
> 만약 확인 단계를 건너뛰고 바로 RAGFlow에 로그인하면, RAGFlow가 완전히 초기화되지 않았기 때문에 브라우저에서 `network anormal` 오류가 발생할 수 있습니다.
2. 웹 브라우저에 서버의 IP 주소를 입력하고 RAGFlow에 로그인하세요.
> 기본 설정을 사용할 경우, `http://IP_OF_YOUR_MACHINE`만 입력하면 됩니다 (포트 번호는 제외). 기본 HTTP 서비스 포트 `80`은 기본 구성으로 사용할 때 생략할 수 있습니다.
@ -271,15 +270,6 @@ cd ragflow/
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
```
프록시 환경인 경우, 프록시 인수를 전달할 수 있습니다:
```bash
docker build --platform linux/amd64 \
--build-arg http_proxy=http://YOUR_PROXY:PORT \
--build-arg https_proxy=http://YOUR_PROXY:PORT \
-f Dockerfile -t infiniflow/ragflow:nightly .
```
## 🔨 소스 코드로 서비스를 시작합니다.
1. `uv` 와 `pre-commit` 을 설치하거나, 이미 설치된 경우 이 단계를 건너뜁니다:
@ -372,7 +362,7 @@ docker build --platform linux/amd64 \
## 📜 로드맵
[RAGFlow 로드맵 2026](https://github.com/infiniflow/ragflow/issues/12241)을 확인하세요.
[RAGFlow 로드맵 2025](https://github.com/infiniflow/ragflow/issues/4214)을 확인하세요.
## 🏄 커뮤니티

View File

@ -22,7 +22,7 @@
<img alt="Badge Estático" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
</a>
<a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Última%20Relese" alt="Última Versão">
@ -37,7 +37,7 @@
<h4 align="center">
<a href="https://ragflow.io/docs/dev/">Documentação</a> |
<a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> |
<a href="https://twitter.com/infiniflowai">Twitter</a> |
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
<a href="https://demo.ragflow.io">Demo</a>
@ -73,7 +73,7 @@
## 💡 O que é o RAGFlow?
[RAGFlow](https://ragflow.io/) é um mecanismo de [RAG](https://ragflow.io/basics/what-is-rag) (Retrieval-Augmented Generation) open-source líder que fusiona tecnologias RAG de ponta com funcionalidades Agent para criar uma camada contextual superior para LLMs. Oferece um fluxo de trabalho RAG otimizado adaptável a empresas de qualquer escala. Alimentado por [um motor de contexto](https://ragflow.io/basics/what-is-agent-context-engine) convergente e modelos Agent pré-construídos, o RAGFlow permite que desenvolvedores transformem dados complexos em sistemas de IA de alta fidelidade e pronto para produção com excepcional eficiência e precisão.
[RAGFlow](https://ragflow.io/) é um mecanismo de RAG (Retrieval-Augmented Generation) open-source líder que fusiona tecnologias RAG de ponta com funcionalidades Agent para criar uma camada contextual superior para LLMs. Oferece um fluxo de trabalho RAG otimizado adaptável a empresas de qualquer escala. Alimentado por um motor de contexto convergente e modelos Agent pré-construídos, o RAGFlow permite que desenvolvedores transformem dados complexos em sistemas de IA de alta fidelidade e pronto para produção com excepcional eficiência e precisão.
## 🎮 Demo
@ -86,7 +86,6 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
## 🔥 Últimas Atualizações
- 26-12-2025 Suporte à função 'Memória' para agentes de IA.
- 19-11-2025 Suporta Gemini 3 Pro.
- 12-11-2025 Suporta a sincronização de dados do Confluence, S3, Notion, Discord e Google Drive.
- 23-10-2025 Suporta MinerU e Docling como métodos de análise de documentos.
@ -188,12 +187,12 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
> Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64.
> Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema.
> O comando abaixo baixa a edição`v0.23.1` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.23.1`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
> O comando abaixo baixa a edição`v0.22.1` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.22.1`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
```bash
$ cd ragflow/docker
# git checkout v0.23.1
# git checkout v0.22.1
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases)
# Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker.
@ -232,7 +231,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
* Rodando em todos os endereços (0.0.0.0)
```
> Se você pular essa etapa de confirmação e acessar diretamente o RAGFlow, seu navegador pode exibir um erro `network abnormal`, pois, nesse momento, seu RAGFlow pode não estar totalmente inicializado.
> Se você pular essa etapa de confirmação e acessar diretamente o RAGFlow, seu navegador pode exibir um erro `network anormal`, pois, nesse momento, seu RAGFlow pode não estar totalmente inicializado.
>
5. No seu navegador, insira o endereço IP do seu servidor e faça login no RAGFlow.
@ -294,15 +293,6 @@ cd ragflow/
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
```
Se você estiver atrás de um proxy, pode passar argumentos de proxy:
```bash
docker build --platform linux/amd64 \
--build-arg http_proxy=http://YOUR_PROXY:PORT \
--build-arg https_proxy=http://YOUR_PROXY:PORT \
-f Dockerfile -t infiniflow/ragflow:nightly .
```
## 🔨 Lançar o serviço a partir do código-fonte para desenvolvimento
1. Instale o `uv` e o `pre-commit`, ou pule esta etapa se eles já estiverem instalados:
@ -385,7 +375,7 @@ docker build --platform linux/amd64 \
## 📜 Roadmap
Veja o [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241)
Veja o [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214)
## 🏄 Comunidade

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
</a>
<a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -37,7 +37,7 @@
<h4 align="center">
<a href="https://ragflow.io/docs/dev/">Document</a> |
<a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> |
<a href="https://twitter.com/infiniflowai">Twitter</a> |
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
<a href="https://demo.ragflow.io">Demo</a>
@ -72,7 +72,7 @@
## 💡 RAGFlow 是什麼?
[RAGFlow](https://ragflow.io/) 是一款領先的開源 [RAG](https://ragflow.io/basics/what-is-rag)Retrieval-Augmented Generation引擎通過融合前沿的 RAG 技術與 Agent 能力,為大型語言模型提供卓越的上下文層。它提供可適配任意規模企業的端到端 RAG 工作流,憑藉融合式[上下文引擎](https://ragflow.io/basics/what-is-agent-context-engine)與預置的 Agent 模板,助力開發者以極致效率與精度將複雜數據轉化為高可信、生產級的人工智能系統。
[RAGFlow](https://ragflow.io/) 是一款領先的開源 RAGRetrieval-Augmented Generation引擎通過融合前沿的 RAG 技術與 Agent 能力,為大型語言模型提供卓越的上下文層。它提供可適配任意規模企業的端到端 RAG 工作流,憑藉融合式上下文引擎與預置的 Agent 模板,助力開發者以極致效率與精度將複雜數據轉化為高可信、生產級的人工智能系統。
## 🎮 Demo 試用
@ -85,16 +85,15 @@
## 🔥 近期更新
- 2025-12-26 支援AI代理的「記憶」功能。
- 2025-11-19 支援 Gemini 3 Pro。
- 2025-11-19 支援 Gemini 3 Pro.
- 2025-11-12 支援從 Confluence、S3、Notion、Discord、Google Drive 進行資料同步。
- 2025-10-23 支援 MinerU 和 Docling 作為文件解析方法。
- 2025-10-15 支援可編排的資料管道。
- 2025-08-08 支援 OpenAI 最新的 GPT-5 系列模型。
- 2025-08-01 支援 agentic workflow 和 MCP
- 2025-08-01 支援 agentic workflow 和 MCP
- 2025-05-23 為 Agent 新增 Python/JS 程式碼執行器元件。
- 2025-05-05 支援跨語言查詢。
- 2025-03-19 PDF和DOCX中的圖支持用多模態大模型去解析得到描述
- 2025-03-19 PDF和DOCX中的圖支持用多模態大模型去解析得到描述.
- 2024-12-18 升級了 DeepDoc 的文檔佈局分析模型。
- 2024-08-22 支援用 RAG 技術實現從自然語言到 SQL 語句的轉換。
@ -125,7 +124,7 @@
### 🍔 **相容各類異質資料來源**
- 支援豐富的文件類型,包括 Word 文件、PPT、excel 表格、txt 檔案、圖片、PDF、影印件、印件、結構化資料、網頁等。
- 支援豐富的文件類型,包括 Word 文件、PPT、excel 表格、txt 檔案、圖片、PDF、影印件、印件、結構化資料、網頁等。
### 🛀 **全程無憂、自動化的 RAG 工作流程**
@ -187,12 +186,12 @@
> 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。
> 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。
> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.23.1`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.23.1` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.22.1`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.22.1` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
```bash
$ cd ragflow/docker
# git checkout v0.23.1
# git checkout v0.22.1
# 可選使用穩定版標籤查看發佈https://github.com/infiniflow/ragflow/releases
# 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。
@ -237,7 +236,7 @@
* Running on all addresses (0.0.0.0)
```
> 如果您跳過這一步驟系統確認步驟就登入 RAGFlow你的瀏覽器有可能會提示 `network abnormal` 或 `網路異常`,因為 RAGFlow 可能並未完全啟動成功。
> 如果您跳過這一步驟系統確認步驟就登入 RAGFlow你的瀏覽器有可能會提示 `network anormal` 或 `網路異常`,因為 RAGFlow 可能並未完全啟動成功。
>
5. 在你的瀏覽器中輸入你的伺服器對應的 IP 位址並登入 RAGFlow。
@ -303,15 +302,6 @@ cd ragflow/
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
```
若您位於代理環境,可傳遞代理參數:
```bash
docker build --platform linux/amd64 \
--build-arg http_proxy=http://YOUR_PROXY:PORT \
--build-arg https_proxy=http://YOUR_PROXY:PORT \
-f Dockerfile -t infiniflow/ragflow:nightly .
```
## 🔨 以原始碼啟動服務
1. 安裝 `uv` 和 `pre-commit`。如已安裝,可跳過此步驟:
@ -399,7 +389,7 @@ docker build --platform linux/amd64 \
## 📜 路線圖
詳見 [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241) 。
詳見 [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214) 。
## 🏄 開源社群

View File

@ -22,7 +22,7 @@
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
</a>
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
</a>
<a href="https://github.com/infiniflow/ragflow/releases/latest">
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
@ -37,7 +37,7 @@
<h4 align="center">
<a href="https://ragflow.io/docs/dev/">Document</a> |
<a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> |
<a href="https://twitter.com/infiniflowai">Twitter</a> |
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
<a href="https://demo.ragflow.io">Demo</a>
@ -72,7 +72,7 @@
## 💡 RAGFlow 是什么?
[RAGFlow](https://ragflow.io/) 是一款领先的开源检索增强生成([RAG](https://ragflow.io/basics/what-is-rag))引擎,通过融合前沿的 RAG 技术与 Agent 能力,为大型语言模型提供卓越的上下文层。它提供可适配任意规模企业的端到端 RAG 工作流,凭借融合式[上下文引擎](https://ragflow.io/basics/what-is-agent-context-engine)与预置的 Agent 模板,助力开发者以极致效率与精度将复杂数据转化为高可信、生产级的人工智能系统。
[RAGFlow](https://ragflow.io/) 是一款领先的开源检索增强生成RAG引擎通过融合前沿的 RAG 技术与 Agent 能力,为大型语言模型提供卓越的上下文层。它提供可适配任意规模企业的端到端 RAG 工作流,凭借融合式上下文引擎与预置的 Agent 模板,助力开发者以极致效率与精度将复杂数据转化为高可信、生产级的人工智能系统。
## 🎮 Demo 试用
@ -85,8 +85,7 @@
## 🔥 近期更新
- 2025-12-26 支持AI代理的“记忆”功能。
- 2025-11-19 支持 Gemini 3 Pro。
- 2025-11-19 支持 Gemini 3 Pro.
- 2025-11-12 支持从 Confluence、S3、Notion、Discord、Google Drive 进行数据同步。
- 2025-10-23 支持 MinerU 和 Docling 作为文档解析方法。
- 2025-10-15 支持可编排的数据管道。
@ -94,7 +93,7 @@
- 2025-08-01 支持 agentic workflow 和 MCP。
- 2025-05-23 Agent 新增 Python/JS 代码执行器组件。
- 2025-05-05 支持跨语言查询。
- 2025-03-19 PDF 和 DOCX 中的图支持用多模态大模型去解析得到描述
- 2025-03-19 PDF 和 DOCX 中的图支持用多模态大模型去解析得到描述.
- 2024-12-18 升级了 DeepDoc 的文档布局分析模型。
- 2024-08-22 支持用 RAG 技术实现从自然语言到 SQL 语句的转换。
@ -188,12 +187,12 @@
> 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。
> 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。
> 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.23.1`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.23.1` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
> 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.22.1`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.22.1` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
```bash
$ cd ragflow/docker
# git checkout v0.23.1
# git checkout v0.22.1
# 可选使用稳定版本标签查看发布https://github.com/infiniflow/ragflow/releases
# 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。
@ -238,7 +237,7 @@
* Running on all addresses (0.0.0.0)
```
> 如果您在没有看到上面的提示信息出来之前,就尝试登录 RAGFlow你的浏览器有可能会提示 `network abnormal` 或 `网络异常`。
> 如果您在没有看到上面的提示信息出来之前,就尝试登录 RAGFlow你的浏览器有可能会提示 `network anormal` 或 `网络异常`。
5. 在你的浏览器中输入你的服务器对应的 IP 地址并登录 RAGFlow。
> 上面这个例子中,您只需输入 http://IP_OF_YOUR_MACHINE 即可:未改动过配置则无需输入端口(默认的 HTTP 服务端口 80
@ -302,15 +301,6 @@ cd ragflow/
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
```
如果您处在代理环境下,可以传递代理参数:
```bash
docker build --platform linux/amd64 \
--build-arg http_proxy=http://YOUR_PROXY:PORT \
--build-arg https_proxy=http://YOUR_PROXY:PORT \
-f Dockerfile -t infiniflow/ragflow:nightly .
```
## 🔨 以源代码启动服务
1. 安装 `uv` 和 `pre-commit`。如已经安装,可跳过本步骤:
@ -402,7 +392,7 @@ docker build --platform linux/amd64 \
## 📜 路线图
详见 [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241) 。
详见 [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214) 。
## 🏄 开源社区

View File

@ -21,7 +21,7 @@ cp pyproject.toml release/$PROJECT_NAME/pyproject.toml
cp README.md release/$PROJECT_NAME/README.md
mkdir release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR -p
cp ragflow_cli.py release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR/ragflow_cli.py
cp admin_client.py release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR/admin_client.py
if [ -d "release/$PROJECT_NAME/$SOURCE_DIR" ]; then
echo "✅ source dir: release/$PROJECT_NAME/$SOURCE_DIR"

View File

@ -48,7 +48,7 @@ It consists of a server-side Service and a command-line client (CLI), both imple
1. Ensure the Admin Service is running.
2. Install ragflow-cli.
```bash
pip install ragflow-cli==0.23.1
pip install ragflow-cli==0.22.1
```
3. Launch the CLI client:
```bash

View File

@ -0,0 +1,978 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import base64
from cmd import Cmd
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from typing import Dict, List, Any
from lark import Lark, Transformer, Tree
import requests
import getpass
GRAMMAR = r"""
start: command
command: sql_command | meta_command
sql_command: list_services
| show_service
| startup_service
| shutdown_service
| restart_service
| list_users
| show_user
| drop_user
| alter_user
| create_user
| activate_user
| list_datasets
| list_agents
| create_role
| drop_role
| alter_role
| list_roles
| show_role
| grant_permission
| revoke_permission
| alter_user_role
| show_user_permission
| show_version
// meta command definition
meta_command: "\\" meta_command_name [meta_args]
meta_command_name: /[a-zA-Z?]+/
meta_args: (meta_arg)+
meta_arg: /[^\\s"']+/ | quoted_string
// command definition
LIST: "LIST"i
SERVICES: "SERVICES"i
SHOW: "SHOW"i
CREATE: "CREATE"i
SERVICE: "SERVICE"i
SHUTDOWN: "SHUTDOWN"i
STARTUP: "STARTUP"i
RESTART: "RESTART"i
USERS: "USERS"i
DROP: "DROP"i
USER: "USER"i
ALTER: "ALTER"i
ACTIVE: "ACTIVE"i
PASSWORD: "PASSWORD"i
DATASETS: "DATASETS"i
OF: "OF"i
AGENTS: "AGENTS"i
ROLE: "ROLE"i
ROLES: "ROLES"i
DESCRIPTION: "DESCRIPTION"i
GRANT: "GRANT"i
REVOKE: "REVOKE"i
ALL: "ALL"i
PERMISSION: "PERMISSION"i
TO: "TO"i
FROM: "FROM"i
FOR: "FOR"i
RESOURCES: "RESOURCES"i
ON: "ON"i
SET: "SET"i
VERSION: "VERSION"i
list_services: LIST SERVICES ";"
show_service: SHOW SERVICE NUMBER ";"
startup_service: STARTUP SERVICE NUMBER ";"
shutdown_service: SHUTDOWN SERVICE NUMBER ";"
restart_service: RESTART SERVICE NUMBER ";"
list_users: LIST USERS ";"
drop_user: DROP USER quoted_string ";"
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
show_user: SHOW USER quoted_string ";"
create_user: CREATE USER quoted_string quoted_string ";"
activate_user: ALTER USER ACTIVE quoted_string status ";"
list_datasets: LIST DATASETS OF quoted_string ";"
list_agents: LIST AGENTS OF quoted_string ";"
create_role: CREATE ROLE identifier [DESCRIPTION quoted_string] ";"
drop_role: DROP ROLE identifier ";"
alter_role: ALTER ROLE identifier SET DESCRIPTION quoted_string ";"
list_roles: LIST ROLES ";"
show_role: SHOW ROLE identifier ";"
grant_permission: GRANT action_list ON identifier TO ROLE identifier ";"
revoke_permission: REVOKE action_list ON identifier FROM ROLE identifier ";"
alter_user_role: ALTER USER quoted_string SET ROLE identifier ";"
show_user_permission: SHOW USER PERMISSION quoted_string ";"
show_version: SHOW VERSION ";"
action_list: identifier ("," identifier)*
identifier: WORD
quoted_string: QUOTED_STRING
status: WORD
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
WORD: /[a-zA-Z0-9_\-\.]+/
NUMBER: /[0-9]+/
%import common.WS
%ignore WS
"""
class AdminTransformer(Transformer):
def start(self, items):
return items[0]
def command(self, items):
return items[0]
def list_services(self, items):
result = {'type': 'list_services'}
return result
def show_service(self, items):
service_id = int(items[2])
return {"type": "show_service", "number": service_id}
def startup_service(self, items):
service_id = int(items[2])
return {"type": "startup_service", "number": service_id}
def shutdown_service(self, items):
service_id = int(items[2])
return {"type": "shutdown_service", "number": service_id}
def restart_service(self, items):
service_id = int(items[2])
return {"type": "restart_service", "number": service_id}
def list_users(self, items):
return {"type": "list_users"}
def show_user(self, items):
user_name = items[2]
return {"type": "show_user", "user_name": user_name}
def drop_user(self, items):
user_name = items[2]
return {"type": "drop_user", "user_name": user_name}
def alter_user(self, items):
user_name = items[3]
new_password = items[4]
return {"type": "alter_user", "user_name": user_name, "password": new_password}
def create_user(self, items):
user_name = items[2]
password = items[3]
return {"type": "create_user", "user_name": user_name, "password": password, "role": "user"}
def activate_user(self, items):
user_name = items[3]
activate_status = items[4]
return {"type": "activate_user", "activate_status": activate_status, "user_name": user_name}
def list_datasets(self, items):
user_name = items[3]
return {"type": "list_datasets", "user_name": user_name}
def list_agents(self, items):
user_name = items[3]
return {"type": "list_agents", "user_name": user_name}
def create_role(self, items):
role_name = items[2]
if len(items) > 4:
description = items[4]
return {"type": "create_role", "role_name": role_name, "description": description}
else:
return {"type": "create_role", "role_name": role_name}
def drop_role(self, items):
role_name = items[2]
return {"type": "drop_role", "role_name": role_name}
def alter_role(self, items):
role_name = items[2]
description = items[5]
return {"type": "alter_role", "role_name": role_name, "description": description}
def list_roles(self, items):
return {"type": "list_roles"}
def show_role(self, items):
role_name = items[2]
return {"type": "show_role", "role_name": role_name}
def grant_permission(self, items):
action_list = items[1]
resource = items[3]
role_name = items[6]
return {"type": "grant_permission", "role_name": role_name, "resource": resource, "actions": action_list}
def revoke_permission(self, items):
action_list = items[1]
resource = items[3]
role_name = items[6]
return {
"type": "revoke_permission",
"role_name": role_name,
"resource": resource, "actions": action_list
}
def alter_user_role(self, items):
user_name = items[2]
role_name = items[5]
return {"type": "alter_user_role", "user_name": user_name, "role_name": role_name}
def show_user_permission(self, items):
user_name = items[3]
return {"type": "show_user_permission", "user_name": user_name}
def show_version(self, items):
return {"type": "show_version"}
def action_list(self, items):
return items
def meta_command(self, items):
command_name = str(items[0]).lower()
args = items[1:] if len(items) > 1 else []
# handle quoted parameter
parsed_args = []
for arg in args:
if hasattr(arg, 'value'):
parsed_args.append(arg.value)
else:
parsed_args.append(str(arg))
return {'type': 'meta', 'command': command_name, 'args': parsed_args}
def meta_command_name(self, items):
return items[0]
def meta_args(self, items):
return items
def encrypt(input_string):
pub = '-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----'
pub_key = RSA.importKey(pub)
cipher = Cipher_pkcs1_v1_5.new(pub_key)
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode('utf-8')))
return base64.b64encode(cipher_text).decode("utf-8")
def encode_to_base64(input_string):
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
return base64_encoded.decode('utf-8')
class AdminCLI(Cmd):
def __init__(self):
super().__init__()
self.parser = Lark(GRAMMAR, start='start', parser='lalr', transformer=AdminTransformer())
self.command_history = []
self.is_interactive = False
self.admin_account = "admin@ragflow.io"
self.admin_password: str = "admin"
self.session = requests.Session()
self.access_token: str = ""
self.host: str = ""
self.port: int = 0
intro = r"""Type "\h" for help."""
prompt = "admin> "
def onecmd(self, command: str) -> bool:
try:
result = self.parse_command(command)
if isinstance(result, dict):
if 'type' in result and result.get('type') == 'empty':
return False
self.execute_command(result)
if isinstance(result, Tree):
return False
if result.get('type') == 'meta' and result.get('command') in ['q', 'quit', 'exit']:
return True
except KeyboardInterrupt:
print("\nUse '\\q' to quit")
except EOFError:
print("\nGoodbye!")
return True
return False
def emptyline(self) -> bool:
return False
def default(self, line: str) -> bool:
return self.onecmd(line)
def parse_command(self, command_str: str) -> dict[str, str]:
if not command_str.strip():
return {'type': 'empty'}
self.command_history.append(command_str)
try:
result = self.parser.parse(command_str)
return result
except Exception as e:
return {'type': 'error', 'message': f'Parse error: {str(e)}'}
def verify_admin(self, arguments: dict, single_command: bool):
self.host = arguments['host']
self.port = arguments['port']
print("Attempt to access server for admin login")
url = f"http://{self.host}:{self.port}/api/v1/admin/login"
attempt_count = 3
if single_command:
attempt_count = 1
try_count = 0
while True:
try_count += 1
if try_count > attempt_count:
return False
if single_command:
admin_passwd = arguments['password']
else:
admin_passwd = getpass.getpass(f"password for {self.admin_account}: ").strip()
try:
self.admin_password = encrypt(admin_passwd)
response = self.session.post(url, json={'email': self.admin_account, 'password': self.admin_password})
if response.status_code == 200:
res_json = response.json()
error_code = res_json.get('code', -1)
if error_code == 0:
self.session.headers.update({
'Content-Type': 'application/json',
'Authorization': response.headers['Authorization'],
'User-Agent': 'RAGFlow-CLI/0.22.1'
})
print("Authentication successful.")
return True
else:
error_message = res_json.get('message', 'Unknown error')
print(f"Authentication failed: {error_message}, try again")
continue
else:
print(f"Bad responsestatus: {response.status_code}, password is wrong")
except Exception as e:
print(str(e))
print("Can't access server for admin login (connection failed)")
def _format_service_detail_table(self, data):
if isinstance(data, list):
return data
if not all([isinstance(v, list) for v in data.values()]):
# normal table
return data
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
task_executor_list = []
for k, v in data.items():
# display latest status
heartbeats = sorted(v, key=lambda x: x["now"], reverse=True)
task_executor_list.append({
"task_executor_name": k,
**heartbeats[0],
} if heartbeats else {"task_executor_name": k})
return task_executor_list
def _print_table_simple(self, data):
if not data:
print("No data to print")
return
if isinstance(data, dict):
# handle single row data
data = [data]
columns = list(set().union(*(d.keys() for d in data)))
columns.sort()
col_widths = {}
def get_string_width(text):
half_width_chars = (
" !\"#$%&'()*+,-./0123456789:;<=>?@"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`"
"abcdefghijklmnopqrstuvwxyz{|}~"
"\t\n\r"
)
width = 0
for char in text:
if char in half_width_chars:
width += 1
else:
width += 2
return width
for col in columns:
max_width = get_string_width(str(col))
for item in data:
value_len = get_string_width(str(item.get(col, '')))
if value_len > max_width:
max_width = value_len
col_widths[col] = max(2, max_width)
# Generate delimiter
separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
# Print header
print(separator)
header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
print(header)
print(separator)
# Print data
for item in data:
row = "|"
for col in columns:
value = str(item.get(col, ''))
if get_string_width(value) > col_widths[col]:
value = value[:col_widths[col] - 3] + "..."
row += f" {value:<{col_widths[col] - (get_string_width(value) - len(value))}} |"
print(row)
print(separator)
def run_interactive(self):
self.is_interactive = True
print("RAGFlow Admin command line interface - Type '\\?' for help, '\\q' to quit")
while True:
try:
command = input("admin> ").strip()
if not command:
continue
print(f"command: {command}")
result = self.parse_command(command)
self.execute_command(result)
if isinstance(result, Tree):
continue
if result.get('type') == 'meta' and result.get('command') in ['q', 'quit', 'exit']:
break
except KeyboardInterrupt:
print("\nUse '\\q' to quit")
except EOFError:
print("\nGoodbye!")
break
def run_single_command(self, command: str):
result = self.parse_command(command)
self.execute_command(result)
def parse_connection_args(self, args: List[str]) -> Dict[str, Any]:
parser = argparse.ArgumentParser(description='Admin CLI Client', add_help=False)
parser.add_argument('-h', '--host', default='localhost', help='Admin service host')
parser.add_argument('-p', '--port', type=int, default=9381, help='Admin service port')
parser.add_argument('-w', '--password', default='admin', type=str, help='Superuser password')
parser.add_argument('command', nargs='?', help='Single command')
try:
parsed_args, remaining_args = parser.parse_known_args(args)
if remaining_args:
command = remaining_args[0]
return {
'host': parsed_args.host,
'port': parsed_args.port,
'password': parsed_args.password,
'command': command
}
else:
return {
'host': parsed_args.host,
'port': parsed_args.port,
}
except SystemExit:
return {'error': 'Invalid connection arguments'}
def execute_command(self, parsed_command: Dict[str, Any]):
command_dict: dict
if isinstance(parsed_command, Tree):
command_dict = parsed_command.children[0]
else:
if parsed_command['type'] == 'error':
print(f"Error: {parsed_command['message']}")
return
else:
command_dict = parsed_command
# print(f"Parsed command: {command_dict}")
command_type = command_dict['type']
match command_type:
case 'list_services':
self._handle_list_services(command_dict)
case 'show_service':
self._handle_show_service(command_dict)
case 'restart_service':
self._handle_restart_service(command_dict)
case 'shutdown_service':
self._handle_shutdown_service(command_dict)
case 'startup_service':
self._handle_startup_service(command_dict)
case 'list_users':
self._handle_list_users(command_dict)
case 'show_user':
self._handle_show_user(command_dict)
case 'drop_user':
self._handle_drop_user(command_dict)
case 'alter_user':
self._handle_alter_user(command_dict)
case 'create_user':
self._handle_create_user(command_dict)
case 'activate_user':
self._handle_activate_user(command_dict)
case 'list_datasets':
self._handle_list_datasets(command_dict)
case 'list_agents':
self._handle_list_agents(command_dict)
case 'create_role':
self._create_role(command_dict)
case 'drop_role':
self._drop_role(command_dict)
case 'alter_role':
self._alter_role(command_dict)
case 'list_roles':
self._list_roles(command_dict)
case 'show_role':
self._show_role(command_dict)
case 'grant_permission':
self._grant_permission(command_dict)
case 'revoke_permission':
self._revoke_permission(command_dict)
case 'alter_user_role':
self._alter_user_role(command_dict)
case 'show_user_permission':
self._show_user_permission(command_dict)
case 'show_version':
self._show_version(command_dict)
case 'meta':
self._handle_meta_command(command_dict)
case _:
print(f"Command '{command_type}' would be executed with API")
def _handle_list_services(self, command):
print("Listing all services")
url = f'http://{self.host}:{self.port}/api/v1/admin/services'
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to get all services, code: {res_json['code']}, message: {res_json['message']}")
def _handle_show_service(self, command):
service_id: int = command['number']
print(f"Showing service: {service_id}")
url = f'http://{self.host}:{self.port}/api/v1/admin/services/{service_id}'
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
res_data = res_json['data']
if 'status' in res_data and res_data['status'] == 'alive':
print(f"Service {res_data['service_name']} is alive, ")
if isinstance(res_data['message'], str):
print(res_data['message'])
else:
data = self._format_service_detail_table(res_data['message'])
self._print_table_simple(data)
else:
print(f"Service {res_data['service_name']} is down, {res_data['message']}")
else:
print(f"Fail to show service, code: {res_json['code']}, message: {res_json['message']}")
def _handle_restart_service(self, command):
service_id: int = command['number']
print(f"Restart service {service_id}")
def _handle_shutdown_service(self, command):
service_id: int = command['number']
print(f"Shutdown service {service_id}")
def _handle_startup_service(self, command):
service_id: int = command['number']
print(f"Startup service {service_id}")
def _handle_list_users(self, command):
print("Listing all users")
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
def _handle_show_user(self, command):
username_tree: Tree = command['user_name']
user_name: str = username_tree.children[0].strip("'\"")
print(f"Showing user: {user_name}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}'
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
table_data = res_json['data']
table_data.pop('avatar')
self._print_table_simple(table_data)
else:
print(f"Fail to get user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
def _handle_drop_user(self, command):
username_tree: Tree = command['user_name']
user_name: str = username_tree.children[0].strip("'\"")
print(f"Drop user: {user_name}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}'
response = self.session.delete(url)
res_json = response.json()
if response.status_code == 200:
print(res_json["message"])
else:
print(f"Fail to drop user, code: {res_json['code']}, message: {res_json['message']}")
def _handle_alter_user(self, command):
user_name_tree: Tree = command['user_name']
user_name: str = user_name_tree.children[0].strip("'\"")
password_tree: Tree = command['password']
password: str = password_tree.children[0].strip("'\"")
print(f"Alter user: {user_name}, password: ******")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/password'
response = self.session.put(url, json={'new_password': encrypt(password)})
res_json = response.json()
if response.status_code == 200:
print(res_json["message"])
else:
print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
def _handle_create_user(self, command):
user_name_tree: Tree = command['user_name']
user_name: str = user_name_tree.children[0].strip("'\"")
password_tree: Tree = command['password']
password: str = password_tree.children[0].strip("'\"")
role: str = command['role']
print(f"Create user: {user_name}, password: ******, role: {role}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
response = self.session.post(
url,
json={'user_name': user_name, 'password': encrypt(password), 'role': role}
)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to create user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
def _handle_activate_user(self, command):
user_name_tree: Tree = command['user_name']
user_name: str = user_name_tree.children[0].strip("'\"")
activate_tree: Tree = command['activate_status']
activate_status: str = activate_tree.children[0].strip("'\"")
if activate_status.lower() in ['on', 'off']:
print(f"Alter user {user_name} activate status, turn {activate_status.lower()}.")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/activate'
response = self.session.put(url, json={'activate_status': activate_status})
res_json = response.json()
if response.status_code == 200:
print(res_json["message"])
else:
print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
else:
print(f"Unknown activate status: {activate_status}.")
def _handle_list_datasets(self, command):
username_tree: Tree = command['user_name']
user_name: str = username_tree.children[0].strip("'\"")
print(f"Listing all datasets of user: {user_name}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/datasets'
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
table_data = res_json['data']
for t in table_data:
t.pop('avatar')
self._print_table_simple(table_data)
else:
print(f"Fail to get all datasets of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
def _handle_list_agents(self, command):
username_tree: Tree = command['user_name']
user_name: str = username_tree.children[0].strip("'\"")
print(f"Listing all agents of user: {user_name}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/agents'
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
table_data = res_json['data']
for t in table_data:
t.pop('avatar')
self._print_table_simple(table_data)
else:
print(f"Fail to get all agents of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
def _create_role(self, command):
role_name_tree: Tree = command['role_name']
role_name: str = role_name_tree.children[0].strip("'\"")
desc_str: str = ''
if 'description' in command:
desc_tree: Tree = command['description']
desc_str = desc_tree.children[0].strip("'\"")
print(f"create role name: {role_name}, description: {desc_str}")
url = f'http://{self.host}:{self.port}/api/v1/admin/roles'
response = self.session.post(
url,
json={'role_name': role_name, 'description': desc_str}
)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to create role {role_name}, code: {res_json['code']}, message: {res_json['message']}")
def _drop_role(self, command):
role_name_tree: Tree = command['role_name']
role_name: str = role_name_tree.children[0].strip("'\"")
print(f"drop role name: {role_name}")
url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}'
response = self.session.delete(url)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to drop role {role_name}, code: {res_json['code']}, message: {res_json['message']}")
def _alter_role(self, command):
role_name_tree: Tree = command['role_name']
role_name: str = role_name_tree.children[0].strip("'\"")
desc_tree: Tree = command['description']
desc_str: str = desc_tree.children[0].strip("'\"")
print(f"alter role name: {role_name}, description: {desc_str}")
url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}'
response = self.session.put(
url,
json={'description': desc_str}
)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(
f"Fail to update role {role_name} with description: {desc_str}, code: {res_json['code']}, message: {res_json['message']}")
def _list_roles(self, command):
print("Listing all roles")
url = f'http://{self.host}:{self.port}/api/v1/admin/roles'
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to list roles, code: {res_json['code']}, message: {res_json['message']}")
def _show_role(self, command):
role_name_tree: Tree = command['role_name']
role_name: str = role_name_tree.children[0].strip("'\"")
print(f"show role: {role_name}")
url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}/permission'
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to list roles, code: {res_json['code']}, message: {res_json['message']}")
def _grant_permission(self, command):
role_name_tree: Tree = command['role_name']
role_name_str: str = role_name_tree.children[0].strip("'\"")
resource_tree: Tree = command['resource']
resource_str: str = resource_tree.children[0].strip("'\"")
action_tree_list: list = command['actions']
actions: list = []
for action_tree in action_tree_list:
action_str: str = action_tree.children[0].strip("'\"")
actions.append(action_str)
print(f"grant role_name: {role_name_str}, resource: {resource_str}, actions: {actions}")
url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name_str}/permission'
response = self.session.post(
url,
json={'actions': actions, 'resource': resource_str}
)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(
f"Fail to grant role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}")
def _revoke_permission(self, command):
role_name_tree: Tree = command['role_name']
role_name_str: str = role_name_tree.children[0].strip("'\"")
resource_tree: Tree = command['resource']
resource_str: str = resource_tree.children[0].strip("'\"")
action_tree_list: list = command['actions']
actions: list = []
for action_tree in action_tree_list:
action_str: str = action_tree.children[0].strip("'\"")
actions.append(action_str)
print(f"revoke role_name: {role_name_str}, resource: {resource_str}, actions: {actions}")
url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name_str}/permission'
response = self.session.delete(
url,
json={'actions': actions, 'resource': resource_str}
)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(
f"Fail to revoke role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}")
def _alter_user_role(self, command):
role_name_tree: Tree = command['role_name']
role_name_str: str = role_name_tree.children[0].strip("'\"")
user_name_tree: Tree = command['user_name']
user_name_str: str = user_name_tree.children[0].strip("'\"")
print(f"alter_user_role user_name: {user_name_str}, role_name: {role_name_str}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name_str}/role'
response = self.session.put(
url,
json={'role_name': role_name_str}
)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(
f"Fail to alter user: {user_name_str} to role {role_name_str}, code: {res_json['code']}, message: {res_json['message']}")
def _show_user_permission(self, command):
user_name_tree: Tree = command['user_name']
user_name_str: str = user_name_tree.children[0].strip("'\"")
print(f"show_user_permission user_name: {user_name_str}")
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name_str}/permission'
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(
f"Fail to show user: {user_name_str} permission, code: {res_json['code']}, message: {res_json['message']}")
def _show_version(self, command):
print("show_version")
url = f'http://{self.host}:{self.port}/api/v1/admin/version'
response = self.session.get(url)
res_json = response.json()
if response.status_code == 200:
self._print_table_simple(res_json['data'])
else:
print(f"Fail to show version, code: {res_json['code']}, message: {res_json['message']}")
def _handle_meta_command(self, command):
meta_command = command['command']
args = command.get('args', [])
if meta_command in ['?', 'h', 'help']:
self.show_help()
elif meta_command in ['q', 'quit', 'exit']:
print("Goodbye!")
else:
print(f"Meta command '{meta_command}' with args {args}")
def show_help(self):
"""Help info"""
help_text = """
Commands:
LIST SERVICES
SHOW SERVICE <service>
STARTUP SERVICE <service>
SHUTDOWN SERVICE <service>
RESTART SERVICE <service>
LIST USERS
SHOW USER <user>
DROP USER <user>
CREATE USER <user> <password>
ALTER USER PASSWORD <user> <new_password>
ALTER USER ACTIVE <user> <on/off>
LIST DATASETS OF <user>
LIST AGENTS OF <user>
Meta Commands:
\\?, \\h, \\help Show this help
\\q, \\quit, \\exit Quit the CLI
"""
print(help_text)
def main():
import sys
cli = AdminCLI()
args = cli.parse_connection_args(sys.argv)
if 'error' in args:
print("Error: Invalid connection arguments")
return
if 'command' in args:
if 'password' not in args:
print("Error: password is missing")
return
if cli.verify_admin(args, single_command=True):
command: str = args['command']
# print(f"Run single command: {command}")
cli.run_single_command(command)
else:
if cli.verify_admin(args, single_command=False):
print(r"""
____ ___ ______________ ___ __ _
/ __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
""")
cli.cmdloop()
if __name__ == '__main__':
main()

View File

@ -1,182 +0,0 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import json
import typing
from typing import Any, Dict, Optional
import requests
# from requests.sessions import HTTPAdapter
class HttpClient:
def __init__(
self,
host: str = "127.0.0.1",
port: int = 9381,
api_version: str = "v1",
api_key: Optional[str] = None,
connect_timeout: float = 5.0,
read_timeout: float = 60.0,
verify_ssl: bool = False,
) -> None:
self.host = host
self.port = port
self.api_version = api_version
self.api_key = api_key
self.login_token: str | None = None
self.connect_timeout = connect_timeout
self.read_timeout = read_timeout
self.verify_ssl = verify_ssl
def api_base(self) -> str:
return f"{self.host}:{self.port}/api/{self.api_version}"
def non_api_base(self) -> str:
return f"{self.host}:{self.port}/{self.api_version}"
def build_url(self, path: str, use_api_base: bool = True) -> str:
base = self.api_base() if use_api_base else self.non_api_base()
if self.verify_ssl:
return f"https://{base}/{path.lstrip('/')}"
else:
return f"http://{base}/{path.lstrip('/')}"
def _headers(self, auth_kind: Optional[str], extra: Optional[Dict[str, str]]) -> Dict[str, str]:
headers = {}
if auth_kind == "api" and self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
elif auth_kind == "web" and self.login_token:
headers["Authorization"] = self.login_token
elif auth_kind == "admin" and self.login_token:
headers["Authorization"] = self.login_token
else:
pass
if extra:
headers.update(extra)
return headers
def request(
self,
method: str,
path: str,
*,
use_api_base: bool = True,
auth_kind: Optional[str] = "api",
headers: Optional[Dict[str, str]] = None,
json_body: Optional[Dict[str, Any]] = None,
data: Any = None,
files: Any = None,
params: Optional[Dict[str, Any]] = None,
stream: bool = False,
iterations: int = 1,
) -> requests.Response | dict:
url = self.build_url(path, use_api_base=use_api_base)
merged_headers = self._headers(auth_kind, headers)
# timeout: Tuple[float, float] = (self.connect_timeout, self.read_timeout)
session = requests.Session()
# adapter = HTTPAdapter(pool_connections=100, pool_maxsize=100)
# session.mount("http://", adapter)
http_function = typing.Any
match method:
case "GET":
http_function = session.get
case "POST":
http_function = session.post
case "PUT":
http_function = session.put
case "DELETE":
http_function = session.delete
case "PATCH":
http_function = session.patch
case _:
raise ValueError(f"Invalid HTTP method: {method}")
if iterations > 1:
response_list = []
total_duration = 0.0
for _ in range(iterations):
start_time = time.perf_counter()
response = http_function(url, headers=merged_headers, json=json_body, data=data, stream=stream)
# response = session.get(url, headers=merged_headers, json=json_body, data=data, stream=stream)
# response = requests.request(
# method=method,
# url=url,
# headers=merged_headers,
# json=json_body,
# data=data,
# files=files,
# params=params,
# stream=stream,
# verify=self.verify_ssl,
# )
end_time = time.perf_counter()
total_duration += end_time - start_time
response_list.append(response)
return {"duration": total_duration, "response_list": response_list}
else:
return http_function(url, headers=merged_headers, json=json_body, data=data, stream=stream)
# return session.get(url, headers=merged_headers, json=json_body, data=data, stream=stream)
# return requests.request(
# method=method,
# url=url,
# headers=merged_headers,
# json=json_body,
# data=data,
# files=files,
# params=params,
# stream=stream,
# verify=self.verify_ssl,
# )
def request_json(
self,
method: str,
path: str,
*,
use_api_base: bool = True,
auth_kind: Optional[str] = "api",
headers: Optional[Dict[str, str]] = None,
json_body: Optional[Dict[str, Any]] = None,
data: Any = None,
files: Any = None,
params: Optional[Dict[str, Any]] = None,
stream: bool = False,
) -> Dict[str, Any]:
response = self.request(
method,
path,
use_api_base=use_api_base,
auth_kind=auth_kind,
headers=headers,
json_body=json_body,
data=data,
files=files,
params=params,
stream=stream,
)
try:
return response.json()
except Exception as exc:
raise ValueError(f"Non-JSON response from {path}: {exc}") from exc
@staticmethod
def parse_json_bytes(raw: bytes) -> Dict[str, Any]:
try:
return json.loads(raw.decode("utf-8"))
except Exception as exc:
raise ValueError(f"Invalid JSON payload: {exc}") from exc

View File

@ -1,623 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from lark import Transformer
GRAMMAR = r"""
start: command
command: sql_command | meta_command
sql_command: login_user
| ping_server
| list_services
| show_service
| startup_service
| shutdown_service
| restart_service
| register_user
| list_users
| show_user
| drop_user
| alter_user
| create_user
| activate_user
| list_datasets
| list_agents
| create_role
| drop_role
| alter_role
| list_roles
| show_role
| grant_permission
| revoke_permission
| alter_user_role
| show_user_permission
| show_version
| grant_admin
| revoke_admin
| set_variable
| show_variable
| list_variables
| list_configs
| list_environments
| generate_key
| list_keys
| drop_key
| show_current_user
| set_default_llm
| set_default_vlm
| set_default_embedding
| set_default_reranker
| set_default_asr
| set_default_tts
| reset_default_llm
| reset_default_vlm
| reset_default_embedding
| reset_default_reranker
| reset_default_asr
| reset_default_tts
| create_model_provider
| drop_model_provider
| create_user_dataset_with_parser
| create_user_dataset_with_pipeline
| drop_user_dataset
| list_user_datasets
| list_user_dataset_files
| list_user_agents
| list_user_chats
| create_user_chat
| drop_user_chat
| list_user_model_providers
| list_user_default_models
| parse_dataset_docs
| parse_dataset_sync
| parse_dataset_async
| import_docs_into_dataset
| search_on_datasets
| benchmark
// meta command definition
meta_command: "\\" meta_command_name [meta_args]
meta_command_name: /[a-zA-Z?]+/
meta_args: (meta_arg)+
meta_arg: /[^\\s"']+/ | quoted_string
// command definition
LOGIN: "LOGIN"i
REGISTER: "REGISTER"i
LIST: "LIST"i
SERVICES: "SERVICES"i
SHOW: "SHOW"i
CREATE: "CREATE"i
SERVICE: "SERVICE"i
SHUTDOWN: "SHUTDOWN"i
STARTUP: "STARTUP"i
RESTART: "RESTART"i
USERS: "USERS"i
DROP: "DROP"i
USER: "USER"i
ALTER: "ALTER"i
ACTIVE: "ACTIVE"i
ADMIN: "ADMIN"i
PASSWORD: "PASSWORD"i
DATASET: "DATASET"i
DATASETS: "DATASETS"i
OF: "OF"i
AGENTS: "AGENTS"i
ROLE: "ROLE"i
ROLES: "ROLES"i
DESCRIPTION: "DESCRIPTION"i
GRANT: "GRANT"i
REVOKE: "REVOKE"i
ALL: "ALL"i
PERMISSION: "PERMISSION"i
TO: "TO"i
FROM: "FROM"i
FOR: "FOR"i
RESOURCES: "RESOURCES"i
ON: "ON"i
SET: "SET"i
RESET: "RESET"i
VERSION: "VERSION"i
VAR: "VAR"i
VARS: "VARS"i
CONFIGS: "CONFIGS"i
ENVS: "ENVS"i
KEY: "KEY"i
KEYS: "KEYS"i
GENERATE: "GENERATE"i
MODEL: "MODEL"i
MODELS: "MODELS"i
PROVIDER: "PROVIDER"i
PROVIDERS: "PROVIDERS"i
DEFAULT: "DEFAULT"i
CHATS: "CHATS"i
CHAT: "CHAT"i
FILES: "FILES"i
AS: "AS"i
PARSE: "PARSE"i
IMPORT: "IMPORT"i
INTO: "INTO"i
WITH: "WITH"i
PARSER: "PARSER"i
PIPELINE: "PIPELINE"i
SEARCH: "SEARCH"i
CURRENT: "CURRENT"i
LLM: "LLM"i
VLM: "VLM"i
EMBEDDING: "EMBEDDING"i
RERANKER: "RERANKER"i
ASR: "ASR"i
TTS: "TTS"i
ASYNC: "ASYNC"i
SYNC: "SYNC"i
BENCHMARK: "BENCHMARK"i
PING: "PING"i
login_user: LOGIN USER quoted_string ";"
list_services: LIST SERVICES ";"
show_service: SHOW SERVICE NUMBER ";"
startup_service: STARTUP SERVICE NUMBER ";"
shutdown_service: SHUTDOWN SERVICE NUMBER ";"
restart_service: RESTART SERVICE NUMBER ";"
register_user: REGISTER USER quoted_string AS quoted_string PASSWORD quoted_string ";"
list_users: LIST USERS ";"
drop_user: DROP USER quoted_string ";"
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
show_user: SHOW USER quoted_string ";"
create_user: CREATE USER quoted_string quoted_string ";"
activate_user: ALTER USER ACTIVE quoted_string status ";"
list_datasets: LIST DATASETS OF quoted_string ";"
list_agents: LIST AGENTS OF quoted_string ";"
create_role: CREATE ROLE identifier [DESCRIPTION quoted_string] ";"
drop_role: DROP ROLE identifier ";"
alter_role: ALTER ROLE identifier SET DESCRIPTION quoted_string ";"
list_roles: LIST ROLES ";"
show_role: SHOW ROLE identifier ";"
grant_permission: GRANT identifier_list ON identifier TO ROLE identifier ";"
revoke_permission: REVOKE identifier_list ON identifier FROM ROLE identifier ";"
alter_user_role: ALTER USER quoted_string SET ROLE identifier ";"
show_user_permission: SHOW USER PERMISSION quoted_string ";"
show_version: SHOW VERSION ";"
grant_admin: GRANT ADMIN quoted_string ";"
revoke_admin: REVOKE ADMIN quoted_string ";"
generate_key: GENERATE KEY FOR USER quoted_string ";"
list_keys: LIST KEYS OF quoted_string ";"
drop_key: DROP KEY quoted_string OF quoted_string ";"
set_variable: SET VAR identifier identifier ";"
show_variable: SHOW VAR identifier ";"
list_variables: LIST VARS ";"
list_configs: LIST CONFIGS ";"
list_environments: LIST ENVS ";"
benchmark: BENCHMARK NUMBER NUMBER user_statement
user_statement: ping_server
| show_current_user
| create_model_provider
| drop_model_provider
| set_default_llm
| set_default_vlm
| set_default_embedding
| set_default_reranker
| set_default_asr
| set_default_tts
| reset_default_llm
| reset_default_vlm
| reset_default_embedding
| reset_default_reranker
| reset_default_asr
| reset_default_tts
| create_user_dataset_with_parser
| create_user_dataset_with_pipeline
| drop_user_dataset
| list_user_datasets
| list_user_dataset_files
| list_user_agents
| list_user_chats
| create_user_chat
| drop_user_chat
| list_user_model_providers
| list_user_default_models
| import_docs_into_dataset
| search_on_datasets
ping_server: PING ";"
show_current_user: SHOW CURRENT USER ";"
create_model_provider: CREATE MODEL PROVIDER quoted_string quoted_string ";"
drop_model_provider: DROP MODEL PROVIDER quoted_string ";"
set_default_llm: SET DEFAULT LLM quoted_string ";"
set_default_vlm: SET DEFAULT VLM quoted_string ";"
set_default_embedding: SET DEFAULT EMBEDDING quoted_string ";"
set_default_reranker: SET DEFAULT RERANKER quoted_string ";"
set_default_asr: SET DEFAULT ASR quoted_string ";"
set_default_tts: SET DEFAULT TTS quoted_string ";"
reset_default_llm: RESET DEFAULT LLM ";"
reset_default_vlm: RESET DEFAULT VLM ";"
reset_default_embedding: RESET DEFAULT EMBEDDING ";"
reset_default_reranker: RESET DEFAULT RERANKER ";"
reset_default_asr: RESET DEFAULT ASR ";"
reset_default_tts: RESET DEFAULT TTS ";"
list_user_datasets: LIST DATASETS ";"
create_user_dataset_with_parser: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PARSER quoted_string ";"
create_user_dataset_with_pipeline: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PIPELINE quoted_string ";"
drop_user_dataset: DROP DATASET quoted_string ";"
list_user_dataset_files: LIST FILES OF DATASET quoted_string ";"
list_user_agents: LIST AGENTS ";"
list_user_chats: LIST CHATS ";"
create_user_chat: CREATE CHAT quoted_string ";"
drop_user_chat: DROP CHAT quoted_string ";"
list_user_model_providers: LIST MODEL PROVIDERS ";"
list_user_default_models: LIST DEFAULT MODELS ";"
import_docs_into_dataset: IMPORT quoted_string INTO DATASET quoted_string ";"
search_on_datasets: SEARCH quoted_string ON DATASETS quoted_string ";"
parse_dataset_docs: PARSE quoted_string OF DATASET quoted_string ";"
parse_dataset_sync: PARSE DATASET quoted_string SYNC ";"
parse_dataset_async: PARSE DATASET quoted_string ASYNC ";"
identifier_list: identifier ("," identifier)*
identifier: WORD
quoted_string: QUOTED_STRING
status: WORD
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
WORD: /[a-zA-Z0-9_\-\.]+/
NUMBER: /[0-9]+/
%import common.WS
%ignore WS
"""
class RAGFlowCLITransformer(Transformer):
def start(self, items):
return items[0]
def command(self, items):
return items[0]
def login_user(self, items):
email = items[2].children[0].strip("'\"")
return {"type": "login_user", "email": email}
def ping_server(self, items):
return {"type": "ping_server"}
def list_services(self, items):
result = {"type": "list_services"}
return result
def show_service(self, items):
service_id = int(items[2])
return {"type": "show_service", "number": service_id}
def startup_service(self, items):
service_id = int(items[2])
return {"type": "startup_service", "number": service_id}
def shutdown_service(self, items):
service_id = int(items[2])
return {"type": "shutdown_service", "number": service_id}
def restart_service(self, items):
service_id = int(items[2])
return {"type": "restart_service", "number": service_id}
def register_user(self, items):
user_name: str = items[2].children[0].strip("'\"")
nickname: str = items[4].children[0].strip("'\"")
password: str = items[6].children[0].strip("'\"")
return {"type": "register_user", "user_name": user_name, "nickname": nickname, "password": password}
def list_users(self, items):
return {"type": "list_users"}
def show_user(self, items):
user_name = items[2]
return {"type": "show_user", "user_name": user_name}
def drop_user(self, items):
user_name = items[2]
return {"type": "drop_user", "user_name": user_name}
def alter_user(self, items):
user_name = items[3]
new_password = items[4]
return {"type": "alter_user", "user_name": user_name, "password": new_password}
def create_user(self, items):
user_name = items[2]
password = items[3]
return {"type": "create_user", "user_name": user_name, "password": password, "role": "user"}
def activate_user(self, items):
user_name = items[3]
activate_status = items[4]
return {"type": "activate_user", "activate_status": activate_status, "user_name": user_name}
def list_datasets(self, items):
user_name = items[3]
return {"type": "list_datasets", "user_name": user_name}
def list_agents(self, items):
user_name = items[3]
return {"type": "list_agents", "user_name": user_name}
def create_role(self, items):
role_name = items[2]
if len(items) > 4:
description = items[4]
return {"type": "create_role", "role_name": role_name, "description": description}
else:
return {"type": "create_role", "role_name": role_name}
def drop_role(self, items):
role_name = items[2]
return {"type": "drop_role", "role_name": role_name}
def alter_role(self, items):
role_name = items[2]
description = items[5]
return {"type": "alter_role", "role_name": role_name, "description": description}
def list_roles(self, items):
return {"type": "list_roles"}
def show_role(self, items):
role_name = items[2]
return {"type": "show_role", "role_name": role_name}
def grant_permission(self, items):
action_list = items[1]
resource = items[3]
role_name = items[6]
return {"type": "grant_permission", "role_name": role_name, "resource": resource, "actions": action_list}
def revoke_permission(self, items):
action_list = items[1]
resource = items[3]
role_name = items[6]
return {"type": "revoke_permission", "role_name": role_name, "resource": resource, "actions": action_list}
def alter_user_role(self, items):
user_name = items[2]
role_name = items[5]
return {"type": "alter_user_role", "user_name": user_name, "role_name": role_name}
def show_user_permission(self, items):
user_name = items[3]
return {"type": "show_user_permission", "user_name": user_name}
def show_version(self, items):
return {"type": "show_version"}
def grant_admin(self, items):
user_name = items[2]
return {"type": "grant_admin", "user_name": user_name}
def revoke_admin(self, items):
user_name = items[2]
return {"type": "revoke_admin", "user_name": user_name}
def generate_key(self, items):
user_name = items[4]
return {"type": "generate_key", "user_name": user_name}
def list_keys(self, items):
user_name = items[3]
return {"type": "list_keys", "user_name": user_name}
def drop_key(self, items):
key = items[2]
user_name = items[4]
return {"type": "drop_key", "key": key, "user_name": user_name}
def set_variable(self, items):
var_name = items[2]
var_value = items[3]
return {"type": "set_variable", "var_name": var_name, "var_value": var_value}
def show_variable(self, items):
var_name = items[2]
return {"type": "show_variable", "var_name": var_name}
def list_variables(self, items):
return {"type": "list_variables"}
def list_configs(self, items):
return {"type": "list_configs"}
def list_environments(self, items):
return {"type": "list_environments"}
def create_model_provider(self, items):
provider_name = items[3].children[0].strip("'\"")
provider_key = items[4].children[0].strip("'\"")
return {"type": "create_model_provider", "provider_name": provider_name, "provider_key": provider_key}
def drop_model_provider(self, items):
provider_name = items[3].children[0].strip("'\"")
return {"type": "drop_model_provider", "provider_name": provider_name}
def show_current_user(self, items):
return {"type": "show_current_user"}
def set_default_llm(self, items):
llm_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "llm_id", "model_id": llm_id}
def set_default_vlm(self, items):
vlm_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "img2txt_id", "model_id": vlm_id}
def set_default_embedding(self, items):
embedding_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "embd_id", "model_id": embedding_id}
def set_default_reranker(self, items):
reranker_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "reranker_id", "model_id": reranker_id}
def set_default_asr(self, items):
asr_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "asr_id", "model_id": asr_id}
def set_default_tts(self, items):
tts_id = items[3].children[0].strip("'\"")
return {"type": "set_default_model", "model_type": "tts_id", "model_id": tts_id}
def reset_default_llm(self, items):
return {"type": "reset_default_model", "model_type": "llm_id"}
def reset_default_vlm(self, items):
return {"type": "reset_default_model", "model_type": "img2txt_id"}
def reset_default_embedding(self, items):
return {"type": "reset_default_model", "model_type": "embd_id"}
def reset_default_reranker(self, items):
return {"type": "reset_default_model", "model_type": "reranker_id"}
def reset_default_asr(self, items):
return {"type": "reset_default_model", "model_type": "asr_id"}
def reset_default_tts(self, items):
return {"type": "reset_default_model", "model_type": "tts_id"}
def list_user_datasets(self, items):
return {"type": "list_user_datasets"}
def create_user_dataset_with_parser(self, items):
dataset_name = items[2].children[0].strip("'\"")
embedding = items[5].children[0].strip("'\"")
parser_type = items[7].children[0].strip("'\"")
return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding,
"parser_type": parser_type}
def create_user_dataset_with_pipeline(self, items):
dataset_name = items[2].children[0].strip("'\"")
embedding = items[5].children[0].strip("'\"")
pipeline = items[7].children[0].strip("'\"")
return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding,
"pipeline": pipeline}
def drop_user_dataset(self, items):
dataset_name = items[2].children[0].strip("'\"")
return {"type": "drop_user_dataset", "dataset_name": dataset_name}
def list_user_dataset_files(self, items):
dataset_name = items[4].children[0].strip("'\"")
return {"type": "list_user_dataset_files", "dataset_name": dataset_name}
def list_user_agents(self, items):
return {"type": "list_user_agents"}
def list_user_chats(self, items):
return {"type": "list_user_chats"}
def create_user_chat(self, items):
chat_name = items[2].children[0].strip("'\"")
return {"type": "create_user_chat", "chat_name": chat_name}
def drop_user_chat(self, items):
chat_name = items[2].children[0].strip("'\"")
return {"type": "drop_user_chat", "chat_name": chat_name}
def list_user_model_providers(self, items):
return {"type": "list_user_model_providers"}
def list_user_default_models(self, items):
return {"type": "list_user_default_models"}
def parse_dataset_docs(self, items):
document_list_str = items[1].children[0].strip("'\"")
document_names = document_list_str.split(",")
if len(document_names) == 1:
document_names = document_names[0]
document_names = document_names.split(" ")
dataset_name = items[4].children[0].strip("'\"")
return {"type": "parse_dataset_docs", "dataset_name": dataset_name, "document_names": document_names}
def parse_dataset_sync(self, items):
dataset_name = items[2].children[0].strip("'\"")
return {"type": "parse_dataset", "dataset_name": dataset_name, "method": "sync"}
def parse_dataset_async(self, items):
dataset_name = items[2].children[0].strip("'\"")
return {"type": "parse_dataset", "dataset_name": dataset_name, "method": "async"}
def import_docs_into_dataset(self, items):
document_list_str = items[1].children[0].strip("'\"")
document_paths = document_list_str.split(",")
if len(document_paths) == 1:
document_paths = document_paths[0]
document_paths = document_paths.split(" ")
dataset_name = items[4].children[0].strip("'\"")
return {"type": "import_docs_into_dataset", "dataset_name": dataset_name, "document_paths": document_paths}
def search_on_datasets(self, items):
question = items[1].children[0].strip("'\"")
datasets_str = items[4].children[0].strip("'\"")
datasets = datasets_str.split(",")
if len(datasets) == 1:
datasets = datasets[0]
datasets = datasets.split(" ")
return {"type": "search_on_datasets", "datasets": datasets, "question": question}
def benchmark(self, items):
concurrency: int = int(items[1])
iterations: int = int(items[2])
command = items[3].children[0]
return {"type": "benchmark", "concurrency": concurrency, "iterations": iterations, "command": command}
def action_list(self, items):
return items
def meta_command(self, items):
command_name = str(items[0]).lower()
args = items[1:] if len(items) > 1 else []
# handle quoted parameter
parsed_args = []
for arg in args:
if hasattr(arg, "value"):
parsed_args.append(arg.value)
else:
parsed_args.append(str(arg))
return {"type": "meta", "command": command_name, "args": parsed_args}
def meta_command_name(self, items):
return items[0]
def meta_args(self, items):
return items

View File

@ -1,6 +1,6 @@
[project]
name = "ragflow-cli"
version = "0.23.1"
version = "0.22.1"
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
license = { text = "Apache License, Version 2.0" }
@ -20,8 +20,5 @@ test = [
"requests-toolbelt>=1.0.0",
]
[tool.setuptools]
py-modules = ["ragflow_cli", "parser"]
[project.scripts]
ragflow-cli = "ragflow_cli:main"
ragflow-cli = "admin_client:main"

View File

@ -1,322 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import argparse
import base64
import getpass
from cmd import Cmd
from typing import Any, Dict, List
import requests
import warnings
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from Cryptodome.PublicKey import RSA
from lark import Lark, Tree
from parser import GRAMMAR, RAGFlowCLITransformer
from http_client import HttpClient
from ragflow_client import RAGFlowClient, run_command
from user import login_user
warnings.filterwarnings("ignore", category=getpass.GetPassWarning)
def encrypt(input_string):
pub = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----"
pub_key = RSA.importKey(pub)
cipher = Cipher_pkcs1_v1_5.new(pub_key)
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode("utf-8")))
return base64.b64encode(cipher_text).decode("utf-8")
def encode_to_base64(input_string):
base64_encoded = base64.b64encode(input_string.encode("utf-8"))
return base64_encoded.decode("utf-8")
class RAGFlowCLI(Cmd):
def __init__(self):
super().__init__()
self.parser = Lark(GRAMMAR, start="start", parser="lalr", transformer=RAGFlowCLITransformer())
self.command_history = []
self.account = "admin@ragflow.io"
self.account_password: str = "admin"
self.session = requests.Session()
self.host: str = ""
self.port: int = 0
self.mode: str = "admin"
self.ragflow_client = None
intro = r"""Type "\h" for help."""
prompt = "ragflow> "
def onecmd(self, command: str) -> bool:
try:
result = self.parse_command(command)
if isinstance(result, dict):
if "type" in result and result.get("type") == "empty":
return False
self.execute_command(result)
if isinstance(result, Tree):
return False
if result.get("type") == "meta" and result.get("command") in ["q", "quit", "exit"]:
return True
except KeyboardInterrupt:
print("\nUse '\\q' to quit")
except EOFError:
print("\nGoodbye!")
return True
return False
def emptyline(self) -> bool:
return False
def default(self, line: str) -> bool:
return self.onecmd(line)
def parse_command(self, command_str: str) -> dict[str, str]:
if not command_str.strip():
return {"type": "empty"}
self.command_history.append(command_str)
try:
result = self.parser.parse(command_str)
return result
except Exception as e:
return {"type": "error", "message": f"Parse error: {str(e)}"}
def verify_auth(self, arguments: dict, single_command: bool, auth: bool):
server_type = arguments.get("type", "admin")
http_client = HttpClient(arguments["host"], arguments["port"])
if not auth:
self.ragflow_client = RAGFlowClient(http_client, server_type)
return True
user_name = arguments["username"]
attempt_count = 3
if single_command:
attempt_count = 1
try_count = 0
while True:
try_count += 1
if try_count > attempt_count:
return False
if single_command:
user_password = arguments["password"]
else:
user_password = getpass.getpass(f"password for {user_name}: ").strip()
try:
token = login_user(http_client, server_type, user_name, user_password)
http_client.login_token = token
self.ragflow_client = RAGFlowClient(http_client, server_type)
return True
except Exception as e:
print(str(e))
print("Can't access server for login (connection failed)")
def _format_service_detail_table(self, data):
if isinstance(data, list):
return data
if not all([isinstance(v, list) for v in data.values()]):
# normal table
return data
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
task_executor_list = []
for k, v in data.items():
# display latest status
heartbeats = sorted(v, key=lambda x: x["now"], reverse=True)
task_executor_list.append(
{
"task_executor_name": k,
**heartbeats[0],
}
if heartbeats
else {"task_executor_name": k}
)
return task_executor_list
def _print_table_simple(self, data):
if not data:
print("No data to print")
return
if isinstance(data, dict):
# handle single row data
data = [data]
columns = list(set().union(*(d.keys() for d in data)))
columns.sort()
col_widths = {}
def get_string_width(text):
half_width_chars = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\t\n\r"
width = 0
for char in text:
if char in half_width_chars:
width += 1
else:
width += 2
return width
for col in columns:
max_width = get_string_width(str(col))
for item in data:
value_len = get_string_width(str(item.get(col, "")))
if value_len > max_width:
max_width = value_len
col_widths[col] = max(2, max_width)
# Generate delimiter
separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
# Print header
print(separator)
header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
print(header)
print(separator)
# Print data
for item in data:
row = "|"
for col in columns:
value = str(item.get(col, ""))
if get_string_width(value) > col_widths[col]:
value = value[: col_widths[col] - 3] + "..."
row += f" {value:<{col_widths[col] - (get_string_width(value) - len(value))}} |"
print(row)
print(separator)
def run_interactive(self, args):
if self.verify_auth(args, single_command=False, auth=args["auth"]):
print(r"""
____ ___ ______________ ________ ____
/ __ \/ | / ____/ ____/ /___ _ __ / ____/ / / _/
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / / / / / /
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / /___/ /____/ /
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ \____/_____/___/
""")
self.cmdloop()
print("RAGFlow command line interface - Type '\\?' for help, '\\q' to quit")
def run_single_command(self, args):
if self.verify_auth(args, single_command=True, auth=args["auth"]):
command = args["command"]
result = self.parse_command(command)
self.execute_command(result)
def parse_connection_args(self, args: List[str]) -> Dict[str, Any]:
parser = argparse.ArgumentParser(description="RAGFlow CLI Client", add_help=False)
parser.add_argument("-h", "--host", default="127.0.0.1", help="Admin or RAGFlow service host")
parser.add_argument("-p", "--port", type=int, default=9381, help="Admin or RAGFlow service port")
parser.add_argument("-w", "--password", default="admin", type=str, help="Superuser password")
parser.add_argument("-t", "--type", default="admin", type=str, help="CLI mode, admin or user")
parser.add_argument("-u", "--username", default=None,
help="Username (email). In admin mode defaults to admin@ragflow.io, in user mode required.")
parser.add_argument("command", nargs="?", help="Single command")
try:
parsed_args, remaining_args = parser.parse_known_args(args)
# Determine username based on mode
username = parsed_args.username
if parsed_args.type == "admin":
if username is None:
username = "admin@ragflow.io"
if remaining_args:
if remaining_args[0] == "command":
command_str = ' '.join(remaining_args[1:]) + ';'
auth = True
if remaining_args[1] == "register":
auth = False
else:
if username is None:
print("Error: username (-u) is required in user mode")
return {"error": "Username required"}
return {
"host": parsed_args.host,
"port": parsed_args.port,
"password": parsed_args.password,
"type": parsed_args.type,
"username": username,
"command": command_str,
"auth": auth
}
else:
return {"error": "Invalid command"}
else:
auth = True
if username is None:
auth = False
return {
"host": parsed_args.host,
"port": parsed_args.port,
"type": parsed_args.type,
"username": username,
"auth": auth
}
except SystemExit:
return {"error": "Invalid connection arguments"}
def execute_command(self, parsed_command: Dict[str, Any]):
command_dict: dict
if isinstance(parsed_command, Tree):
command_dict = parsed_command.children[0]
else:
if parsed_command["type"] == "error":
print(f"Error: {parsed_command['message']}")
return
else:
command_dict = parsed_command
# print(f"Parsed command: {command_dict}")
run_command(self.ragflow_client, command_dict)
def main():
cli = RAGFlowCLI()
args = cli.parse_connection_args(sys.argv)
if "error" in args:
print("Error: Invalid connection arguments")
return
if "command" in args:
# single command mode
# for user mode, api key or password is ok
# for admin mode, only password
if "password" not in args:
print("Error: password is missing")
return
cli.run_single_command(args)
else:
cli.run_interactive(args)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -1,65 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from http_client import HttpClient
class AuthException(Exception):
def __init__(self, message, code=401):
super().__init__(message)
self.code = code
self.message = message
def encrypt_password(password_plain: str) -> str:
try:
from api.utils.crypt import crypt
except Exception as exc:
raise AuthException(
"Password encryption unavailable; install pycryptodomex (uv sync --python 3.12 --group test)."
) from exc
return crypt(password_plain)
def register_user(client: HttpClient, email: str, nickname: str, password: str) -> None:
password_enc = encrypt_password(password)
payload = {"email": email, "nickname": nickname, "password": password_enc}
res = client.request_json("POST", "/user/register", use_api_base=False, auth_kind=None, json_body=payload)
if res.get("code") == 0:
return
msg = res.get("message", "")
if "has already registered" in msg:
return
raise AuthException(f"Register failed: {msg}")
def login_user(client: HttpClient, server_type: str, email: str, password: str) -> str:
password_enc = encrypt_password(password)
payload = {"email": email, "password": password_enc}
if server_type == "admin":
response = client.request("POST", "/admin/login", use_api_base=True, auth_kind=None, json_body=payload)
else:
response = client.request("POST", "/user/login", use_api_base=False, auth_kind=None, json_body=payload)
try:
res = response.json()
except Exception as exc:
raise AuthException(f"Login failed: invalid JSON response ({exc})") from exc
if res.get("code") != 0:
raise AuthException(f"Login failed: {res.get('message')}")
token = response.headers.get("Authorization")
if not token:
raise AuthException("Login failed: missing Authorization header")
return token

2
admin/client/uv.lock generated
View File

@ -196,7 +196,7 @@ wheels = [
[[package]]
name = "ragflow-cli"
version = "0.23.1"
version = "0.22.1"
source = { virtual = "." }
dependencies = [
{ name = "beartype" },

View File

@ -14,12 +14,10 @@
# limitations under the License.
#
import time
start_ts = time.time()
import os
import signal
import logging
import time
import threading
import traceback
import faulthandler
@ -68,7 +66,7 @@ if __name__ == '__main__':
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
try:
logging.info(f"RAGFlow admin is ready after {time.time() - start_ts}s initialization.")
logging.info("RAGFlow Admin service start...")
run_simple(
hostname="0.0.0.0",
port=9381,

View File

@ -27,8 +27,6 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from api.common.exceptions import AdminException, UserNotFoundError
from api.common.base64 import encode_to_base64
from api.db.services import UserService
from api.db import UserTenantRole
from api.db.services.user_service import TenantService, UserTenantService
from common.constants import ActiveEnum, StatusEnum
from api.utils.crypt import decrypt
from common.misc_utils import get_uuid
@ -87,44 +85,8 @@ def init_default_admin():
}
if not UserService.save(**default_admin):
raise AdminException("Can't init admin.", 500)
add_tenant_for_admin(default_admin, UserTenantRole.OWNER)
elif not any([u.is_active == ActiveEnum.ACTIVE.value for u in users]):
raise AdminException("No active admin. Please update 'is_active' in db manually.", 500)
else:
default_admin_rows = [u for u in users if u.email == "admin@ragflow.io"]
if default_admin_rows:
default_admin = default_admin_rows[0].to_dict()
exist, default_admin_tenant = TenantService.get_by_id(default_admin["id"])
if not exist:
add_tenant_for_admin(default_admin, UserTenantRole.OWNER)
def add_tenant_for_admin(user_info: dict, role: str):
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.llm_service import get_init_tenant_llm
tenant = {
"id": user_info["id"],
"name": user_info["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL
}
usr_tenant = {
"tenant_id": user_info["id"],
"user_id": user_info["id"],
"invited_by": user_info["id"],
"role": role
}
tenant_llm = get_init_tenant_llm(user_info["id"])
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
logging.info(
f"Added tenant for email: {user_info['email']}, A default tenant has been set; changing the default models after login is strongly recommended.")
def check_admin_auth(func):

View File

@ -15,34 +15,29 @@
#
import secrets
import logging
from typing import Any
from common.time_utils import current_timestamp, datetime_format
from datetime import datetime
from flask import Blueprint, Response, request
from flask import Blueprint, request
from flask_login import current_user, login_required, logout_user
from auth import login_verify, login_admin, check_admin_auth
from responses import success_response, error_response
from services import UserMgr, ServiceMgr, UserServiceMgr, SettingsMgr, ConfigMgr, EnvironmentsMgr, SandboxMgr
from services import UserMgr, ServiceMgr, UserServiceMgr
from roles import RoleMgr
from api.common.exceptions import AdminException
from common.versions import get_ragflow_version
from api.utils.api_utils import generate_confirmation_token
admin_bp = Blueprint("admin", __name__, url_prefix="/api/v1/admin")
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
@admin_bp.route("/ping", methods=["GET"])
@admin_bp.route('/ping', methods=['GET'])
def ping():
return success_response("PONG")
return success_response('PONG')
@admin_bp.route("/login", methods=["POST"])
@admin_bp.route('/login', methods=['POST'])
def login():
if not request.json:
return error_response("Authorize admin failed.", 400)
return error_response('Authorize admin failed.' ,400)
try:
email = request.json.get("email", "")
password = request.json.get("password", "")
@ -51,7 +46,7 @@ def login():
return error_response(str(e), 500)
@admin_bp.route("/logout", methods=["GET"])
@admin_bp.route('/logout', methods=['GET'])
@login_required
def logout():
try:
@ -63,7 +58,7 @@ def logout():
return error_response(str(e), 500)
@admin_bp.route("/auth", methods=["GET"])
@admin_bp.route('/auth', methods=['GET'])
@login_verify
def auth_admin():
try:
@ -72,7 +67,7 @@ def auth_admin():
return error_response(str(e), 500)
@admin_bp.route("/users", methods=["GET"])
@admin_bp.route('/users', methods=['GET'])
@login_required
@check_admin_auth
def list_users():
@ -83,18 +78,18 @@ def list_users():
return error_response(str(e), 500)
@admin_bp.route("/users", methods=["POST"])
@admin_bp.route('/users', methods=['POST'])
@login_required
@check_admin_auth
def create_user():
try:
data = request.get_json()
if not data or "username" not in data or "password" not in data:
if not data or 'username' not in data or 'password' not in data:
return error_response("Username and password are required", 400)
username = data["username"]
password = data["password"]
role = data.get("role", "user")
username = data['username']
password = data['password']
role = data.get('role', 'user')
res = UserMgr.create_user(username, password, role)
if res["success"]:
@ -110,7 +105,7 @@ def create_user():
return error_response(str(e))
@admin_bp.route("/users/<username>", methods=["DELETE"])
@admin_bp.route('/users/<username>', methods=['DELETE'])
@login_required
@check_admin_auth
def delete_user(username):
@ -127,16 +122,16 @@ def delete_user(username):
return error_response(str(e), 500)
@admin_bp.route("/users/<username>/password", methods=["PUT"])
@admin_bp.route('/users/<username>/password', methods=['PUT'])
@login_required
@check_admin_auth
def change_password(username):
try:
data = request.get_json()
if not data or "new_password" not in data:
if not data or 'new_password' not in data:
return error_response("New password is required", 400)
new_password = data["new_password"]
new_password = data['new_password']
msg = UserMgr.update_user_password(username, new_password)
return success_response(None, msg)
@ -146,15 +141,15 @@ def change_password(username):
return error_response(str(e), 500)
@admin_bp.route("/users/<username>/activate", methods=["PUT"])
@admin_bp.route('/users/<username>/activate', methods=['PUT'])
@login_required
@check_admin_auth
def alter_user_activate_status(username):
try:
data = request.get_json()
if not data or "activate_status" not in data:
if not data or 'activate_status' not in data:
return error_response("Activation status is required", 400)
activate_status = data["activate_status"]
activate_status = data['activate_status']
msg = UserMgr.update_user_activate_status(username, activate_status)
return success_response(None, msg)
except AdminException as e:
@ -163,39 +158,7 @@ def alter_user_activate_status(username):
return error_response(str(e), 500)
@admin_bp.route("/users/<username>/admin", methods=["PUT"])
@login_required
@check_admin_auth
def grant_admin(username):
try:
if current_user.email == username:
return error_response(f"can't grant current user: {username}", 409)
msg = UserMgr.grant_admin(username)
return success_response(None, msg)
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/users/<username>/admin", methods=["DELETE"])
@login_required
@check_admin_auth
def revoke_admin(username):
try:
if current_user.email == username:
return error_response(f"can't grant current user: {username}", 409)
msg = UserMgr.revoke_admin(username)
return success_response(None, msg)
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/users/<username>", methods=["GET"])
@admin_bp.route('/users/<username>', methods=['GET'])
@login_required
@check_admin_auth
def get_user_details(username):
@ -209,7 +172,7 @@ def get_user_details(username):
return error_response(str(e), 500)
@admin_bp.route("/users/<username>/datasets", methods=["GET"])
@admin_bp.route('/users/<username>/datasets', methods=['GET'])
@login_required
@check_admin_auth
def get_user_datasets(username):
@ -223,7 +186,7 @@ def get_user_datasets(username):
return error_response(str(e), 500)
@admin_bp.route("/users/<username>/agents", methods=["GET"])
@admin_bp.route('/users/<username>/agents', methods=['GET'])
@login_required
@check_admin_auth
def get_user_agents(username):
@ -237,7 +200,7 @@ def get_user_agents(username):
return error_response(str(e), 500)
@admin_bp.route("/services", methods=["GET"])
@admin_bp.route('/services', methods=['GET'])
@login_required
@check_admin_auth
def get_services():
@ -248,7 +211,7 @@ def get_services():
return error_response(str(e), 500)
@admin_bp.route("/service_types/<service_type>", methods=["GET"])
@admin_bp.route('/service_types/<service_type>', methods=['GET'])
@login_required
@check_admin_auth
def get_services_by_type(service_type_str):
@ -259,7 +222,7 @@ def get_services_by_type(service_type_str):
return error_response(str(e), 500)
@admin_bp.route("/services/<service_id>", methods=["GET"])
@admin_bp.route('/services/<service_id>', methods=['GET'])
@login_required
@check_admin_auth
def get_service(service_id):
@ -270,7 +233,7 @@ def get_service(service_id):
return error_response(str(e), 500)
@admin_bp.route("/services/<service_id>", methods=["DELETE"])
@admin_bp.route('/services/<service_id>', methods=['DELETE'])
@login_required
@check_admin_auth
def shutdown_service(service_id):
@ -281,7 +244,7 @@ def shutdown_service(service_id):
return error_response(str(e), 500)
@admin_bp.route("/services/<service_id>", methods=["PUT"])
@admin_bp.route('/services/<service_id>', methods=['PUT'])
@login_required
@check_admin_auth
def restart_service(service_id):
@ -292,38 +255,38 @@ def restart_service(service_id):
return error_response(str(e), 500)
@admin_bp.route("/roles", methods=["POST"])
@admin_bp.route('/roles', methods=['POST'])
@login_required
@check_admin_auth
def create_role():
try:
data = request.get_json()
if not data or "role_name" not in data:
if not data or 'role_name' not in data:
return error_response("Role name is required", 400)
role_name: str = data["role_name"]
description: str = data["description"]
role_name: str = data['role_name']
description: str = data['description']
res = RoleMgr.create_role(role_name, description)
return success_response(res)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/roles/<role_name>", methods=["PUT"])
@admin_bp.route('/roles/<role_name>', methods=['PUT'])
@login_required
@check_admin_auth
def update_role(role_name: str):
try:
data = request.get_json()
if not data or "description" not in data:
if not data or 'description' not in data:
return error_response("Role description is required", 400)
description: str = data["description"]
description: str = data['description']
res = RoleMgr.update_role_description(role_name, description)
return success_response(res)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/roles/<role_name>", methods=["DELETE"])
@admin_bp.route('/roles/<role_name>', methods=['DELETE'])
@login_required
@check_admin_auth
def delete_role(role_name: str):
@ -334,7 +297,7 @@ def delete_role(role_name: str):
return error_response(str(e), 500)
@admin_bp.route("/roles", methods=["GET"])
@admin_bp.route('/roles', methods=['GET'])
@login_required
@check_admin_auth
def list_roles():
@ -345,7 +308,7 @@ def list_roles():
return error_response(str(e), 500)
@admin_bp.route("/roles/<role_name>/permission", methods=["GET"])
@admin_bp.route('/roles/<role_name>/permission', methods=['GET'])
@login_required
@check_admin_auth
def get_role_permission(role_name: str):
@ -356,54 +319,54 @@ def get_role_permission(role_name: str):
return error_response(str(e), 500)
@admin_bp.route("/roles/<role_name>/permission", methods=["POST"])
@admin_bp.route('/roles/<role_name>/permission', methods=['POST'])
@login_required
@check_admin_auth
def grant_role_permission(role_name: str):
try:
data = request.get_json()
if not data or "actions" not in data or "resource" not in data:
if not data or 'actions' not in data or 'resource' not in data:
return error_response("Permission is required", 400)
actions: list = data["actions"]
resource: str = data["resource"]
actions: list = data['actions']
resource: str = data['resource']
res = RoleMgr.grant_role_permission(role_name, actions, resource)
return success_response(res)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/roles/<role_name>/permission", methods=["DELETE"])
@admin_bp.route('/roles/<role_name>/permission', methods=['DELETE'])
@login_required
@check_admin_auth
def revoke_role_permission(role_name: str):
try:
data = request.get_json()
if not data or "actions" not in data or "resource" not in data:
if not data or 'actions' not in data or 'resource' not in data:
return error_response("Permission is required", 400)
actions: list = data["actions"]
resource: str = data["resource"]
actions: list = data['actions']
resource: str = data['resource']
res = RoleMgr.revoke_role_permission(role_name, actions, resource)
return success_response(res)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/users/<user_name>/role", methods=["PUT"])
@admin_bp.route('/users/<user_name>/role', methods=['PUT'])
@login_required
@check_admin_auth
def update_user_role(user_name: str):
try:
data = request.get_json()
if not data or "role_name" not in data:
if not data or 'role_name' not in data:
return error_response("Role name is required", 400)
role_name: str = data["role_name"]
role_name: str = data['role_name']
res = RoleMgr.update_user_role(user_name, role_name)
return success_response(res)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/users/<user_name>/permission", methods=["GET"])
@admin_bp.route('/users/<user_name>/permission', methods=['GET'])
@login_required
@check_admin_auth
def get_user_permission(user_name: str):
@ -413,140 +376,7 @@ def get_user_permission(user_name: str):
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/variables", methods=["PUT"])
@login_required
@check_admin_auth
def set_variable():
try:
data = request.get_json()
if not data and "var_name" not in data:
return error_response("Var name is required", 400)
if "var_value" not in data:
return error_response("Var value is required", 400)
var_name: str = data["var_name"]
var_value: str = data["var_value"]
SettingsMgr.update_by_name(var_name, var_value)
return success_response(None, "Set variable successfully")
except AdminException as e:
return error_response(str(e), 400)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/variables", methods=["GET"])
@login_required
@check_admin_auth
def get_variable():
try:
if request.content_length is None or request.content_length == 0:
# list variables
res = list(SettingsMgr.get_all())
return success_response(res)
# get var
data = request.get_json()
if not data and "var_name" not in data:
return error_response("Var name is required", 400)
var_name: str = data["var_name"]
res = SettingsMgr.get_by_name(var_name)
return success_response(res)
except AdminException as e:
return error_response(str(e), 400)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/configs", methods=["GET"])
@login_required
@check_admin_auth
def get_config():
try:
res = list(ConfigMgr.get_all())
return success_response(res)
except AdminException as e:
return error_response(str(e), 400)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/environments", methods=["GET"])
@login_required
@check_admin_auth
def get_environments():
try:
res = list(EnvironmentsMgr.get_all())
return success_response(res)
except AdminException as e:
return error_response(str(e), 400)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/users/<username>/keys", methods=["POST"])
@login_required
@check_admin_auth
def generate_user_api_key(username: str) -> tuple[Response, int]:
try:
user_details: list[dict[str, Any]] = UserMgr.get_user_details(username)
if not user_details:
return error_response("User not found!", 404)
tenants: list[dict[str, Any]] = UserServiceMgr.get_user_tenants(username)
if not tenants:
return error_response("Tenant not found!", 404)
tenant_id: str = tenants[0]["tenant_id"]
key: str = generate_confirmation_token()
obj: dict[str, Any] = {
"tenant_id": tenant_id,
"token": key,
"beta": generate_confirmation_token().replace("ragflow-", "")[:32],
"create_time": current_timestamp(),
"create_date": datetime_format(datetime.now()),
"update_time": None,
"update_date": None,
}
if not UserMgr.save_api_key(obj):
return error_response("Failed to generate API key!", 500)
return success_response(obj, "API key generated successfully")
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/users/<username>/keys", methods=["GET"])
@login_required
@check_admin_auth
def get_user_api_keys(username: str) -> tuple[Response, int]:
try:
api_keys: list[dict[str, Any]] = UserMgr.get_user_api_key(username)
return success_response(api_keys, "Get user API keys")
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/users/<username>/keys/<key>", methods=["DELETE"])
@login_required
@check_admin_auth
def delete_user_api_key(username: str, key: str) -> tuple[Response, int]:
try:
deleted = UserMgr.delete_api_key(username, key)
if deleted:
return success_response(None, "API key deleted successfully")
else:
return error_response("API key not found or could not be deleted", 404)
except AdminException as e:
return error_response(e.message, e.code)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/version", methods=["GET"])
@admin_bp.route('/version', methods=['GET'])
@login_required
@check_admin_auth
def show_version():
@ -555,100 +385,3 @@ def show_version():
return success_response(res)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/sandbox/providers", methods=["GET"])
@login_required
@check_admin_auth
def list_sandbox_providers():
"""List all available sandbox providers."""
try:
res = SandboxMgr.list_providers()
return success_response(res)
except AdminException as e:
return error_response(str(e), 400)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/sandbox/providers/<provider_id>/schema", methods=["GET"])
@login_required
@check_admin_auth
def get_sandbox_provider_schema(provider_id: str):
"""Get configuration schema for a specific provider."""
try:
res = SandboxMgr.get_provider_config_schema(provider_id)
return success_response(res)
except AdminException as e:
return error_response(str(e), 400)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/sandbox/config", methods=["GET"])
@login_required
@check_admin_auth
def get_sandbox_config():
"""Get current sandbox configuration."""
try:
res = SandboxMgr.get_config()
return success_response(res)
except AdminException as e:
return error_response(str(e), 400)
except Exception as e:
return error_response(str(e), 500)
@admin_bp.route("/sandbox/config", methods=["POST"])
@login_required
@check_admin_auth
def set_sandbox_config():
"""Set sandbox provider configuration."""
try:
data = request.get_json()
if not data:
logging.error("set_sandbox_config: Request body is required")
return error_response("Request body is required", 400)
provider_type = data.get("provider_type")
if not provider_type:
logging.error("set_sandbox_config: provider_type is required")
return error_response("provider_type is required", 400)
config = data.get("config", {})
set_active = data.get("set_active", True) # Default to True for backward compatibility
logging.info(f"set_sandbox_config: provider_type={provider_type}, set_active={set_active}")
logging.info(f"set_sandbox_config: config keys={list(config.keys())}")
res = SandboxMgr.set_config(provider_type, config, set_active)
return success_response(res, "Sandbox configuration updated successfully")
except AdminException as e:
logging.exception("set_sandbox_config AdminException")
return error_response(str(e), 400)
except Exception as e:
logging.exception("set_sandbox_config unexpected error")
return error_response(str(e), 500)
@admin_bp.route("/sandbox/test", methods=["POST"])
@login_required
@check_admin_auth
def test_sandbox_connection():
"""Test connection to sandbox provider."""
try:
data = request.get_json()
if not data:
return error_response("Request body is required", 400)
provider_type = data.get("provider_type")
if not provider_type:
return error_response("provider_type is required", 400)
config = data.get("config", {})
res = SandboxMgr.test_connection(provider_type, config)
return success_response(res)
except AdminException as e:
return error_response(str(e), 400)
except Exception as e:
return error_response(str(e), 500)

View File

@ -13,23 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import logging
import re
from typing import Any
from werkzeug.security import check_password_hash
from common.constants import ActiveEnum
from api.db.services import UserService
from api.db.joint_services.user_account_service import create_new_user, delete_user_data
from api.db.services.canvas_service import UserCanvasService
from api.db.services.user_service import TenantService, UserTenantService
from api.db.services.user_service import TenantService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.system_settings_service import SystemSettingsService
from api.db.services.api_service import APITokenService
from api.db.db_models import APIToken
from api.utils.crypt import decrypt
from api.utils import health_utils
@ -43,15 +35,13 @@ class UserMgr:
users = UserService.get_all_users()
result = []
for user in users:
result.append(
{
"email": user.email,
"nickname": user.nickname,
"create_date": user.create_date,
"is_active": user.is_active,
"is_superuser": user.is_superuser,
}
)
result.append({
'email': user.email,
'nickname': user.nickname,
'create_date': user.create_date,
'is_active': user.is_active,
'is_superuser': user.is_superuser,
})
return result
@staticmethod
@ -60,21 +50,19 @@ class UserMgr:
users = UserService.query_user_by_email(username)
result = []
for user in users:
result.append(
{
"avatar": user.avatar,
"email": user.email,
"language": user.language,
"last_login_time": user.last_login_time,
"is_active": user.is_active,
"is_anonymous": user.is_anonymous,
"login_channel": user.login_channel,
"status": user.status,
"is_superuser": user.is_superuser,
"create_date": user.create_date,
"update_date": user.update_date,
}
)
result.append({
'avatar': user.avatar,
'email': user.email,
'language': user.language,
'last_login_time': user.last_login_time,
'is_active': user.is_active,
'is_anonymous': user.is_anonymous,
'login_channel': user.login_channel,
'status': user.status,
'is_superuser': user.is_superuser,
'create_date': user.create_date,
'update_date': user.update_date
})
return result
@staticmethod
@ -136,8 +124,8 @@ class UserMgr:
# format activate_status before handle
_activate_status = activate_status.lower()
target_status = {
"on": ActiveEnum.ACTIVE.value,
"off": ActiveEnum.INACTIVE.value,
'on': ActiveEnum.ACTIVE.value,
'off': ActiveEnum.INACTIVE.value,
}.get(_activate_status)
if not target_status:
raise AdminException(f"Invalid activate_status: {activate_status}")
@ -147,84 +135,9 @@ class UserMgr:
UserService.update_user(usr.id, {"is_active": target_status})
return f"Turn {_activate_status} user activate status successfully!"
@staticmethod
def get_user_api_key(username: str) -> list[dict[str, Any]]:
# use email to find user. check exist and unique.
user_list: list[Any] = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"More than one user with username '{username}' found!")
usr: Any = user_list[0]
# tenant_id is typically the same as user_id for the owner tenant
tenant_id: str = usr.id
# Query all API keys for this tenant
api_keys: Any = APITokenService.query(tenant_id=tenant_id)
result: list[dict[str, Any]] = []
for key in api_keys:
result.append(key.to_dict())
return result
@staticmethod
def save_api_key(api_key: dict[str, Any]) -> bool:
return APITokenService.save(**api_key)
@staticmethod
def delete_api_key(username: str, key: str) -> bool:
# use email to find user. check exist and unique.
user_list: list[Any] = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
usr: Any = user_list[0]
# tenant_id is typically the same as user_id for the owner tenant
tenant_id: str = usr.id
# Delete the API key
deleted_count: int = APITokenService.filter_delete([APIToken.tenant_id == tenant_id, APIToken.token == key])
return deleted_count > 0
@staticmethod
def grant_admin(username: str):
# use email to find user. check exist and unique.
user_list = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
# check activate status different from new
usr = user_list[0]
if usr.is_superuser:
return f"{usr} is already superuser!"
# update is_active
UserService.update_user(usr.id, {"is_superuser": True})
return "Grant successfully!"
@staticmethod
def revoke_admin(username: str):
# use email to find user. check exist and unique.
user_list = UserService.query_user_by_email(username)
if not user_list:
raise UserNotFoundError(username)
elif len(user_list) > 1:
raise AdminException(f"Exist more than 1 user: {username}!")
# check activate status different from new
usr = user_list[0]
if not usr.is_superuser:
return f"{usr} isn't superuser, yet!"
# update is_active
UserService.update_user(usr.id, {"is_superuser": False})
return "Revoke successfully!"
class UserServiceMgr:
@staticmethod
def get_user_datasets(username):
# use email to find user.
@ -254,43 +167,35 @@ class UserServiceMgr:
tenant_ids = [m["tenant_id"] for m in tenants]
# filter permitted agents and owned agents
res = UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
return [{"title": r["title"], "permission": r["permission"], "canvas_category": r["canvas_category"].split("_")[0], "avatar": r["avatar"]} for r in res]
@staticmethod
def get_user_tenants(email: str) -> list[dict[str, Any]]:
users: list[Any] = UserService.query_user_by_email(email)
if not users:
raise UserNotFoundError(email)
user: Any = users[0]
tenants: list[dict[str, Any]] = UserTenantService.get_tenants_by_user_id(user.id)
return tenants
return [{
'title': r['title'],
'permission': r['permission'],
'canvas_category': r['canvas_category'].split('_')[0],
'avatar': r['avatar']
} for r in res]
class ServiceMgr:
@staticmethod
def get_all_services():
doc_engine = os.getenv("DOC_ENGINE", "elasticsearch")
result = []
configs = SERVICE_CONFIGS.configs
for service_id, config in enumerate(configs):
config_dict = config.to_dict()
if config_dict["service_type"] == "retrieval":
if config_dict["extra"]["retrieval_type"] != doc_engine:
continue
try:
service_detail = ServiceMgr.get_service_details(service_id)
if "status" in service_detail:
config_dict["status"] = service_detail["status"]
config_dict['status'] = service_detail['status']
else:
config_dict["status"] = "timeout"
config_dict['status'] = 'timeout'
except Exception as e:
logging.warning(f"Can't get service details, error: {e}")
config_dict["status"] = "timeout"
if not config_dict["host"]:
config_dict["host"] = "-"
if not config_dict["port"]:
config_dict["port"] = "-"
config_dict['status'] = 'timeout'
if not config_dict['host']:
config_dict['host'] = '-'
if not config_dict['port']:
config_dict['port'] = '-'
result.append(config_dict)
return result
@ -306,18 +211,11 @@ class ServiceMgr:
raise AdminException(f"invalid service_index: {service_idx}")
service_config = configs[service_idx]
service_info = {'name': service_config.name, 'detail_func_name': service_config.detail_func_name}
# exclude retrieval service if retrieval_type is not matched
doc_engine = os.getenv("DOC_ENGINE", "elasticsearch")
if service_config.service_type == "retrieval":
if service_config.retrieval_type != doc_engine:
raise AdminException(f"invalid service_index: {service_idx}")
service_info = {"name": service_config.name, "detail_func_name": service_config.detail_func_name}
detail_func = getattr(health_utils, service_info.get("detail_func_name"))
detail_func = getattr(health_utils, service_info.get('detail_func_name'))
res = detail_func()
res.update({"service_name": service_info.get("name")})
res.update({'service_name': service_info.get('name')})
return res
@staticmethod
@ -327,397 +225,3 @@ class ServiceMgr:
@staticmethod
def restart_service(service_id: int):
raise AdminException("restart_service: not implemented")
class SettingsMgr:
@staticmethod
def get_all():
settings = SystemSettingsService.get_all()
result = []
for setting in settings:
result.append(
{
"name": setting.name,
"source": setting.source,
"data_type": setting.data_type,
"value": setting.value,
}
)
return result
@staticmethod
def get_by_name(name: str):
settings = SystemSettingsService.get_by_name(name)
if len(settings) == 0:
raise AdminException(f"Can't get setting: {name}")
result = []
for setting in settings:
result.append(
{
"name": setting.name,
"source": setting.source,
"data_type": setting.data_type,
"value": setting.value,
}
)
return result
@staticmethod
def update_by_name(name: str, value: str):
settings = SystemSettingsService.get_by_name(name)
if len(settings) == 1:
setting = settings[0]
setting.value = value
setting_dict = setting.to_dict()
SystemSettingsService.update_by_name(name, setting_dict)
elif len(settings) > 1:
raise AdminException(f"Can't update more than 1 setting: {name}")
else:
# Create new setting if it doesn't exist
# Determine data_type based on name and value
if name.startswith("sandbox."):
data_type = "json"
elif name.endswith(".enabled"):
data_type = "boolean"
else:
data_type = "string"
new_setting = {
"name": name,
"value": str(value),
"source": "admin",
"data_type": data_type,
}
SystemSettingsService.save(**new_setting)
class ConfigMgr:
@staticmethod
def get_all():
result = []
configs = SERVICE_CONFIGS.configs
for config in configs:
config_dict = config.to_dict()
result.append(config_dict)
return result
class EnvironmentsMgr:
@staticmethod
def get_all():
result = []
env_kv = {"env": "DOC_ENGINE", "value": os.getenv("DOC_ENGINE")}
result.append(env_kv)
env_kv = {"env": "DEFAULT_SUPERUSER_EMAIL", "value": os.getenv("DEFAULT_SUPERUSER_EMAIL", "admin@ragflow.io")}
result.append(env_kv)
env_kv = {"env": "DB_TYPE", "value": os.getenv("DB_TYPE", "mysql")}
result.append(env_kv)
env_kv = {"env": "DEVICE", "value": os.getenv("DEVICE", "cpu")}
result.append(env_kv)
env_kv = {"env": "STORAGE_IMPL", "value": os.getenv("STORAGE_IMPL", "MINIO")}
result.append(env_kv)
return result
class SandboxMgr:
"""Manager for sandbox provider configuration and operations."""
# Provider registry with metadata
PROVIDER_REGISTRY = {
"self_managed": {
"name": "Self-Managed",
"description": "On-premise deployment using Daytona/Docker",
"tags": ["self-hosted", "low-latency", "secure"],
},
"aliyun_codeinterpreter": {
"name": "Aliyun Code Interpreter",
"description": "Aliyun Function Compute Code Interpreter - Code execution in serverless microVMs",
"tags": ["saas", "cloud", "scalable", "aliyun"],
},
"e2b": {
"name": "E2B",
"description": "E2B Cloud - Code Execution Sandboxes",
"tags": ["saas", "fast", "global"],
},
}
@staticmethod
def list_providers():
"""List all available sandbox providers."""
result = []
for provider_id, metadata in SandboxMgr.PROVIDER_REGISTRY.items():
result.append({
"id": provider_id,
**metadata
})
return result
@staticmethod
def get_provider_config_schema(provider_id: str):
"""Get configuration schema for a specific provider."""
from agent.sandbox.providers import (
SelfManagedProvider,
AliyunCodeInterpreterProvider,
E2BProvider,
)
schemas = {
"self_managed": SelfManagedProvider.get_config_schema(),
"aliyun_codeinterpreter": AliyunCodeInterpreterProvider.get_config_schema(),
"e2b": E2BProvider.get_config_schema(),
}
if provider_id not in schemas:
raise AdminException(f"Unknown provider: {provider_id}")
return schemas.get(provider_id, {})
@staticmethod
def get_config():
"""Get current sandbox configuration."""
try:
# Get active provider type
provider_type_settings = SystemSettingsService.get_by_name("sandbox.provider_type")
if not provider_type_settings:
# Return default config if not set
provider_type = "self_managed"
else:
provider_type = provider_type_settings[0].value
# Get provider-specific config
provider_config_settings = SystemSettingsService.get_by_name(f"sandbox.{provider_type}")
if not provider_config_settings:
provider_config = {}
else:
try:
provider_config = json.loads(provider_config_settings[0].value)
except json.JSONDecodeError:
provider_config = {}
return {
"provider_type": provider_type,
"config": provider_config,
}
except Exception as e:
raise AdminException(f"Failed to get sandbox config: {str(e)}")
@staticmethod
def set_config(provider_type: str, config: dict, set_active: bool = True):
"""
Set sandbox provider configuration.
Args:
provider_type: Provider identifier (e.g., "self_managed", "e2b")
config: Provider configuration dictionary
set_active: If True, also update the active provider. If False,
only update the configuration without switching providers.
Default: True
Returns:
Dictionary with updated provider_type and config
"""
from agent.sandbox.providers import (
SelfManagedProvider,
AliyunCodeInterpreterProvider,
E2BProvider,
)
try:
# Validate provider type
if provider_type not in SandboxMgr.PROVIDER_REGISTRY:
raise AdminException(f"Unknown provider type: {provider_type}")
# Get provider schema for validation
schema = SandboxMgr.get_provider_config_schema(provider_type)
# Validate config against schema
for field_name, field_schema in schema.items():
if field_schema.get("required", False) and field_name not in config:
raise AdminException(f"Required field '{field_name}' is missing")
# Type validation
if field_name in config:
field_type = field_schema.get("type")
if field_type == "integer":
if not isinstance(config[field_name], int):
raise AdminException(f"Field '{field_name}' must be an integer")
elif field_type == "string":
if not isinstance(config[field_name], str):
raise AdminException(f"Field '{field_name}' must be a string")
elif field_type == "bool":
if not isinstance(config[field_name], bool):
raise AdminException(f"Field '{field_name}' must be a boolean")
# Range validation for integers
if field_type == "integer" and field_name in config:
min_val = field_schema.get("min")
max_val = field_schema.get("max")
if min_val is not None and config[field_name] < min_val:
raise AdminException(f"Field '{field_name}' must be >= {min_val}")
if max_val is not None and config[field_name] > max_val:
raise AdminException(f"Field '{field_name}' must be <= {max_val}")
# Provider-specific custom validation
provider_classes = {
"self_managed": SelfManagedProvider,
"aliyun_codeinterpreter": AliyunCodeInterpreterProvider,
"e2b": E2BProvider,
}
provider = provider_classes[provider_type]()
is_valid, error_msg = provider.validate_config(config)
if not is_valid:
raise AdminException(f"Provider validation failed: {error_msg}")
# Update provider_type only if set_active is True
if set_active:
SettingsMgr.update_by_name("sandbox.provider_type", provider_type)
# Always update the provider config
config_json = json.dumps(config)
SettingsMgr.update_by_name(f"sandbox.{provider_type}", config_json)
return {"provider_type": provider_type, "config": config}
except AdminException:
raise
except Exception as e:
raise AdminException(f"Failed to set sandbox config: {str(e)}")
@staticmethod
def test_connection(provider_type: str, config: dict):
"""
Test connection to sandbox provider by executing a simple Python script.
This creates a temporary sandbox instance and runs a test code to verify:
- Connection credentials are valid
- Sandbox can be created
- Code execution works correctly
Args:
provider_type: Provider identifier
config: Provider configuration dictionary
Returns:
dict with test results including stdout, stderr, exit_code, execution_time
"""
try:
from agent.sandbox.providers import (
SelfManagedProvider,
AliyunCodeInterpreterProvider,
E2BProvider,
)
# Instantiate provider based on type
provider_classes = {
"self_managed": SelfManagedProvider,
"aliyun_codeinterpreter": AliyunCodeInterpreterProvider,
"e2b": E2BProvider,
}
if provider_type not in provider_classes:
raise AdminException(f"Unknown provider type: {provider_type}")
provider = provider_classes[provider_type]()
# Initialize with config
if not provider.initialize(config):
raise AdminException(f"Failed to initialize provider '{provider_type}'")
# Create a temporary sandbox instance for testing
instance = provider.create_instance(template="python")
if not instance or instance.status != "READY":
raise AdminException(f"Failed to create sandbox instance. Status: {instance.status if instance else 'None'}")
# Simple test code that exercises basic Python functionality
test_code = """
# Test basic Python functionality
import sys
import json
import math
print("Python version:", sys.version)
print("Platform:", sys.platform)
# Test basic calculations
result = 2 + 2
print(f"2 + 2 = {result}")
# Test JSON operations
data = {"test": "data", "value": 123}
print(f"JSON dump: {json.dumps(data)}")
# Test math operations
print(f"Math.sqrt(16) = {math.sqrt(16)}")
# Test error handling
try:
x = 1 / 1
print("Division test: OK")
except Exception as e:
print(f"Error: {e}")
# Return success indicator
print("TEST_PASSED")
"""
# Execute test code with timeout
execution_result = provider.execute_code(
instance_id=instance.instance_id,
code=test_code,
language="python",
timeout=10 # 10 seconds timeout
)
# Clean up the test instance (if provider supports it)
try:
if hasattr(provider, 'terminate_instance'):
provider.terminate_instance(instance.instance_id)
logging.info(f"Cleaned up test instance {instance.instance_id}")
else:
logging.warning(f"Provider {provider_type} does not support terminate_instance, test instance may leak")
except Exception as cleanup_error:
logging.warning(f"Failed to cleanup test instance {instance.instance_id}: {cleanup_error}")
# Build detailed result message
success = execution_result.exit_code == 0 and "TEST_PASSED" in execution_result.stdout
message_parts = [
f"Test {success and 'PASSED' or 'FAILED'}",
f"Exit code: {execution_result.exit_code}",
f"Execution time: {execution_result.execution_time:.2f}s"
]
if execution_result.stdout.strip():
stdout_preview = execution_result.stdout.strip()[:200]
message_parts.append(f"Output: {stdout_preview}...")
if execution_result.stderr.strip():
stderr_preview = execution_result.stderr.strip()[:200]
message_parts.append(f"Errors: {stderr_preview}...")
message = " | ".join(message_parts)
return {
"success": success,
"message": message,
"details": {
"exit_code": execution_result.exit_code,
"execution_time": execution_result.execution_time,
"stdout": execution_result.stdout,
"stderr": execution_result.stderr,
}
}
except AdminException:
raise
except Exception as e:
import traceback
error_details = traceback.format_exc()
raise AdminException(f"Connection test failed: {str(e)}\\n\\nStack trace:\\n{error_details}")

View File

@ -78,14 +78,13 @@ class Graph:
}
"""
def __init__(self, dsl: str, tenant_id=None, task_id=None, custom_header=None):
def __init__(self, dsl: str, tenant_id=None, task_id=None):
self.path = []
self.components = {}
self.error = ""
self.dsl = json.loads(dsl)
self._tenant_id = tenant_id
self.task_id = task_id if task_id else get_uuid()
self.custom_header = custom_header
self._thread_pool = ThreadPoolExecutor(max_workers=5)
self.load()
@ -95,7 +94,6 @@ class Graph:
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
param = component_class(cpn["obj"]["component_name"] + "Param")()
cpn["obj"]["params"]["custom_header"] = self.custom_header
param.update(cpn["obj"]["params"])
try:
param.check()
@ -280,32 +278,27 @@ class Graph:
class Canvas(Graph):
def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None, custom_header=None):
def __init__(self, dsl: str, tenant_id=None, task_id=None):
self.globals = {
"sys.query": "",
"sys.user_id": tenant_id,
"sys.conversation_turns": 0,
"sys.files": [],
"sys.history": []
"sys.files": []
}
self.variables = {}
super().__init__(dsl, tenant_id, task_id, custom_header=custom_header)
self._id = canvas_id
super().__init__(dsl, tenant_id, task_id)
def load(self):
super().load()
self.history = self.dsl["history"]
if "globals" in self.dsl:
self.globals = self.dsl["globals"]
if "sys.history" not in self.globals:
self.globals["sys.history"] = []
else:
self.globals = {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": [],
"sys.history": []
"sys.files": []
}
if "variables" in self.dsl:
self.variables = self.dsl["variables"]
@ -346,23 +339,21 @@ class Canvas(Graph):
key = k[4:]
if key in self.variables:
variable = self.variables[key]
if variable["type"] == "string":
self.globals[k] = ""
variable["value"] = ""
elif variable["type"] == "number":
self.globals[k] = 0
variable["value"] = 0
elif variable["type"] == "boolean":
self.globals[k] = False
variable["value"] = False
elif variable["type"] == "object":
self.globals[k] = {}
variable["value"] = {}
elif variable["type"].startswith("array"):
self.globals[k] = []
variable["value"] = []
if variable["value"]:
self.globals[k] = variable["value"]
else:
self.globals[k] = ""
if variable["type"] == "string":
self.globals[k] = ""
elif variable["type"] == "number":
self.globals[k] = 0
elif variable["type"] == "boolean":
self.globals[k] = False
elif variable["type"] == "object":
self.globals[k] = {}
elif variable["type"].startswith("array"):
self.globals[k] = []
else:
self.globals[k] = ""
else:
self.globals[k] = ""
@ -427,15 +418,9 @@ class Canvas(Graph):
loop = asyncio.get_running_loop()
tasks = []
max_concurrency = getattr(self._thread_pool, "_max_workers", 5)
sem = asyncio.Semaphore(max_concurrency)
async def _invoke_one(cpn_obj, sync_fn, call_kwargs, use_async: bool):
async with sem:
if use_async:
await cpn_obj.invoke_async(**(call_kwargs or {}))
return
await loop.run_in_executor(self._thread_pool, partial(sync_fn, **(call_kwargs or {})))
def _run_async_in_thread(coro_func, **call_kwargs):
return asyncio.run(coro_func(**call_kwargs))
i = f
while i < t:
@ -461,9 +446,11 @@ class Canvas(Graph):
if task_fn is None:
continue
fn_invoke_async = getattr(cpn, "_invoke_async", None)
use_async = (fn_invoke_async and asyncio.iscoroutinefunction(fn_invoke_async)) or asyncio.iscoroutinefunction(getattr(cpn, "_invoke", None))
tasks.append(asyncio.create_task(_invoke_one(cpn, task_fn, call_kwargs, use_async)))
invoke_async = getattr(cpn, "invoke_async", None)
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
else:
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
if tasks:
await asyncio.gather(*tasks)
@ -650,7 +637,6 @@ class Canvas(Graph):
"created_at": st,
})
self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
self.globals["sys.history"].append(f"{self.history[-1][0]}: {self.history[-1][1]}")
elif "Task has been canceled" in self.error:
yield decorate("workflow_finished",
{
@ -728,7 +714,6 @@ class Canvas(Graph):
def add_user_input(self, question):
self.history.append(("user", question))
self.globals["sys.history"].append(f"{self.history[-1][0]}: {self.history[-1][1]}")
def get_prologue(self):
return self.components["begin"]["obj"]._param.prologue
@ -736,9 +721,6 @@ class Canvas(Graph):
def get_mode(self):
return self.components["begin"]["obj"]._param.mode
def get_sys_query(self):
return self.globals.get("sys.query", "")
def set_global_param(self, **kwargs):
self.globals.update(kwargs)
@ -754,16 +736,13 @@ class Canvas(Graph):
def image_to_base64(file):
return "data:{};base64,{}".format(file["mime_type"],
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
def parse_file(file):
blob = FileService.get_blob(file["created_by"], file["id"])
return FileService.parse(file["name"], blob, True, file["created_by"])
loop = asyncio.get_running_loop()
tasks = []
for file in files:
if file["mime_type"].find("image") >=0:
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
continue
tasks.append(loop.run_in_executor(self._thread_pool, parse_file, file))
tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
return await asyncio.gather(*tasks)
def get_files(self, files: Union[None, list[dict]]) -> list[str]:

View File

@ -76,8 +76,6 @@ class AgentParam(LLMParam, ToolParamBase):
self.mcp = []
self.max_rounds = 5
self.description = ""
self.custom_header = {}
class Agent(LLM, ToolBase):
@ -107,8 +105,7 @@ class Agent(LLM, ToolBase):
for mcp in self._param.mcp:
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
custom_header = self._param.custom_header
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables, custom_header)
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
for tnm, meta in mcp["tools"].items():
self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta))
self.tools[tnm] = tool_call_session

View File

@ -27,10 +27,6 @@ import pandas as pd
from agent import settings
from common.connection_utils import timeout
from common.misc_utils import thread_pool_exec
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params"
_USER_FEEDED_PARAMS = "_user_feeded_params"
@ -383,7 +379,6 @@ class ComponentBase(ABC):
def __init__(self, canvas, id, param: ComponentParamBase):
from agent.canvas import Graph # Local import to avoid cyclic dependency
assert isinstance(canvas, Graph), "canvas must be an instance of Canvas"
self._canvas = canvas
self._id = id
@ -435,7 +430,7 @@ class ComponentBase(ABC):
elif asyncio.iscoroutinefunction(self._invoke):
await self._invoke(**kwargs)
else:
await thread_pool_exec(self._invoke, **kwargs)
await asyncio.to_thread(self._invoke, **kwargs)
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()

View File

@ -45,14 +45,11 @@ class Begin(UserFillUp):
if self.check_if_canceled("Begin processing"):
return
if isinstance(v, dict) and v.get("type", "").lower().find("file") >= 0:
if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
if v.get("optional") and v.get("value", None) is None:
v = None
else:
file_value = v["value"]
# Support both single file (backward compatibility) and multiple files
files = file_value if isinstance(file_value, list) else [file_value]
v = FileService.get_files(files)
v = FileService.get_files([v["value"]])
else:
v = v.get("value")
self.set_output(k, v)

View File

@ -97,13 +97,6 @@ Here's description of each category:
class Categorize(LLM, ABC):
component_name = "Categorize"
def get_input_elements(self) -> dict[str, dict]:
query_key = self._param.query or "sys.query"
elements = self.get_input_elements_from_text(f"{{{query_key}}}")
if not elements:
logging.warning(f"[Categorize] input element not detected for query key: {query_key}")
return elements
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Categorize processing"):
@ -112,15 +105,12 @@ class Categorize(LLM, ABC):
msg = self._canvas.get_history(self._param.message_history_window_size)
if not msg:
msg = [{"role": "user", "content": ""}]
query_key = self._param.query or "sys.query"
if query_key in kwargs:
query_value = kwargs[query_key]
if kwargs.get("sys.query"):
msg[-1]["content"] = kwargs["sys.query"]
self.set_input_value("sys.query", kwargs["sys.query"])
else:
query_value = self._canvas.get_variable_value(query_key)
if query_value is None:
query_value = ""
msg[-1]["content"] = query_value
self.set_input_value(query_key, msg[-1]["content"])
msg[-1]["content"] = self._canvas.get_variable_value(self._param.query)
self.set_input_value(self._param.query, msg[-1]["content"])
self._param.update_prompt()
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
@ -147,7 +137,7 @@ class Categorize(LLM, ABC):
category_counts[c] = count
cpn_ids = list(self._param.category_description.items())[-1][1]["to"]
max_category = list(self._param.category_description.keys())[-1]
max_category = list(self._param.category_description.keys())[0]
if any(category_counts.values()):
max_category = max(category_counts.items(), key=lambda x: x[1])[0]
cpn_ids = self._param.category_description[max_category]["to"]

View File

@ -64,14 +64,11 @@ class UserFillUp(ComponentBase):
for k, v in kwargs.get("inputs", {}).items():
if self.check_if_canceled("UserFillUp processing"):
return
if isinstance(v, dict) and v.get("type", "").lower().find("file") >= 0:
if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
if v.get("optional") and v.get("value", None) is None:
v = None
else:
file_value = v["value"]
# Support both single file (backward compatibility) and multiple files
files = file_value if isinstance(file_value, list) else [file_value]
v = FileService.get_files(files)
v = FileService.get_files([v["value"]])
else:
v = v.get("value")
self.set_output(k, v)

View File

@ -57,10 +57,12 @@ class Iteration(ComponentBase, ABC):
return cid
def _invoke(self, **kwargs):
if not self.check_if_canceled("Iteration processing"):
arr = self._canvas.get_variable_value(self._param.items_ref)
if not isinstance(arr, list):
self.set_output("_ERROR", self._param.items_ref + " must be an array, but its type is "+str(type(arr)))
if self.check_if_canceled("Iteration processing"):
return
arr = self._canvas.get_variable_value(self._param.items_ref)
if not isinstance(arr, list):
self.set_output("_ERROR", self._param.items_ref + " must be an array, but its type is "+str(type(arr)))
def thoughts(self) -> str:
return "Need to process {} items.".format(len(self._canvas.get_variable_value(self._param.items_ref)))

View File

@ -51,27 +51,29 @@ class Loop(ComponentBase, ABC):
return cid
def _invoke(self, **kwargs):
if not self.check_if_canceled("Loop processing"):
for item in self._param.loop_variables:
if any([not item.get("variable"), not item.get("input_mode"), not item.get("value"),not item.get("type")]):
assert "Loop Variable is not complete."
if item["input_mode"]=="variable":
self.set_output(item["variable"],self._canvas.get_variable_value(item["value"]))
elif item["input_mode"]=="constant":
self.set_output(item["variable"],item["value"])
if self.check_if_canceled("Loop processing"):
return
for item in self._param.loop_variables:
if any([not item.get("variable"), not item.get("input_mode"), not item.get("value"),not item.get("type")]):
assert "Loop Variable is not complete."
if item["input_mode"]=="variable":
self.set_output(item["variable"],self._canvas.get_variable_value(item["value"]))
elif item["input_mode"]=="constant":
self.set_output(item["variable"],item["value"])
else:
if item["type"] == "number":
self.set_output(item["variable"], 0)
elif item["type"] == "string":
self.set_output(item["variable"], "")
elif item["type"] == "boolean":
self.set_output(item["variable"], False)
elif item["type"].startswith("object"):
self.set_output(item["variable"], {})
elif item["type"].startswith("array"):
self.set_output(item["variable"], [])
else:
if item["type"] == "number":
self.set_output(item["variable"], 0)
elif item["type"] == "string":
self.set_output(item["variable"], "")
elif item["type"] == "boolean":
self.set_output(item["variable"], False)
elif item["type"].startswith("object"):
self.set_output(item["variable"], {})
elif item["type"].startswith("array"):
self.set_output(item["variable"], [])
else:
self.set_output(item["variable"], "")
self.set_output(item["variable"], "")
def thoughts(self) -> str:

View File

@ -33,8 +33,6 @@ from common.connection_utils import timeout
from common.misc_utils import get_uuid
from common import settings
from api.db.joint_services.memory_message_service import queue_save_to_memory_task
class MessageParam(ComponentParamBase):
"""
@ -168,7 +166,6 @@ class Message(ComponentBase):
self.set_output("content", all_content)
self._convert_content(all_content)
await self._save_to_memory(all_content)
def _is_jinjia2(self, content:str) -> bool:
patt = [
@ -201,7 +198,6 @@ class Message(ComponentBase):
self.set_output("content", content)
self._convert_content(content)
self._save_to_memory(content)
def thoughts(self) -> str:
return ""
@ -425,16 +421,3 @@ class Message(ComponentBase):
except Exception as e:
logging.error(f"Error converting content to {self._param.output_format}: {e}")
async def _save_to_memory(self, content):
if not hasattr(self._param, "memory_ids") or not self._param.memory_ids:
return True, "No memory selected."
message_dict = {
"user_id": self._canvas._tenant_id,
"agent_id": self._canvas._id,
"session_id": self._canvas.task_id,
"user_input": self._canvas.get_sys_query(),
"agent_response": content
}
return await queue_save_to_memory_task(self._param.memory_ids, message_dict)

View File

@ -1,239 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Sandbox client for agent components.
This module provides a unified interface for agent components to interact
with the configured sandbox provider.
"""
import json
import logging
from typing import Dict, Any, Optional
from api.db.services.system_settings_service import SystemSettingsService
from agent.sandbox.providers import ProviderManager
from agent.sandbox.providers.base import ExecutionResult
logger = logging.getLogger(__name__)
# Global provider manager instance
_provider_manager: Optional[ProviderManager] = None
def get_provider_manager() -> ProviderManager:
"""
Get the global provider manager instance.
Returns:
ProviderManager instance with active provider loaded
"""
global _provider_manager
if _provider_manager is not None:
return _provider_manager
# Initialize provider manager with system settings
_provider_manager = ProviderManager()
_load_provider_from_settings()
return _provider_manager
def _load_provider_from_settings() -> None:
"""
Load sandbox provider from system settings and configure the provider manager.
This function reads the system settings to determine which provider is active
and initializes it with the appropriate configuration.
"""
global _provider_manager
if _provider_manager is None:
return
try:
# Get active provider type
provider_type_settings = SystemSettingsService.get_by_name("sandbox.provider_type")
if not provider_type_settings:
raise RuntimeError(
"Sandbox provider type not configured. Please set 'sandbox.provider_type' in system settings."
)
provider_type = provider_type_settings[0].value
# Get provider configuration
provider_config_settings = SystemSettingsService.get_by_name(f"sandbox.{provider_type}")
if not provider_config_settings:
logger.warning(f"No configuration found for provider: {provider_type}")
config = {}
else:
try:
config = json.loads(provider_config_settings[0].value)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse sandbox config for {provider_type}: {e}")
config = {}
# Import and instantiate the provider
from agent.sandbox.providers import (
SelfManagedProvider,
AliyunCodeInterpreterProvider,
E2BProvider,
)
provider_classes = {
"self_managed": SelfManagedProvider,
"aliyun_codeinterpreter": AliyunCodeInterpreterProvider,
"e2b": E2BProvider,
}
if provider_type not in provider_classes:
logger.error(f"Unknown provider type: {provider_type}")
return
provider_class = provider_classes[provider_type]
provider = provider_class()
# Initialize the provider
if not provider.initialize(config):
logger.error(f"Failed to initialize sandbox provider: {provider_type}. Config keys: {list(config.keys())}")
return
# Set the active provider
_provider_manager.set_provider(provider_type, provider)
logger.info(f"Sandbox provider '{provider_type}' initialized successfully")
except Exception as e:
logger.error(f"Failed to load sandbox provider from settings: {e}")
import traceback
traceback.print_exc()
def reload_provider() -> None:
"""
Reload the sandbox provider from system settings.
Use this function when sandbox settings have been updated.
"""
global _provider_manager
_provider_manager = None
_load_provider_from_settings()
def execute_code(
code: str,
language: str = "python",
timeout: int = 30,
arguments: Optional[Dict[str, Any]] = None
) -> ExecutionResult:
"""
Execute code in the configured sandbox.
This is the main entry point for agent components to execute code.
Args:
code: Source code to execute
language: Programming language (python, nodejs, javascript)
timeout: Maximum execution time in seconds
arguments: Optional arguments dict to pass to main() function
Returns:
ExecutionResult containing stdout, stderr, exit_code, and metadata
Raises:
RuntimeError: If no provider is configured or execution fails
"""
provider_manager = get_provider_manager()
if not provider_manager.is_configured():
raise RuntimeError(
"No sandbox provider configured. Please configure sandbox settings in the admin panel."
)
provider = provider_manager.get_provider()
# Create a sandbox instance
instance = provider.create_instance(template=language)
try:
# Execute the code
result = provider.execute_code(
instance_id=instance.instance_id,
code=code,
language=language,
timeout=timeout,
arguments=arguments
)
return result
finally:
# Clean up the instance
try:
provider.destroy_instance(instance.instance_id)
except Exception as e:
logger.warning(f"Failed to destroy sandbox instance {instance.instance_id}: {e}")
def health_check() -> bool:
"""
Check if the sandbox provider is healthy.
Returns:
True if provider is configured and healthy, False otherwise
"""
try:
provider_manager = get_provider_manager()
if not provider_manager.is_configured():
return False
provider = provider_manager.get_provider()
return provider.health_check()
except Exception as e:
logger.error(f"Sandbox health check failed: {e}")
return False
def get_provider_info() -> Dict[str, Any]:
"""
Get information about the current sandbox provider.
Returns:
Dictionary with provider information:
- provider_type: Type of the active provider
- configured: Whether provider is configured
- healthy: Whether provider is healthy
"""
try:
provider_manager = get_provider_manager()
return {
"provider_type": provider_manager.get_provider_name(),
"configured": provider_manager.is_configured(),
"healthy": health_check(),
}
except Exception as e:
logger.error(f"Failed to get provider info: {e}")
return {
"provider_type": None,
"configured": False,
"healthy": False,
}

View File

@ -1,37 +0,0 @@
FROM python:3.11-slim-bookworm
RUN grep -rl 'deb.debian.org' /etc/apt/ | xargs sed -i 's|http[s]*://deb.debian.org|https://mirrors.tuna.tsinghua.edu.cn|g' && \
apt-get update && \
apt-get install -y curl gcc && \
rm -rf /var/lib/apt/lists/*
ARG TARGETARCH
ARG TARGETVARIANT
RUN set -eux; \
case "${TARGETARCH}${TARGETVARIANT}" in \
amd64) DOCKER_ARCH=x86_64 ;; \
arm64) DOCKER_ARCH=aarch64 ;; \
armv7) DOCKER_ARCH=armhf ;; \
armv6) DOCKER_ARCH=armel ;; \
arm64v8) DOCKER_ARCH=aarch64 ;; \
arm64v7) DOCKER_ARCH=armhf ;; \
arm*) DOCKER_ARCH=armhf ;; \
ppc64le) DOCKER_ARCH=ppc64le ;; \
s390x) DOCKER_ARCH=s390x ;; \
*) echo "Unsupported architecture: ${TARGETARCH}${TARGETVARIANT}" && exit 1 ;; \
esac; \
echo "Downloading Docker for architecture: ${DOCKER_ARCH}"; \
curl -fsSL "https://download.docker.com/linux/static/stable/${DOCKER_ARCH}/docker-29.1.0.tgz" | \
tar xz -C /usr/local/bin --strip-components=1 docker/docker; \
ln -sf /usr/local/bin/docker /usr/bin/docker
COPY --from=ghcr.io/astral-sh/uv:0.7.5 /uv /uvx /bin/
ENV UV_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
WORKDIR /app
COPY . .
RUN uv pip install --system -r requirements.txt
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "9385"]

View File

@ -1,43 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Sandbox providers package.
This package contains:
- base.py: Base interface for all sandbox providers
- manager.py: Provider manager for managing active provider
- self_managed.py: Self-managed provider implementation (wraps existing executor_manager)
- aliyun_codeinterpreter.py: Aliyun Code Interpreter provider implementation
Official Documentation: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter
- e2b.py: E2B provider implementation
"""
from .base import SandboxProvider, SandboxInstance, ExecutionResult
from .manager import ProviderManager
from .self_managed import SelfManagedProvider
from .aliyun_codeinterpreter import AliyunCodeInterpreterProvider
from .e2b import E2BProvider
__all__ = [
"SandboxProvider",
"SandboxInstance",
"ExecutionResult",
"ProviderManager",
"SelfManagedProvider",
"AliyunCodeInterpreterProvider",
"E2BProvider",
]

View File

@ -1,512 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Aliyun Code Interpreter provider implementation.
This provider integrates with Aliyun Function Compute Code Interpreter service
for secure code execution in serverless microVMs using the official agentrun-sdk.
Official Documentation: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter
Official SDK: https://github.com/Serverless-Devs/agentrun-sdk-python
https://api.aliyun.com/api/AgentRun/2025-09-10/CreateTemplate?lang=PYTHON
https://api.aliyun.com/api/AgentRun/2025-09-10/CreateSandbox?lang=PYTHON
"""
import logging
import os
import time
from typing import Dict, Any, List, Optional
from datetime import datetime, timezone
from agentrun.sandbox import TemplateType, CodeLanguage, Template, TemplateInput, Sandbox
from agentrun.utils.config import Config
from agentrun.utils.exception import ServerError
from .base import SandboxProvider, SandboxInstance, ExecutionResult
logger = logging.getLogger(__name__)
class AliyunCodeInterpreterProvider(SandboxProvider):
"""
Aliyun Code Interpreter provider implementation.
This provider uses the official agentrun-sdk to interact with
Aliyun Function Compute's Code Interpreter service.
"""
def __init__(self):
self.access_key_id: Optional[str] = None
self.access_key_secret: Optional[str] = None
self.account_id: Optional[str] = None
self.region: str = "cn-hangzhou"
self.template_name: str = ""
self.timeout: int = 30
self._initialized: bool = False
self._config: Optional[Config] = None
def initialize(self, config: Dict[str, Any]) -> bool:
"""
Initialize the provider with Aliyun credentials.
Args:
config: Configuration dictionary with keys:
- access_key_id: Aliyun AccessKey ID
- access_key_secret: Aliyun AccessKey Secret
- account_id: Aliyun primary account ID (主账号ID)
- region: Region (default: "cn-hangzhou")
- template_name: Optional sandbox template name
- timeout: Request timeout in seconds (default: 30, max 30)
Returns:
True if initialization successful, False otherwise
"""
# Get values from config or environment variables
access_key_id = config.get("access_key_id") or os.getenv("AGENTRUN_ACCESS_KEY_ID")
access_key_secret = config.get("access_key_secret") or os.getenv("AGENTRUN_ACCESS_KEY_SECRET")
account_id = config.get("account_id") or os.getenv("AGENTRUN_ACCOUNT_ID")
region = config.get("region") or os.getenv("AGENTRUN_REGION", "cn-hangzhou")
self.access_key_id = access_key_id
self.access_key_secret = access_key_secret
self.account_id = account_id
self.region = region
self.template_name = config.get("template_name", "")
self.timeout = min(config.get("timeout", 30), 30) # Max 30 seconds
logger.info(f"Aliyun Code Interpreter: Initializing with account_id={self.account_id}, region={self.region}")
# Validate required fields
if not self.access_key_id or not self.access_key_secret:
logger.error("Aliyun Code Interpreter: Missing access_key_id or access_key_secret")
return False
if not self.account_id:
logger.error("Aliyun Code Interpreter: Missing account_id (主账号ID)")
return False
# Create SDK configuration
try:
logger.info(f"Aliyun Code Interpreter: Creating Config object with account_id={self.account_id}")
self._config = Config(
access_key_id=self.access_key_id,
access_key_secret=self.access_key_secret,
account_id=self.account_id,
region_id=self.region,
timeout=self.timeout,
)
logger.info("Aliyun Code Interpreter: Config object created successfully")
# Verify connection with health check
if not self.health_check():
logger.error(f"Aliyun Code Interpreter: Health check failed for region {self.region}")
return False
self._initialized = True
logger.info(f"Aliyun Code Interpreter: Initialized successfully for region {self.region}")
return True
except Exception as e:
logger.error(f"Aliyun Code Interpreter: Initialization failed - {str(e)}")
return False
def create_instance(self, template: str = "python") -> SandboxInstance:
"""
Create a new sandbox instance in Aliyun Code Interpreter.
Args:
template: Programming language (python, javascript)
Returns:
SandboxInstance object
Raises:
RuntimeError: If instance creation fails
"""
if not self._initialized or not self._config:
raise RuntimeError("Provider not initialized. Call initialize() first.")
# Normalize language
language = self._normalize_language(template)
try:
# Get or create template
from agentrun.sandbox import Sandbox
if self.template_name:
# Use existing template
template_name = self.template_name
else:
# Try to get default template, or create one if it doesn't exist
default_template_name = f"ragflow-{language}-default"
try:
# Check if template exists
Template.get_by_name(default_template_name, config=self._config)
template_name = default_template_name
except Exception:
# Create default template if it doesn't exist
template_input = TemplateInput(
template_name=default_template_name,
template_type=TemplateType.CODE_INTERPRETER,
)
Template.create(template_input, config=self._config)
template_name = default_template_name
# Create sandbox directly
sandbox = Sandbox.create(
template_type=TemplateType.CODE_INTERPRETER,
template_name=template_name,
sandbox_idle_timeout_seconds=self.timeout,
config=self._config,
)
instance_id = sandbox.sandbox_id
return SandboxInstance(
instance_id=instance_id,
provider="aliyun_codeinterpreter",
status="READY",
metadata={
"language": language,
"region": self.region,
"account_id": self.account_id,
"template_name": template_name,
"created_at": datetime.now(timezone.utc).isoformat(),
},
)
except ServerError as e:
raise RuntimeError(f"Failed to create sandbox instance: {str(e)}")
except Exception as e:
raise RuntimeError(f"Unexpected error creating instance: {str(e)}")
def execute_code(self, instance_id: str, code: str, language: str, timeout: int = 10, arguments: Optional[Dict[str, Any]] = None) -> ExecutionResult:
"""
Execute code in the Aliyun Code Interpreter instance.
Args:
instance_id: ID of the sandbox instance
code: Source code to execute
language: Programming language (python, javascript)
timeout: Maximum execution time in seconds (max 30)
arguments: Optional arguments dict to pass to main() function
Returns:
ExecutionResult containing stdout, stderr, exit_code, and metadata
Raises:
RuntimeError: If execution fails
TimeoutError: If execution exceeds timeout
"""
if not self._initialized or not self._config:
raise RuntimeError("Provider not initialized. Call initialize() first.")
# Normalize language
normalized_lang = self._normalize_language(language)
# Enforce 30-second hard limit
timeout = min(timeout or self.timeout, 30)
try:
# Connect to existing sandbox instance
sandbox = Sandbox.connect(sandbox_id=instance_id, config=self._config)
# Convert language string to CodeLanguage enum
code_language = CodeLanguage.PYTHON if normalized_lang == "python" else CodeLanguage.JAVASCRIPT
# Wrap code to call main() function
# Matches self_managed provider behavior: call main(**arguments)
if normalized_lang == "python":
# Build arguments string for main() call
if arguments:
import json as json_module
args_json = json_module.dumps(arguments)
wrapped_code = f'''{code}
if __name__ == "__main__":
import json
result = main(**{args_json})
print(json.dumps(result) if isinstance(result, dict) else result)
'''
else:
wrapped_code = f'''{code}
if __name__ == "__main__":
import json
result = main()
print(json.dumps(result) if isinstance(result, dict) else result)
'''
else: # javascript
if arguments:
import json as json_module
args_json = json_module.dumps(arguments)
wrapped_code = f'''{code}
// Call main and output result
const result = main({args_json});
console.log(typeof result === 'object' ? JSON.stringify(result) : String(result));
'''
else:
wrapped_code = f'''{code}
// Call main and output result
const result = main();
console.log(typeof result === 'object' ? JSON.stringify(result) : String(result));
'''
logger.debug(f"Aliyun Code Interpreter: Wrapped code (first 200 chars): {wrapped_code[:200]}")
start_time = time.time()
# Execute code using SDK's simplified execute endpoint
logger.info(f"Aliyun Code Interpreter: Executing code (language={normalized_lang}, timeout={timeout})")
logger.debug(f"Aliyun Code Interpreter: Original code (first 200 chars): {code[:200]}")
result = sandbox.context.execute(
code=wrapped_code,
language=code_language,
timeout=timeout,
)
execution_time = time.time() - start_time
logger.info(f"Aliyun Code Interpreter: Execution completed in {execution_time:.2f}s")
logger.debug(f"Aliyun Code Interpreter: Raw SDK result: {result}")
# Parse execution result
results = result.get("results", []) if isinstance(result, dict) else []
logger.info(f"Aliyun Code Interpreter: Parsed {len(results)} result items")
# Extract stdout and stderr from results
stdout_parts = []
stderr_parts = []
exit_code = 0
execution_status = "ok"
for item in results:
result_type = item.get("type", "")
text = item.get("text", "")
if result_type == "stdout":
stdout_parts.append(text)
elif result_type == "stderr":
stderr_parts.append(text)
exit_code = 1 # Error occurred
elif result_type == "endOfExecution":
execution_status = item.get("status", "ok")
if execution_status != "ok":
exit_code = 1
elif result_type == "error":
stderr_parts.append(text)
exit_code = 1
stdout = "\n".join(stdout_parts)
stderr = "\n".join(stderr_parts)
logger.info(f"Aliyun Code Interpreter: stdout length={len(stdout)}, stderr length={len(stderr)}, exit_code={exit_code}")
if stdout:
logger.debug(f"Aliyun Code Interpreter: stdout (first 200 chars): {stdout[:200]}")
if stderr:
logger.debug(f"Aliyun Code Interpreter: stderr (first 200 chars): {stderr[:200]}")
return ExecutionResult(
stdout=stdout,
stderr=stderr,
exit_code=exit_code,
execution_time=execution_time,
metadata={
"instance_id": instance_id,
"language": normalized_lang,
"context_id": result.get("contextId") if isinstance(result, dict) else None,
"timeout": timeout,
},
)
except ServerError as e:
if "timeout" in str(e).lower():
raise TimeoutError(f"Execution timed out after {timeout} seconds")
raise RuntimeError(f"Failed to execute code: {str(e)}")
except Exception as e:
raise RuntimeError(f"Unexpected error during execution: {str(e)}")
def destroy_instance(self, instance_id: str) -> bool:
"""
Destroy an Aliyun Code Interpreter instance.
Args:
instance_id: ID of the instance to destroy
Returns:
True if destruction successful, False otherwise
"""
if not self._initialized or not self._config:
raise RuntimeError("Provider not initialized. Call initialize() first.")
try:
# Delete sandbox by ID directly
Sandbox.delete_by_id(sandbox_id=instance_id)
logger.info(f"Successfully destroyed sandbox instance {instance_id}")
return True
except ServerError as e:
logger.error(f"Failed to destroy instance {instance_id}: {str(e)}")
return False
except Exception as e:
logger.error(f"Unexpected error destroying instance {instance_id}: {str(e)}")
return False
def health_check(self) -> bool:
"""
Check if the Aliyun Code Interpreter service is accessible.
Returns:
True if provider is healthy, False otherwise
"""
if not self._initialized and not (self.access_key_id and self.account_id):
return False
try:
# Try to list templates to verify connection
from agentrun.sandbox import Template
templates = Template.list(config=self._config)
return templates is not None
except Exception as e:
logger.warning(f"Aliyun Code Interpreter health check failed: {str(e)}")
# If we get any response (even an error), the service is reachable
return "connection" not in str(e).lower()
def get_supported_languages(self) -> List[str]:
"""
Get list of supported programming languages.
Returns:
List of language identifiers
"""
return ["python", "javascript"]
@staticmethod
def get_config_schema() -> Dict[str, Dict]:
"""
Return configuration schema for Aliyun Code Interpreter provider.
Returns:
Dictionary mapping field names to their schema definitions
"""
return {
"access_key_id": {
"type": "string",
"required": True,
"label": "Access Key ID",
"placeholder": "LTAI5t...",
"description": "Aliyun AccessKey ID for authentication",
"secret": False,
},
"access_key_secret": {
"type": "string",
"required": True,
"label": "Access Key Secret",
"placeholder": "••••••••••••••••",
"description": "Aliyun AccessKey Secret for authentication",
"secret": True,
},
"account_id": {
"type": "string",
"required": True,
"label": "Account ID",
"placeholder": "1234567890...",
"description": "Aliyun primary account ID (主账号ID), required for API calls",
},
"region": {
"type": "string",
"required": False,
"label": "Region",
"default": "cn-hangzhou",
"description": "Aliyun region for Code Interpreter service",
"options": ["cn-hangzhou", "cn-beijing", "cn-shanghai", "cn-shenzhen", "cn-guangzhou"],
},
"template_name": {
"type": "string",
"required": False,
"label": "Template Name",
"placeholder": "my-interpreter",
"description": "Optional sandbox template name for pre-configured environments",
},
"timeout": {
"type": "integer",
"required": False,
"label": "Execution Timeout (seconds)",
"default": 30,
"min": 1,
"max": 30,
"description": "Code execution timeout (max 30 seconds - hard limit)",
},
}
def validate_config(self, config: Dict[str, Any]) -> tuple[bool, Optional[str]]:
"""
Validate Aliyun-specific configuration.
Args:
config: Configuration dictionary to validate
Returns:
Tuple of (is_valid, error_message)
"""
# Validate access key format
access_key_id = config.get("access_key_id", "")
if access_key_id and not access_key_id.startswith("LTAI"):
return False, "Invalid AccessKey ID format (should start with 'LTAI')"
# Validate account ID
account_id = config.get("account_id", "")
if not account_id:
return False, "Account ID is required"
# Validate region
valid_regions = ["cn-hangzhou", "cn-beijing", "cn-shanghai", "cn-shenzhen", "cn-guangzhou"]
region = config.get("region", "cn-hangzhou")
if region and region not in valid_regions:
return False, f"Invalid region. Must be one of: {', '.join(valid_regions)}"
# Validate timeout range (max 30 seconds)
timeout = config.get("timeout", 30)
if isinstance(timeout, int) and (timeout < 1 or timeout > 30):
return False, "Timeout must be between 1 and 30 seconds"
return True, None
def _normalize_language(self, language: str) -> str:
"""
Normalize language identifier to Aliyun format.
Args:
language: Language identifier (python, python3, javascript, nodejs)
Returns:
Normalized language identifier
"""
if not language:
return "python"
lang_lower = language.lower()
if lang_lower in ("python", "python3"):
return "python"
elif lang_lower in ("javascript", "nodejs"):
return "javascript"
else:
return language

View File

@ -1,212 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Base interface for sandbox providers.
Each sandbox provider (self-managed, SaaS) implements this interface
to provide code execution capabilities.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Any, Optional, List
@dataclass
class SandboxInstance:
"""Represents a sandbox execution instance"""
instance_id: str
provider: str
status: str # running, stopped, error
metadata: Dict[str, Any]
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
@dataclass
class ExecutionResult:
"""Result of code execution in a sandbox"""
stdout: str
stderr: str
exit_code: int
execution_time: float # in seconds
metadata: Dict[str, Any]
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
class SandboxProvider(ABC):
"""
Base interface for all sandbox providers.
Each provider implementation (self-managed, Aliyun OpenSandbox, E2B, etc.)
must implement these methods to provide code execution capabilities.
"""
@abstractmethod
def initialize(self, config: Dict[str, Any]) -> bool:
"""
Initialize the provider with configuration.
Args:
config: Provider-specific configuration dictionary
Returns:
True if initialization successful, False otherwise
"""
pass
@abstractmethod
def create_instance(self, template: str = "python") -> SandboxInstance:
"""
Create a new sandbox instance.
Args:
template: Programming language/template for the instance
(e.g., "python", "nodejs", "bash")
Returns:
SandboxInstance object representing the created instance
Raises:
RuntimeError: If instance creation fails
"""
pass
@abstractmethod
def execute_code(
self,
instance_id: str,
code: str,
language: str,
timeout: int = 10,
arguments: Optional[Dict[str, Any]] = None
) -> ExecutionResult:
"""
Execute code in a sandbox instance.
Args:
instance_id: ID of the sandbox instance
code: Source code to execute
language: Programming language (python, javascript, etc.)
timeout: Maximum execution time in seconds
arguments: Optional arguments dict to pass to main() function
Returns:
ExecutionResult containing stdout, stderr, exit_code, and metadata
Raises:
RuntimeError: If execution fails
TimeoutError: If execution exceeds timeout
"""
pass
@abstractmethod
def destroy_instance(self, instance_id: str) -> bool:
"""
Destroy a sandbox instance.
Args:
instance_id: ID of the instance to destroy
Returns:
True if destruction successful, False otherwise
Raises:
RuntimeError: If destruction fails
"""
pass
@abstractmethod
def health_check(self) -> bool:
"""
Check if the provider is healthy and accessible.
Returns:
True if provider is healthy, False otherwise
"""
pass
@abstractmethod
def get_supported_languages(self) -> List[str]:
"""
Get list of supported programming languages.
Returns:
List of language identifiers (e.g., ["python", "javascript", "go"])
"""
pass
@staticmethod
def get_config_schema() -> Dict[str, Dict]:
"""
Return configuration schema for this provider.
The schema defines what configuration fields are required/optional,
their types, validation rules, and UI labels.
Returns:
Dictionary mapping field names to their schema definitions.
Example:
{
"endpoint": {
"type": "string",
"required": True,
"label": "API Endpoint",
"placeholder": "http://localhost:9385"
},
"timeout": {
"type": "integer",
"default": 30,
"label": "Timeout (seconds)",
"min": 5,
"max": 300
}
}
"""
return {}
def validate_config(self, config: Dict[str, Any]) -> tuple[bool, Optional[str]]:
"""
Validate provider-specific configuration.
This method allows providers to implement custom validation logic beyond
the basic schema validation. Override this method to add provider-specific
checks like URL format validation, API key format validation, etc.
Args:
config: Configuration dictionary to validate
Returns:
Tuple of (is_valid, error_message):
- is_valid: True if configuration is valid, False otherwise
- error_message: Error message if invalid, None if valid
Example:
>>> def validate_config(self, config):
>>> endpoint = config.get("endpoint", "")
>>> if not endpoint.startswith(("http://", "https://")):
>>> return False, "Endpoint must start with http:// or https://"
>>> return True, None
"""
# Default implementation: no custom validation
return True, None

View File

@ -1,233 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
E2B provider implementation.
This provider integrates with E2B Cloud for cloud-based code execution
using Firecracker microVMs.
"""
import uuid
from typing import Dict, Any, List
from .base import SandboxProvider, SandboxInstance, ExecutionResult
class E2BProvider(SandboxProvider):
"""
E2B provider implementation.
This provider uses E2B Cloud service for secure code execution
in Firecracker microVMs.
"""
def __init__(self):
self.api_key: str = ""
self.region: str = "us"
self.timeout: int = 30
self._initialized: bool = False
def initialize(self, config: Dict[str, Any]) -> bool:
"""
Initialize the provider with E2B credentials.
Args:
config: Configuration dictionary with keys:
- api_key: E2B API key
- region: Region (us, eu) (default: "us")
- timeout: Request timeout in seconds (default: 30)
Returns:
True if initialization successful, False otherwise
"""
self.api_key = config.get("api_key", "")
self.region = config.get("region", "us")
self.timeout = config.get("timeout", 30)
# Validate required fields
if not self.api_key:
return False
# TODO: Implement actual E2B API client initialization
# For now, we'll mark as initialized but actual API calls will fail
self._initialized = True
return True
def create_instance(self, template: str = "python") -> SandboxInstance:
"""
Create a new sandbox instance in E2B.
Args:
template: Programming language template (python, nodejs, go, bash)
Returns:
SandboxInstance object
Raises:
RuntimeError: If instance creation fails
"""
if not self._initialized:
raise RuntimeError("Provider not initialized. Call initialize() first.")
# Normalize language
language = self._normalize_language(template)
# TODO: Implement actual E2B API call
# POST /sandbox with template
instance_id = str(uuid.uuid4())
return SandboxInstance(
instance_id=instance_id,
provider="e2b",
status="running",
metadata={
"language": language,
"region": self.region,
}
)
def execute_code(
self,
instance_id: str,
code: str,
language: str,
timeout: int = 10
) -> ExecutionResult:
"""
Execute code in the E2B instance.
Args:
instance_id: ID of the sandbox instance
code: Source code to execute
language: Programming language (python, nodejs, go, bash)
timeout: Maximum execution time in seconds
Returns:
ExecutionResult containing stdout, stderr, exit_code, and metadata
Raises:
RuntimeError: If execution fails
TimeoutError: If execution exceeds timeout
"""
if not self._initialized:
raise RuntimeError("Provider not initialized. Call initialize() first.")
# TODO: Implement actual E2B API call
# POST /sandbox/{sandboxID}/execute
raise RuntimeError(
"E2B provider is not yet fully implemented. "
"Please use the self-managed provider or implement the E2B API integration. "
"See https://github.com/e2b-dev/e2b for API documentation."
)
def destroy_instance(self, instance_id: str) -> bool:
"""
Destroy an E2B instance.
Args:
instance_id: ID of the instance to destroy
Returns:
True if destruction successful, False otherwise
"""
if not self._initialized:
raise RuntimeError("Provider not initialized. Call initialize() first.")
# TODO: Implement actual E2B API call
# DELETE /sandbox/{sandboxID}
return True
def health_check(self) -> bool:
"""
Check if the E2B service is accessible.
Returns:
True if provider is healthy, False otherwise
"""
if not self._initialized:
return False
# TODO: Implement actual E2B health check API call
# GET /healthz or similar
# For now, return True if initialized with API key
return bool(self.api_key)
def get_supported_languages(self) -> List[str]:
"""
Get list of supported programming languages.
Returns:
List of language identifiers
"""
return ["python", "nodejs", "javascript", "go", "bash"]
@staticmethod
def get_config_schema() -> Dict[str, Dict]:
"""
Return configuration schema for E2B provider.
Returns:
Dictionary mapping field names to their schema definitions
"""
return {
"api_key": {
"type": "string",
"required": True,
"label": "API Key",
"placeholder": "e2b_sk_...",
"description": "E2B API key for authentication",
"secret": True,
},
"region": {
"type": "string",
"required": False,
"label": "Region",
"default": "us",
"description": "E2B service region (us or eu)",
},
"timeout": {
"type": "integer",
"required": False,
"label": "Request Timeout (seconds)",
"default": 30,
"min": 5,
"max": 300,
"description": "API request timeout for code execution",
}
}
def _normalize_language(self, language: str) -> str:
"""
Normalize language identifier to E2B template format.
Args:
language: Language identifier
Returns:
Normalized language identifier
"""
if not language:
return "python"
lang_lower = language.lower()
if lang_lower in ("python", "python3"):
return "python"
elif lang_lower in ("javascript", "nodejs"):
return "nodejs"
else:
return language

View File

@ -1,78 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Provider manager for sandbox providers.
Since sandbox configuration is global (system-level), we only use one
active provider at a time. This manager is a thin wrapper that holds a reference
to the currently active provider.
"""
from typing import Optional
from .base import SandboxProvider
class ProviderManager:
"""
Manages the currently active sandbox provider.
With global configuration, there's only one active provider at a time.
This manager simply holds a reference to that provider.
"""
def __init__(self):
"""Initialize an empty provider manager."""
self.current_provider: Optional[SandboxProvider] = None
self.current_provider_name: Optional[str] = None
def set_provider(self, name: str, provider: SandboxProvider):
"""
Set the active provider.
Args:
name: Provider identifier (e.g., "self_managed", "e2b")
provider: Provider instance
"""
self.current_provider = provider
self.current_provider_name = name
def get_provider(self) -> Optional[SandboxProvider]:
"""
Get the active provider.
Returns:
Currently active SandboxProvider instance, or None if not set
"""
return self.current_provider
def get_provider_name(self) -> Optional[str]:
"""
Get the active provider name.
Returns:
Provider name (e.g., "self_managed"), or None if not set
"""
return self.current_provider_name
def is_configured(self) -> bool:
"""
Check if a provider is configured.
Returns:
True if a provider is set, False otherwise
"""
return self.current_provider is not None

View File

@ -1,359 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Self-managed sandbox provider implementation.
This provider wraps the existing executor_manager HTTP API which manages
a pool of Docker containers with gVisor for secure code execution.
"""
import base64
import time
import uuid
from typing import Dict, Any, List, Optional
import requests
from .base import SandboxProvider, SandboxInstance, ExecutionResult
class SelfManagedProvider(SandboxProvider):
"""
Self-managed sandbox provider using Daytona/Docker.
This provider communicates with the executor_manager HTTP API
which manages a pool of containers for code execution.
"""
def __init__(self):
self.endpoint: str = "http://localhost:9385"
self.timeout: int = 30
self.max_retries: int = 3
self.pool_size: int = 10
self._initialized: bool = False
def initialize(self, config: Dict[str, Any]) -> bool:
"""
Initialize the provider with configuration.
Args:
config: Configuration dictionary with keys:
- endpoint: HTTP endpoint (default: "http://localhost:9385")
- timeout: Request timeout in seconds (default: 30)
- max_retries: Maximum retry attempts (default: 3)
- pool_size: Container pool size for info (default: 10)
Returns:
True if initialization successful, False otherwise
"""
self.endpoint = config.get("endpoint", "http://localhost:9385")
self.timeout = config.get("timeout", 30)
self.max_retries = config.get("max_retries", 3)
self.pool_size = config.get("pool_size", 10)
# Validate endpoint is accessible
if not self.health_check():
# Try to fall back to SANDBOX_HOST from settings if we are using localhost
if "localhost" in self.endpoint or "127.0.0.1" in self.endpoint:
try:
from api import settings
if settings.SANDBOX_HOST and settings.SANDBOX_HOST not in self.endpoint:
original_endpoint = self.endpoint
self.endpoint = f"http://{settings.SANDBOX_HOST}:9385"
if self.health_check():
import logging
logging.warning(f"Sandbox self_managed: Connected using settings.SANDBOX_HOST fallback: {self.endpoint} (original: {original_endpoint})")
self._initialized = True
return True
else:
self.endpoint = original_endpoint # Restore if fallback also fails
except ImportError:
pass
return False
self._initialized = True
return True
def create_instance(self, template: str = "python") -> SandboxInstance:
"""
Create a new sandbox instance.
Note: For self-managed provider, instances are managed internally
by the executor_manager's container pool. This method returns
a logical instance handle.
Args:
template: Programming language (python, nodejs)
Returns:
SandboxInstance object
Raises:
RuntimeError: If instance creation fails
"""
if not self._initialized:
raise RuntimeError("Provider not initialized. Call initialize() first.")
# Normalize language
language = self._normalize_language(template)
# The executor_manager manages instances internally via container pool
# We create a logical instance ID for tracking
instance_id = str(uuid.uuid4())
return SandboxInstance(
instance_id=instance_id,
provider="self_managed",
status="running",
metadata={
"language": language,
"endpoint": self.endpoint,
"pool_size": self.pool_size,
}
)
def execute_code(
self,
instance_id: str,
code: str,
language: str,
timeout: int = 10,
arguments: Optional[Dict[str, Any]] = None
) -> ExecutionResult:
"""
Execute code in the sandbox.
Args:
instance_id: ID of the sandbox instance (not used for self-managed)
code: Source code to execute
language: Programming language (python, nodejs, javascript)
timeout: Maximum execution time in seconds
arguments: Optional arguments dict to pass to main() function
Returns:
ExecutionResult containing stdout, stderr, exit_code, and metadata
Raises:
RuntimeError: If execution fails
TimeoutError: If execution exceeds timeout
"""
if not self._initialized:
raise RuntimeError("Provider not initialized. Call initialize() first.")
# Normalize language
normalized_lang = self._normalize_language(language)
# Prepare request
code_b64 = base64.b64encode(code.encode("utf-8")).decode("utf-8")
payload = {
"code_b64": code_b64,
"language": normalized_lang,
"arguments": arguments or {}
}
url = f"{self.endpoint}/run"
exec_timeout = timeout or self.timeout
start_time = time.time()
try:
response = requests.post(
url,
json=payload,
timeout=exec_timeout,
headers={"Content-Type": "application/json"}
)
execution_time = time.time() - start_time
if response.status_code != 200:
raise RuntimeError(
f"HTTP {response.status_code}: {response.text}"
)
result = response.json()
return ExecutionResult(
stdout=result.get("stdout", ""),
stderr=result.get("stderr", ""),
exit_code=result.get("exit_code", 0),
execution_time=execution_time,
metadata={
"status": result.get("status"),
"time_used_ms": result.get("time_used_ms"),
"memory_used_kb": result.get("memory_used_kb"),
"detail": result.get("detail"),
"instance_id": instance_id,
}
)
except requests.Timeout:
execution_time = time.time() - start_time
raise TimeoutError(
f"Execution timed out after {exec_timeout} seconds"
)
except requests.RequestException as e:
raise RuntimeError(f"HTTP request failed: {str(e)}")
def destroy_instance(self, instance_id: str) -> bool:
"""
Destroy a sandbox instance.
Note: For self-managed provider, instances are returned to the
internal pool automatically by executor_manager after execution.
This is a no-op for tracking purposes.
Args:
instance_id: ID of the instance to destroy
Returns:
True (always succeeds for self-managed)
"""
# The executor_manager manages container lifecycle internally
# Container is returned to pool after execution
return True
def health_check(self) -> bool:
"""
Check if the provider is healthy and accessible.
Returns:
True if provider is healthy, False otherwise
"""
try:
url = f"{self.endpoint}/healthz"
response = requests.get(url, timeout=5)
return response.status_code == 200
except Exception:
return False
def get_supported_languages(self) -> List[str]:
"""
Get list of supported programming languages.
Returns:
List of language identifiers
"""
return ["python", "nodejs", "javascript"]
@staticmethod
def get_config_schema() -> Dict[str, Dict]:
"""
Return configuration schema for self-managed provider.
Returns:
Dictionary mapping field names to their schema definitions
"""
return {
"endpoint": {
"type": "string",
"required": True,
"label": "Executor Manager Endpoint",
"placeholder": "http://localhost:9385",
"default": "http://localhost:9385",
"description": "HTTP endpoint of the executor_manager service"
},
"timeout": {
"type": "integer",
"required": False,
"label": "Request Timeout (seconds)",
"default": 30,
"min": 5,
"max": 300,
"description": "HTTP request timeout for code execution"
},
"max_retries": {
"type": "integer",
"required": False,
"label": "Max Retries",
"default": 3,
"min": 0,
"max": 10,
"description": "Maximum number of retry attempts for failed requests"
},
"pool_size": {
"type": "integer",
"required": False,
"label": "Container Pool Size",
"default": 10,
"min": 1,
"max": 100,
"description": "Size of the container pool (configured in executor_manager)"
}
}
def _normalize_language(self, language: str) -> str:
"""
Normalize language identifier to executor_manager format.
Args:
language: Language identifier (python, python3, nodejs, javascript)
Returns:
Normalized language identifier
"""
if not language:
return "python"
lang_lower = language.lower()
if lang_lower in ("python", "python3"):
return "python"
elif lang_lower in ("javascript", "nodejs"):
return "nodejs"
else:
return language
def validate_config(self, config: dict) -> tuple[bool, Optional[str]]:
"""
Validate self-managed provider configuration.
Performs custom validation beyond the basic schema validation,
such as checking URL format.
Args:
config: Configuration dictionary to validate
Returns:
Tuple of (is_valid, error_message)
"""
# Validate endpoint URL format
endpoint = config.get("endpoint", "")
if endpoint:
# Check if it's a valid HTTP/HTTPS URL or localhost
import re
url_pattern = r'^(https?://|http://localhost|http://[\d\.]+:[a-z]+:[/]|http://[\w\.]+:)'
if not re.match(url_pattern, endpoint):
return False, f"Invalid endpoint format: {endpoint}. Must start with http:// or https://"
# Validate pool_size is positive
pool_size = config.get("pool_size", 10)
if isinstance(pool_size, int) and pool_size <= 0:
return False, "Pool size must be greater than 0"
# Validate timeout is reasonable
timeout = config.get("timeout", 30)
if isinstance(timeout, int) and (timeout < 1 or timeout > 600):
return False, "Timeout must be between 1 and 600 seconds"
# Validate max_retries
max_retries = config.get("max_retries", 3)
if isinstance(max_retries, int) and (max_retries < 0 or max_retries > 10):
return False, "Max retries must be between 0 and 10"
return True, None

View File

@ -1,261 +0,0 @@
# Aliyun Code Interpreter Provider - 使用官方 SDK
## 重要变更
### 官方资源
- **Code Interpreter API**: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter
- **官方 SDK**: https://github.com/Serverless-Devs/agentrun-sdk-python
- **SDK 文档**: https://docs.agent.run
## 使用官方 SDK 的优势
从手动 HTTP 请求迁移到官方 SDK (`agentrun-sdk`) 有以下优势:
### 1. **自动签名认证**
- SDK 自动处理 Aliyun API 签名(无需手动实现 `Authorization` 头)
- 支持多种认证方式AccessKey、STS Token
- 自动读取环境变量
### 2. **简化的 API**
```python
# 旧实现(手动 HTTP 请求)
response = requests.post(
f"{DATA_ENDPOINT}/sandboxes/{sandbox_id}/execute",
headers={"X-Acs-Parent-Id": account_id},
json={"code": code, "language": "python"}
)
# 新实现(使用 SDK
sandbox = CodeInterpreterSandbox(template_name="python-sandbox", config=config)
result = sandbox.context.execute(code="print('hello')")
```
### 3. **更好的错误处理**
- 结构化的异常类型 (`ServerError`)
- 自动重试机制
- 详细的错误信息
## 主要变更
### 1. 文件重命名
| 旧文件名 | 新文件名 | 说明 |
|---------|---------|------|
| `aliyun_opensandbox.py` | `aliyun_codeinterpreter.py` | 提供商实现 |
| `test_aliyun_provider.py` | `test_aliyun_codeinterpreter.py` | 单元测试 |
| `test_aliyun_integration.py` | `test_aliyun_codeinterpreter_integration.py` | 集成测试 |
### 2. 配置字段变更
#### 旧配置OpenSandbox
```json
{
"access_key_id": "LTAI5t...",
"access_key_secret": "...",
"region": "cn-hangzhou",
"workspace_id": "ws-xxxxx"
}
```
#### 新配置Code Interpreter
```json
{
"access_key_id": "LTAI5t...",
"access_key_secret": "...",
"account_id": "1234567890...", // 新增阿里云主账号ID必需
"region": "cn-hangzhou",
"template_name": "python-sandbox", // 新增:沙箱模板名称
"timeout": 30 // 最大 30 秒(硬限制)
}
```
### 3. 关键差异
| 特性 | OpenSandbox | Code Interpreter |
|------|-------------|-----------------|
| **API 端点** | `opensandbox.{region}.aliyuncs.com` | `agentrun.{region}.aliyuncs.com` (控制面) |
| **API 版本** | `2024-01-01` | `2025-09-10` |
| **认证** | 需要 AccessKey | 需要 AccessKey + 主账号ID |
| **请求头** | 标准签名 | 需要 `X-Acs-Parent-Id` 头 |
| **超时限制** | 可配置 | **最大 30 秒**(硬限制) |
| **上下文** | 不支持 | 支持上下文Jupyter kernel |
### 4. API 调用方式变更
#### 旧实现(假设的 OpenSandbox
```python
# 单一端点
API_ENDPOINT = "https://opensandbox.cn-hangzhou.aliyuncs.com"
# 简单的请求/响应
response = requests.post(
f"{API_ENDPOINT}/execute",
json={"code": "print('hello')", "language": "python"}
)
```
#### 新实现Code Interpreter
```python
# 控制面 API - 管理沙箱生命周期
CONTROL_ENDPOINT = "https://agentrun.cn-hangzhou.aliyuncs.com/2025-09-10"
# 数据面 API - 执行代码
DATA_ENDPOINT = "https://{account_id}.agentrun-data.cn-hangzhou.aliyuncs.com"
# 创建沙箱(控制面)
response = requests.post(
f"{CONTROL_ENDPOINT}/sandboxes",
headers={"X-Acs-Parent-Id": account_id},
json={"templateName": "python-sandbox"}
)
# 执行代码(数据面)
response = requests.post(
f"{DATA_ENDPOINT}/sandboxes/{sandbox_id}/execute",
headers={"X-Acs-Parent-Id": account_id},
json={"code": "print('hello')", "language": "python", "timeout": 30}
)
```
### 5. 迁移步骤
#### 步骤 1: 更新配置
如果您之前使用的是 `aliyun_opensandbox`
**旧配置**:
```json
{
"name": "sandbox.provider_type",
"value": "aliyun_opensandbox"
}
```
**新配置**:
```json
{
"name": "sandbox.provider_type",
"value": "aliyun_codeinterpreter"
}
```
#### 步骤 2: 添加必需的 account_id
在 Aliyun 控制台右上角点击头像,获取主账号 ID
1. 登录 [阿里云控制台](https://ram.console.aliyun.com/manage/ak)
2. 点击右上角头像
3. 复制主账号 ID16 位数字)
#### 步骤 3: 更新环境变量
```bash
# 新增必需的环境变量
export ALIYUN_ACCOUNT_ID="1234567890123456"
# 其他环境变量保持不变
export ALIYUN_ACCESS_KEY_ID="LTAI5t..."
export ALIYUN_ACCESS_KEY_SECRET="..."
export ALIYUN_REGION="cn-hangzhou"
```
#### 步骤 4: 运行测试
```bash
# 单元测试(不需要真实凭据)
pytest agent/sandbox/tests/test_aliyun_codeinterpreter.py -v
# 集成测试(需要真实凭据)
pytest agent/sandbox/tests/test_aliyun_codeinterpreter_integration.py -v -m integration
```
## 文件变更清单
### ✅ 已完成
- [x] 创建 `aliyun_codeinterpreter.py` - 新的提供商实现
- [x] 更新 `sandbox_spec.md` - 规范文档
- [x] 更新 `admin/services.py` - 服务管理器
- [x] 更新 `providers/__init__.py` - 包导出
- [x] 创建 `test_aliyun_codeinterpreter.py` - 单元测试
- [x] 创建 `test_aliyun_codeinterpreter_integration.py` - 集成测试
### 📝 可选清理
如果您想删除旧的 OpenSandbox 实现:
```bash
# 删除旧文件(可选)
rm agent/sandbox/providers/aliyun_opensandbox.py
rm agent/sandbox/tests/test_aliyun_provider.py
rm agent/sandbox/tests/test_aliyun_integration.py
```
**注意**: 保留旧文件不会影响新功能,只是代码冗余。
## API 参考
### 控制面 API沙箱管理
| 端点 | 方法 | 说明 |
|------|------|------|
| `/sandboxes` | POST | 创建沙箱实例 |
| `/sandboxes/{id}/stop` | POST | 停止实例 |
| `/sandboxes/{id}` | DELETE | 删除实例 |
| `/templates` | GET | 列出模板 |
### 数据面 API代码执行
| 端点 | 方法 | 说明 |
|------|------|------|
| `/sandboxes/{id}/execute` | POST | 执行代码(简化版) |
| `/sandboxes/{id}/contexts` | POST | 创建上下文 |
| `/sandboxes/{id}/contexts/{ctx_id}/execute` | POST | 在上下文中执行 |
| `/sandboxes/{id}/health` | GET | 健康检查 |
| `/sandboxes/{id}/files` | GET/POST | 文件读写 |
| `/sandboxes/{id}/processes/cmd` | POST | 执行 Shell 命令 |
## 常见问题
### Q: 为什么要添加 account_id
**A**: Code Interpreter API 需要在请求头中提供 `X-Acs-Parent-Id`阿里云主账号ID进行身份验证。这是 Aliyun Code Interpreter API 的必需参数。
### Q: 30 秒超时限制可以绕过吗?
**A**: 不可以。这是 Aliyun Code Interpreter 的**硬限制**,无法通过配置或请求参数绕过。如果代码执行时间超过 30 秒,请考虑:
1. 优化代码逻辑
2. 分批处理数据
3. 使用上下文保持状态
### Q: 旧的 OpenSandbox 配置还能用吗?
**A**: 不能。OpenSandbox 和 Code Interpreter 是两个不同的服务API 不兼容。必须迁移到新的配置格式。
### Q: 如何获取阿里云主账号 ID
**A**:
1. 登录阿里云控制台
2. 点击右上角的头像
3. 在弹出的信息中可以看到"主账号ID"
### Q: 迁移后会影响现有功能吗?
**A**:
- **自我管理提供商self_managed**: 不受影响
- **E2B 提供商**: 不受影响
- **Aliyun 提供商**: 需要更新配置并重新测试
## 相关文档
- [官方文档](https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter)
- [sandbox 规范](../docs/develop/sandbox_spec.md)
- [测试指南](./README.md)
- [快速开始](./QUICKSTART.md)
## 技术支持
如有问题,请:
1. 查看官方文档
2. 检查配置是否正确
3. 查看测试输出中的错误信息
4. 联系 RAGFlow 团队

View File

@ -1,178 +0,0 @@
# Aliyun OpenSandbox Provider - 快速测试指南
## 测试说明
### 1. 单元测试(不需要真实凭据)
单元测试使用 mock**不需要**真实的 Aliyun 凭据,可以随时运行。
```bash
# 运行 Aliyun 提供商的单元测试
pytest agent/sandbox/tests/test_aliyun_provider.py -v
# 预期输出:
# test_aliyun_provider.py::TestAliyunOpenSandboxProvider::test_provider_initialization PASSED
# test_aliyun_provider.py::TestAliyunOpenSandboxProvider::test_initialize_success PASSED
# ...
# ========================= 48 passed in 2.34s ==========================
```
### 2. 集成测试(需要真实凭据)
集成测试会调用真实的 Aliyun API需要配置凭据。
#### 步骤 1: 配置环境变量
```bash
export ALIYUN_ACCESS_KEY_ID="LTAI5t..." # 替换为真实的 Access Key ID
export ALIYUN_ACCESS_KEY_SECRET="..." # 替换为真实的 Access Key Secret
export ALIYUN_REGION="cn-hangzhou" # 可选,默认为 cn-hangzhou
```
#### 步骤 2: 运行集成测试
```bash
# 运行所有集成测试
pytest agent/sandbox/tests/test_aliyun_integration.py -v -m integration
# 运行特定测试
pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_health_check -v
```
#### 步骤 3: 预期输出
```
test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_initialize_provider PASSED
test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_health_check PASSED
test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_execute_python_code PASSED
...
========================== 10 passed in 15.67s ==========================
```
### 3. 测试场景
#### 基础功能测试
```bash
# 健康检查
pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_health_check -v
# 创建实例
pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_create_python_instance -v
# 执行代码
pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_execute_python_code -v
# 销毁实例
pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_destroy_instance -v
```
#### 错误处理测试
```bash
# 代码执行错误
pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_execute_python_code_with_error -v
# 超时处理
pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_execute_python_code_timeout -v
```
#### 真实场景测试
```bash
# 数据处理工作流
pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunRealWorldScenarios::test_data_processing_workflow -v
# 字符串操作
pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunRealWorldScenarios::test_string_manipulation -v
# 多次执行
pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunRealWorldScenarios::test_multiple_executions_same_instance -v
```
## 常见问题
### Q: 没有凭据怎么办?
**A:** 运行单元测试即可,不需要真实凭据:
```bash
pytest agent/sandbox/tests/test_aliyun_provider.py -v
```
### Q: 如何跳过集成测试?
**A:** 使用 pytest 标记跳过:
```bash
# 只运行单元测试,跳过集成测试
pytest agent/sandbox/tests/ -v -m "not integration"
```
### Q: 集成测试失败怎么办?
**A:** 检查以下几点:
1. **凭据是否正确**
```bash
echo $ALIYUN_ACCESS_KEY_ID
echo $ALIYUN_ACCESS_KEY_SECRET
```
2. **网络连接是否正常**
```bash
curl -I https://opensandbox.cn-hangzhou.aliyuncs.com
```
3. **是否有 OpenSandbox 服务权限**
- 登录阿里云控制台
- 检查是否已开通 OpenSandbox 服务
- 检查 AccessKey 权限
4. **查看详细错误信息**
```bash
pytest agent/sandbox/tests/test_aliyun_integration.py -v -s
```
### Q: 测试超时怎么办?
**A:** 增加超时时间或检查网络:
```bash
# 使用更长的超时
pytest agent/sandbox/tests/test_aliyun_integration.py -v --timeout=60
```
## 测试命令速查表
| 命令 | 说明 | 需要凭据 |
|------|------|---------|
| `pytest agent/sandbox/tests/test_aliyun_provider.py -v` | 单元测试 | ❌ |
| `pytest agent/sandbox/tests/test_aliyun_integration.py -v` | 集成测试 | ✅ |
| `pytest agent/sandbox/tests/ -v -m "not integration"` | 仅单元测试 | ❌ |
| `pytest agent/sandbox/tests/ -v -m integration` | 仅集成测试 | ✅ |
| `pytest agent/sandbox/tests/ -v` | 所有测试 | 部分需要 |
## 获取 Aliyun 凭据
1. 访问 [阿里云控制台](https://ram.console.aliyun.com/manage/ak)
2. 创建 AccessKey
3. 保存 AccessKey ID 和 AccessKey Secret
4. 设置环境变量
⚠️ **安全提示:**
- 不要在代码中硬编码凭据
- 使用环境变量或配置文件
- 定期轮换 AccessKey
- 限制 AccessKey 权限
## 下一步
1.**运行单元测试** - 验证代码逻辑
2. 🔧 **配置凭据** - 设置环境变量
3. 🚀 **运行集成测试** - 测试真实 API
4. 📊 **查看结果** - 确保所有测试通过
5. 🎯 **集成到系统** - 使用 admin API 配置提供商
## 需要帮助?
- 查看 [完整文档](README.md)
- 检查 [sandbox 规范](../../../../../docs/develop/sandbox_spec.md)
- 联系 RAGFlow 团队

View File

@ -1,213 +0,0 @@
# Sandbox Provider Tests
This directory contains tests for the RAGFlow sandbox provider system.
## Test Structure
```
tests/
├── pytest.ini # Pytest configuration
├── test_providers.py # Unit tests for all providers (mocked)
├── test_aliyun_provider.py # Unit tests for Aliyun provider (mocked)
├── test_aliyun_integration.py # Integration tests for Aliyun (real API)
└── sandbox_security_tests_full.py # Security tests for self-managed provider
```
## Test Types
### 1. Unit Tests (No Credentials Required)
Unit tests use mocks and don't require any external services or credentials.
**Files:**
- `test_providers.py` - Tests for base provider interface and manager
- `test_aliyun_provider.py` - Tests for Aliyun provider with mocked API calls
**Run unit tests:**
```bash
# Run all unit tests
pytest agent/sandbox/tests/test_providers.py -v
pytest agent/sandbox/tests/test_aliyun_provider.py -v
# Run specific test
pytest agent/sandbox/tests/test_aliyun_provider.py::TestAliyunOpenSandboxProvider::test_initialize_success -v
# Run all unit tests (skip integration)
pytest agent/sandbox/tests/ -v -m "not integration"
```
### 2. Integration Tests (Real Credentials Required)
Integration tests make real API calls to Aliyun OpenSandbox service.
**Files:**
- `test_aliyun_integration.py` - Tests with real Aliyun API calls
**Setup environment variables:**
```bash
export ALIYUN_ACCESS_KEY_ID="LTAI5t..."
export ALIYUN_ACCESS_KEY_SECRET="..."
export ALIYUN_REGION="cn-hangzhou" # Optional, defaults to cn-hangzhou
export ALIYUN_WORKSPACE_ID="ws-..." # Optional
```
**Run integration tests:**
```bash
# Run only integration tests
pytest agent/sandbox/tests/test_aliyun_integration.py -v -m integration
# Run all tests including integration
pytest agent/sandbox/tests/ -v
# Run specific integration test
pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_health_check -v
```
### 3. Security Tests
Security tests validate the security features of the self-managed sandbox provider.
**Files:**
- `sandbox_security_tests_full.py` - Comprehensive security tests
**Run security tests:**
```bash
# Run all security tests
pytest agent/sandbox/tests/sandbox_security_tests_full.py -v
# Run specific security test
pytest agent/sandbox/tests/sandbox_security_tests_full.py -k "test_dangerous_imports" -v
```
## Test Commands
### Quick Test Commands
```bash
# Run all sandbox tests (unit only, fast)
pytest agent/sandbox/tests/ -v -m "not integration" --tb=short
# Run tests with coverage
pytest agent/sandbox/tests/ -v --cov=agent.sandbox --cov-report=term-missing -m "not integration"
# Run tests and stop on first failure
pytest agent/sandbox/tests/ -v -x -m "not integration"
# Run tests in parallel (requires pytest-xdist)
pytest agent/sandbox/tests/ -v -n auto -m "not integration"
```
### Aliyun Provider Testing
```bash
# 1. Run unit tests (no credentials needed)
pytest agent/sandbox/tests/test_aliyun_provider.py -v
# 2. Set up credentials for integration tests
export ALIYUN_ACCESS_KEY_ID="your-key-id"
export ALIYUN_ACCESS_KEY_SECRET="your-secret"
export ALIYUN_REGION="cn-hangzhou"
# 3. Run integration tests (makes real API calls)
pytest agent/sandbox/tests/test_aliyun_integration.py -v
# 4. Test specific scenarios
pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_execute_python_code -v
pytest agent/sandbox/tests/test_aliyun_integration.py::TestAliyunRealWorldScenarios -v
```
## Understanding Test Results
### Unit Test Output
```
agent/sandbox/tests/test_aliyun_provider.py::TestAliyunOpenSandboxProvider::test_initialize_success PASSED
agent/sandbox/tests/test_aliyun_provider.py::TestAliyunOpenSandboxProvider::test_create_instance_python PASSED
...
========================== 48 passed in 2.34s ===========================
```
### Integration Test Output
```
agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_health_check PASSED
agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_create_python_instance PASSED
agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_execute_python_code PASSED
...
========================== 10 passed in 15.67s ===========================
```
**Note:** Integration tests will be skipped if credentials are not set:
```
agent/sandbox/tests/test_aliyun_integration.py::TestAliyunOpenSandboxIntegration::test_health_check SKIPPED
...
========================== 48 skipped, 10 passed in 0.12s ===========================
```
## Troubleshooting
### Integration Tests Fail
1. **Check credentials:**
```bash
echo $ALIYUN_ACCESS_KEY_ID
echo $ALIYUN_ACCESS_KEY_SECRET
```
2. **Check network connectivity:**
```bash
curl -I https://opensandbox.cn-hangzhou.aliyuncs.com
```
3. **Verify permissions:**
- Make sure your Aliyun account has OpenSandbox service enabled
- Check that your AccessKey has the required permissions
4. **Check region:**
- Verify the region is correct for your account
- Try different regions: cn-hangzhou, cn-beijing, cn-shanghai, etc.
### Tests Timeout
If tests timeout, increase the timeout in the test configuration or run with a longer timeout:
```bash
pytest agent/sandbox/tests/test_aliyun_integration.py -v --timeout=60
```
### Mock Tests Fail
If unit tests fail, it's likely a code issue, not a credentials issue:
1. Check the test error message
2. Review the code changes
3. Run with verbose output: `pytest -vv`
## Contributing
When adding new providers:
1. **Create unit tests** in `test_{provider}_provider.py` with mocks
2. **Create integration tests** in `test_{provider}_integration.py` with real API calls
3. **Add markers** to distinguish test types
4. **Update this README** with provider-specific testing instructions
Example:
```python
@pytest.mark.integration
def test_new_provider_real_api():
"""Test with real API calls."""
# Your test here
```
## Continuous Integration
In CI/CD pipelines:
```yaml
# Run unit tests only (fast, no credentials)
pytest agent/sandbox/tests/ -v -m "not integration"
# Run integration tests if credentials available
if [ -n "$ALIYUN_ACCESS_KEY_ID" ]; then
pytest agent/sandbox/tests/test_aliyun_integration.py -v -m integration
fi
```

View File

@ -1,33 +0,0 @@
[pytest]
# Pytest configuration for sandbox tests
# Test discovery patterns
python_files = test_*.py
python_classes = Test*
python_functions = test_*
# Markers for different test types
markers =
integration: Tests that require external services (Aliyun API, etc.)
unit: Fast tests that don't require external services
slow: Tests that take a long time to run
# Test paths
testpaths = .
# Minimum version
minversion = 7.0
# Output options
addopts =
-v
--strict-markers
--tb=short
--disable-warnings
# Log options
log_cli = false
log_cli_level = INFO
# Coverage options (if using pytest-cov)
# addopts = --cov=agent.sandbox --cov-report=html --cov-report=term

View File

@ -1,329 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Unit tests for Aliyun Code Interpreter provider.
These tests use mocks and don't require real Aliyun credentials.
Official Documentation: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter
Official SDK: https://github.com/Serverless-Devs/agentrun-sdk-python
"""
import pytest
from unittest.mock import patch, MagicMock
from agent.sandbox.providers.base import SandboxProvider
from agent.sandbox.providers.aliyun_codeinterpreter import AliyunCodeInterpreterProvider
class TestAliyunCodeInterpreterProvider:
"""Test AliyunCodeInterpreterProvider implementation."""
def test_provider_initialization(self):
"""Test provider initialization."""
provider = AliyunCodeInterpreterProvider()
assert provider.access_key_id == ""
assert provider.access_key_secret == ""
assert provider.account_id == ""
assert provider.region == "cn-hangzhou"
assert provider.template_name == ""
assert provider.timeout == 30
assert not provider._initialized
@patch("agent.sandbox.providers.aliyun_codeinterpreter.Template")
def test_initialize_success(self, mock_template):
"""Test successful initialization."""
# Mock health check response
mock_template.list.return_value = []
provider = AliyunCodeInterpreterProvider()
result = provider.initialize(
{
"access_key_id": "LTAI5tXXXXXXXXXX",
"access_key_secret": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX",
"account_id": "1234567890123456",
"region": "cn-hangzhou",
"template_name": "python-sandbox",
"timeout": 20,
}
)
assert result is True
assert provider.access_key_id == "LTAI5tXXXXXXXXXX"
assert provider.access_key_secret == "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
assert provider.account_id == "1234567890123456"
assert provider.region == "cn-hangzhou"
assert provider.template_name == "python-sandbox"
assert provider.timeout == 20
assert provider._initialized
def test_initialize_missing_credentials(self):
"""Test initialization with missing credentials."""
provider = AliyunCodeInterpreterProvider()
# Missing access_key_id
result = provider.initialize({"access_key_secret": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"})
assert result is False
# Missing access_key_secret
result = provider.initialize({"access_key_id": "LTAI5tXXXXXXXXXX"})
assert result is False
# Missing account_id
provider2 = AliyunCodeInterpreterProvider()
result = provider2.initialize({"access_key_id": "LTAI5tXXXXXXXXXX", "access_key_secret": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"})
assert result is False
@patch("agent.sandbox.providers.aliyun_codeinterpreter.Template")
def test_initialize_default_config(self, mock_template):
"""Test initialization with default config."""
mock_template.list.return_value = []
provider = AliyunCodeInterpreterProvider()
result = provider.initialize({"access_key_id": "LTAI5tXXXXXXXXXX", "access_key_secret": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", "account_id": "1234567890123456"})
assert result is True
assert provider.region == "cn-hangzhou"
assert provider.template_name == ""
@patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox")
def test_create_instance_python(self, mock_sandbox_class):
"""Test creating a Python instance."""
# Mock successful instance creation
mock_sandbox = MagicMock()
mock_sandbox.sandbox_id = "01JCED8Z9Y6XQVK8M2NRST5WXY"
mock_sandbox_class.return_value = mock_sandbox
provider = AliyunCodeInterpreterProvider()
provider._initialized = True
provider._config = MagicMock()
instance = provider.create_instance("python")
assert instance.provider == "aliyun_codeinterpreter"
assert instance.status == "READY"
assert instance.metadata["language"] == "python"
@patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox")
def test_create_instance_javascript(self, mock_sandbox_class):
"""Test creating a JavaScript instance."""
mock_sandbox = MagicMock()
mock_sandbox.sandbox_id = "01JCED8Z9Y6XQVK8M2NRST5WXY"
mock_sandbox_class.return_value = mock_sandbox
provider = AliyunCodeInterpreterProvider()
provider._initialized = True
provider._config = MagicMock()
instance = provider.create_instance("javascript")
assert instance.metadata["language"] == "javascript"
def test_create_instance_not_initialized(self):
"""Test creating instance when provider not initialized."""
provider = AliyunCodeInterpreterProvider()
with pytest.raises(RuntimeError, match="Provider not initialized"):
provider.create_instance("python")
@patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox")
def test_execute_code_success(self, mock_sandbox_class):
"""Test successful code execution."""
# Mock sandbox instance
mock_sandbox = MagicMock()
mock_sandbox.context.execute.return_value = {
"results": [{"type": "stdout", "text": "Hello, World!"}, {"type": "result", "text": "None"}, {"type": "endOfExecution", "status": "ok"}],
"contextId": "kernel-12345-67890",
}
mock_sandbox_class.return_value = mock_sandbox
provider = AliyunCodeInterpreterProvider()
provider._initialized = True
provider._config = MagicMock()
result = provider.execute_code(instance_id="01JCED8Z9Y6XQVK8M2NRST5WXY", code="print('Hello, World!')", language="python", timeout=10)
assert result.stdout == "Hello, World!"
assert result.stderr == ""
assert result.exit_code == 0
assert result.execution_time > 0
@patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox")
def test_execute_code_timeout(self, mock_sandbox_class):
"""Test code execution timeout."""
from agentrun.utils.exception import ServerError
mock_sandbox = MagicMock()
mock_sandbox.context.execute.side_effect = ServerError(408, "Request timeout")
mock_sandbox_class.return_value = mock_sandbox
provider = AliyunCodeInterpreterProvider()
provider._initialized = True
provider._config = MagicMock()
with pytest.raises(TimeoutError, match="Execution timed out"):
provider.execute_code(instance_id="01JCED8Z9Y6XQVK8M2NRST5WXY", code="while True: pass", language="python", timeout=5)
@patch("agent.sandbox.providers.aliyun_codeinterpreter.CodeInterpreterSandbox")
def test_execute_code_with_error(self, mock_sandbox_class):
"""Test code execution with error."""
mock_sandbox = MagicMock()
mock_sandbox.context.execute.return_value = {
"results": [{"type": "stderr", "text": "Traceback..."}, {"type": "error", "text": "NameError: name 'x' is not defined"}, {"type": "endOfExecution", "status": "error"}]
}
mock_sandbox_class.return_value = mock_sandbox
provider = AliyunCodeInterpreterProvider()
provider._initialized = True
provider._config = MagicMock()
result = provider.execute_code(instance_id="01JCED8Z9Y6XQVK8M2NRST5WXY", code="print(x)", language="python")
assert result.exit_code != 0
assert len(result.stderr) > 0
def test_get_supported_languages(self):
"""Test getting supported languages."""
provider = AliyunCodeInterpreterProvider()
languages = provider.get_supported_languages()
assert "python" in languages
assert "javascript" in languages
def test_get_config_schema(self):
"""Test getting configuration schema."""
schema = AliyunCodeInterpreterProvider.get_config_schema()
assert "access_key_id" in schema
assert schema["access_key_id"]["required"] is True
assert "access_key_secret" in schema
assert schema["access_key_secret"]["required"] is True
assert "account_id" in schema
assert schema["account_id"]["required"] is True
assert "region" in schema
assert "template_name" in schema
assert "timeout" in schema
def test_validate_config_success(self):
"""Test successful configuration validation."""
provider = AliyunCodeInterpreterProvider()
is_valid, error_msg = provider.validate_config({"access_key_id": "LTAI5tXXXXXXXXXX", "account_id": "1234567890123456", "region": "cn-hangzhou"})
assert is_valid is True
assert error_msg is None
def test_validate_config_invalid_access_key(self):
"""Test validation with invalid access key format."""
provider = AliyunCodeInterpreterProvider()
is_valid, error_msg = provider.validate_config({"access_key_id": "INVALID_KEY"})
assert is_valid is False
assert "AccessKey ID format" in error_msg
def test_validate_config_missing_account_id(self):
"""Test validation with missing account ID."""
provider = AliyunCodeInterpreterProvider()
is_valid, error_msg = provider.validate_config({})
assert is_valid is False
assert "Account ID" in error_msg
def test_validate_config_invalid_region(self):
"""Test validation with invalid region."""
provider = AliyunCodeInterpreterProvider()
is_valid, error_msg = provider.validate_config(
{
"access_key_id": "LTAI5tXXXXXXXXXX",
"account_id": "1234567890123456", # Provide required field
"region": "us-west-1",
}
)
assert is_valid is False
assert "Invalid region" in error_msg
def test_validate_config_invalid_timeout(self):
"""Test validation with invalid timeout (> 30 seconds)."""
provider = AliyunCodeInterpreterProvider()
is_valid, error_msg = provider.validate_config(
{
"access_key_id": "LTAI5tXXXXXXXXXX",
"account_id": "1234567890123456", # Provide required field
"timeout": 60,
}
)
assert is_valid is False
assert "Timeout must be between 1 and 30 seconds" in error_msg
def test_normalize_language_python(self):
"""Test normalizing Python language identifier."""
provider = AliyunCodeInterpreterProvider()
assert provider._normalize_language("python") == "python"
assert provider._normalize_language("python3") == "python"
assert provider._normalize_language("PYTHON") == "python"
def test_normalize_language_javascript(self):
"""Test normalizing JavaScript language identifier."""
provider = AliyunCodeInterpreterProvider()
assert provider._normalize_language("javascript") == "javascript"
assert provider._normalize_language("nodejs") == "javascript"
assert provider._normalize_language("JavaScript") == "javascript"
class TestAliyunCodeInterpreterInterface:
"""Test that Aliyun provider correctly implements the interface."""
def test_aliyun_provider_is_abstract(self):
"""Test that AliyunCodeInterpreterProvider is a SandboxProvider."""
provider = AliyunCodeInterpreterProvider()
assert isinstance(provider, SandboxProvider)
def test_aliyun_provider_has_abstract_methods(self):
"""Test that AliyunCodeInterpreterProvider implements all abstract methods."""
provider = AliyunCodeInterpreterProvider()
assert hasattr(provider, "initialize")
assert callable(provider.initialize)
assert hasattr(provider, "create_instance")
assert callable(provider.create_instance)
assert hasattr(provider, "execute_code")
assert callable(provider.execute_code)
assert hasattr(provider, "destroy_instance")
assert callable(provider.destroy_instance)
assert hasattr(provider, "health_check")
assert callable(provider.health_check)
assert hasattr(provider, "get_supported_languages")
assert callable(provider.get_supported_languages)

View File

@ -1,353 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Integration tests for Aliyun Code Interpreter provider.
These tests require real Aliyun credentials and will make actual API calls.
To run these tests, set the following environment variables:
export AGENTRUN_ACCESS_KEY_ID="LTAI5t..."
export AGENTRUN_ACCESS_KEY_SECRET="..."
export AGENTRUN_ACCOUNT_ID="1234567890..." # Aliyun primary account ID (主账号ID)
export AGENTRUN_REGION="cn-hangzhou" # Note: AGENTRUN_REGION (SDK will read this)
Then run:
pytest agent/sandbox/tests/test_aliyun_codeinterpreter_integration.py -v
Official Documentation: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter
"""
import os
import pytest
from agent.sandbox.providers.aliyun_codeinterpreter import AliyunCodeInterpreterProvider
# Skip all tests if credentials are not provided
pytestmark = pytest.mark.skipif(
not all(
[
os.getenv("AGENTRUN_ACCESS_KEY_ID"),
os.getenv("AGENTRUN_ACCESS_KEY_SECRET"),
os.getenv("AGENTRUN_ACCOUNT_ID"),
]
),
reason="Aliyun credentials not set. Set AGENTRUN_ACCESS_KEY_ID, AGENTRUN_ACCESS_KEY_SECRET, and AGENTRUN_ACCOUNT_ID.",
)
@pytest.fixture
def aliyun_config():
"""Get Aliyun configuration from environment variables."""
return {
"access_key_id": os.getenv("AGENTRUN_ACCESS_KEY_ID"),
"access_key_secret": os.getenv("AGENTRUN_ACCESS_KEY_SECRET"),
"account_id": os.getenv("AGENTRUN_ACCOUNT_ID"),
"region": os.getenv("AGENTRUN_REGION", "cn-hangzhou"),
"template_name": os.getenv("AGENTRUN_TEMPLATE_NAME", ""),
"timeout": 30,
}
@pytest.fixture
def provider(aliyun_config):
"""Create an initialized Aliyun provider."""
provider = AliyunCodeInterpreterProvider()
initialized = provider.initialize(aliyun_config)
if not initialized:
pytest.skip("Failed to initialize Aliyun provider. Check credentials, account ID, and network.")
return provider
@pytest.mark.integration
class TestAliyunCodeInterpreterIntegration:
"""Integration tests for Aliyun Code Interpreter provider."""
def test_initialize_provider(self, aliyun_config):
"""Test provider initialization with real credentials."""
provider = AliyunCodeInterpreterProvider()
result = provider.initialize(aliyun_config)
assert result is True
assert provider._initialized is True
def test_health_check(self, provider):
"""Test health check with real API."""
result = provider.health_check()
assert result is True
def test_get_supported_languages(self, provider):
"""Test getting supported languages."""
languages = provider.get_supported_languages()
assert "python" in languages
assert "javascript" in languages
assert isinstance(languages, list)
def test_create_python_instance(self, provider):
"""Test creating a Python sandbox instance."""
try:
instance = provider.create_instance("python")
assert instance.provider == "aliyun_codeinterpreter"
assert instance.status in ["READY", "CREATING"]
assert instance.metadata["language"] == "python"
assert len(instance.instance_id) > 0
# Clean up
provider.destroy_instance(instance.instance_id)
except Exception as e:
pytest.skip(f"Instance creation failed: {str(e)}. API might not be available yet.")
def test_execute_python_code(self, provider):
"""Test executing Python code in the sandbox."""
try:
# Create instance
instance = provider.create_instance("python")
# Execute simple code
result = provider.execute_code(
instance_id=instance.instance_id,
code="print('Hello from Aliyun Code Interpreter!')\nprint(42)",
language="python",
timeout=30, # Max 30 seconds
)
assert result.exit_code == 0
assert "Hello from Aliyun Code Interpreter!" in result.stdout
assert "42" in result.stdout
assert result.execution_time > 0
# Clean up
provider.destroy_instance(instance.instance_id)
except Exception as e:
pytest.skip(f"Code execution test failed: {str(e)}. API might not be available yet.")
def test_execute_python_code_with_arguments(self, provider):
"""Test executing Python code with arguments parameter."""
try:
# Create instance
instance = provider.create_instance("python")
# Execute code with arguments
result = provider.execute_code(
instance_id=instance.instance_id,
code="""def main(name: str, count: int) -> dict:
return {"message": f"Hello {name}!" * count}
""",
language="python",
timeout=30,
arguments={"name": "World", "count": 2}
)
assert result.exit_code == 0
assert "Hello World!Hello World!" in result.stdout
# Clean up
provider.destroy_instance(instance.instance_id)
except Exception as e:
pytest.skip(f"Arguments test failed: {str(e)}. API might not be available yet.")
def test_execute_python_code_with_error(self, provider):
"""Test executing Python code that produces an error."""
try:
# Create instance
instance = provider.create_instance("python")
# Execute code with error
result = provider.execute_code(instance_id=instance.instance_id, code="raise ValueError('Test error')", language="python", timeout=30)
assert result.exit_code != 0
assert len(result.stderr) > 0 or "ValueError" in result.stdout
# Clean up
provider.destroy_instance(instance.instance_id)
except Exception as e:
pytest.skip(f"Error handling test failed: {str(e)}. API might not be available yet.")
def test_execute_javascript_code(self, provider):
"""Test executing JavaScript code in the sandbox."""
try:
# Create instance
instance = provider.create_instance("javascript")
# Execute simple code
result = provider.execute_code(instance_id=instance.instance_id, code="console.log('Hello from JavaScript!');", language="javascript", timeout=30)
assert result.exit_code == 0
assert "Hello from JavaScript!" in result.stdout
# Clean up
provider.destroy_instance(instance.instance_id)
except Exception as e:
pytest.skip(f"JavaScript execution test failed: {str(e)}. API might not be available yet.")
def test_execute_javascript_code_with_arguments(self, provider):
"""Test executing JavaScript code with arguments parameter."""
try:
# Create instance
instance = provider.create_instance("javascript")
# Execute code with arguments
result = provider.execute_code(
instance_id=instance.instance_id,
code="""function main(args) {
const { name, count } = args;
return `Hello ${name}!`.repeat(count);
}""",
language="javascript",
timeout=30,
arguments={"name": "World", "count": 2}
)
assert result.exit_code == 0
assert "Hello World!Hello World!" in result.stdout
# Clean up
provider.destroy_instance(instance.instance_id)
except Exception as e:
pytest.skip(f"JavaScript arguments test failed: {str(e)}. API might not be available yet.")
def test_destroy_instance(self, provider):
"""Test destroying a sandbox instance."""
try:
# Create instance
instance = provider.create_instance("python")
# Destroy instance
result = provider.destroy_instance(instance.instance_id)
# Note: The API might return True immediately or async
assert result is True or result is False
except Exception as e:
pytest.skip(f"Destroy instance test failed: {str(e)}. API might not be available yet.")
def test_config_validation(self, provider):
"""Test configuration validation."""
# Valid config
is_valid, error = provider.validate_config({"access_key_id": "LTAI5tXXXXXXXXXX", "account_id": "1234567890123456", "region": "cn-hangzhou", "timeout": 30})
assert is_valid is True
assert error is None
# Invalid access key
is_valid, error = provider.validate_config({"access_key_id": "INVALID_KEY"})
assert is_valid is False
# Missing account ID
is_valid, error = provider.validate_config({})
assert is_valid is False
assert "Account ID" in error
def test_timeout_limit(self, provider):
"""Test that timeout is limited to 30 seconds."""
# Timeout > 30 should be clamped to 30
provider2 = AliyunCodeInterpreterProvider()
provider2.initialize(
{
"access_key_id": os.getenv("AGENTRUN_ACCESS_KEY_ID"),
"access_key_secret": os.getenv("AGENTRUN_ACCESS_KEY_SECRET"),
"account_id": os.getenv("AGENTRUN_ACCOUNT_ID"),
"timeout": 60, # Request 60 seconds
}
)
# Should be clamped to 30
assert provider2.timeout == 30
@pytest.mark.integration
class TestAliyunCodeInterpreterScenarios:
"""Test real-world usage scenarios."""
def test_data_processing_workflow(self, provider):
"""Test a simple data processing workflow."""
try:
instance = provider.create_instance("python")
# Execute data processing code
code = """
import json
data = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]
result = json.dumps(data, indent=2)
print(result)
"""
result = provider.execute_code(instance_id=instance.instance_id, code=code, language="python", timeout=30)
assert result.exit_code == 0
assert "Alice" in result.stdout
assert "Bob" in result.stdout
provider.destroy_instance(instance.instance_id)
except Exception as e:
pytest.skip(f"Data processing test failed: {str(e)}")
def test_string_manipulation(self, provider):
"""Test string manipulation operations."""
try:
instance = provider.create_instance("python")
code = """
text = "Hello, World!"
print(text.upper())
print(text.lower())
print(text.replace("World", "Aliyun"))
"""
result = provider.execute_code(instance_id=instance.instance_id, code=code, language="python", timeout=30)
assert result.exit_code == 0
assert "HELLO, WORLD!" in result.stdout
assert "hello, world!" in result.stdout
assert "Hello, Aliyun!" in result.stdout
provider.destroy_instance(instance.instance_id)
except Exception as e:
pytest.skip(f"String manipulation test failed: {str(e)}")
def test_context_persistence(self, provider):
"""Test code execution with context persistence."""
try:
instance = provider.create_instance("python")
# First execution - define variable
result1 = provider.execute_code(instance_id=instance.instance_id, code="x = 42\nprint(x)", language="python", timeout=30)
assert result1.exit_code == 0
# Second execution - use variable
# Note: Context persistence depends on whether the contextId is reused
result2 = provider.execute_code(instance_id=instance.instance_id, code="print(f'x is {x}')", language="python", timeout=30)
# Context might or might not persist depending on API implementation
assert result2.exit_code == 0
provider.destroy_instance(instance.instance_id)
except Exception as e:
pytest.skip(f"Context persistence test failed: {str(e)}")
def test_without_credentials():
"""Test that tests are skipped without credentials."""
# This test should always run (not skipped)
if all(
[
os.getenv("AGENTRUN_ACCESS_KEY_ID"),
os.getenv("AGENTRUN_ACCESS_KEY_SECRET"),
os.getenv("AGENTRUN_ACCOUNT_ID"),
]
):
assert True # Credentials are set
else:
assert True # Credentials not set, test still passes

View File

@ -1,423 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Unit tests for sandbox provider abstraction layer.
"""
import pytest
from unittest.mock import Mock, patch
import requests
from agent.sandbox.providers.base import SandboxProvider, SandboxInstance, ExecutionResult
from agent.sandbox.providers.manager import ProviderManager
from agent.sandbox.providers.self_managed import SelfManagedProvider
class TestSandboxDataclasses:
"""Test sandbox dataclasses."""
def test_sandbox_instance_creation(self):
"""Test SandboxInstance dataclass creation."""
instance = SandboxInstance(
instance_id="test-123",
provider="self_managed",
status="running",
metadata={"language": "python"}
)
assert instance.instance_id == "test-123"
assert instance.provider == "self_managed"
assert instance.status == "running"
assert instance.metadata == {"language": "python"}
def test_sandbox_instance_default_metadata(self):
"""Test SandboxInstance with None metadata."""
instance = SandboxInstance(
instance_id="test-123",
provider="self_managed",
status="running",
metadata=None
)
assert instance.metadata == {}
def test_execution_result_creation(self):
"""Test ExecutionResult dataclass creation."""
result = ExecutionResult(
stdout="Hello, World!",
stderr="",
exit_code=0,
execution_time=1.5,
metadata={"status": "success"}
)
assert result.stdout == "Hello, World!"
assert result.stderr == ""
assert result.exit_code == 0
assert result.execution_time == 1.5
assert result.metadata == {"status": "success"}
def test_execution_result_default_metadata(self):
"""Test ExecutionResult with None metadata."""
result = ExecutionResult(
stdout="output",
stderr="error",
exit_code=1,
execution_time=0.5,
metadata=None
)
assert result.metadata == {}
class TestProviderManager:
"""Test ProviderManager functionality."""
def test_manager_initialization(self):
"""Test ProviderManager initialization."""
manager = ProviderManager()
assert manager.current_provider is None
assert manager.current_provider_name is None
assert not manager.is_configured()
def test_set_provider(self):
"""Test setting a provider."""
manager = ProviderManager()
mock_provider = Mock(spec=SandboxProvider)
manager.set_provider("self_managed", mock_provider)
assert manager.current_provider == mock_provider
assert manager.current_provider_name == "self_managed"
assert manager.is_configured()
def test_get_provider(self):
"""Test getting the current provider."""
manager = ProviderManager()
mock_provider = Mock(spec=SandboxProvider)
manager.set_provider("self_managed", mock_provider)
assert manager.get_provider() == mock_provider
def test_get_provider_name(self):
"""Test getting the current provider name."""
manager = ProviderManager()
mock_provider = Mock(spec=SandboxProvider)
manager.set_provider("self_managed", mock_provider)
assert manager.get_provider_name() == "self_managed"
def test_get_provider_when_not_set(self):
"""Test getting provider when none is set."""
manager = ProviderManager()
assert manager.get_provider() is None
assert manager.get_provider_name() is None
class TestSelfManagedProvider:
"""Test SelfManagedProvider implementation."""
def test_provider_initialization(self):
"""Test provider initialization."""
provider = SelfManagedProvider()
assert provider.endpoint == "http://localhost:9385"
assert provider.timeout == 30
assert provider.max_retries == 3
assert provider.pool_size == 10
assert not provider._initialized
@patch('requests.get')
def test_initialize_success(self, mock_get):
"""Test successful initialization."""
mock_response = Mock()
mock_response.status_code = 200
mock_get.return_value = mock_response
provider = SelfManagedProvider()
result = provider.initialize({
"endpoint": "http://test-endpoint:9385",
"timeout": 60,
"max_retries": 5,
"pool_size": 20
})
assert result is True
assert provider.endpoint == "http://test-endpoint:9385"
assert provider.timeout == 60
assert provider.max_retries == 5
assert provider.pool_size == 20
assert provider._initialized
mock_get.assert_called_once_with("http://test-endpoint:9385/healthz", timeout=5)
@patch('requests.get')
def test_initialize_failure(self, mock_get):
"""Test initialization failure."""
mock_get.side_effect = Exception("Connection error")
provider = SelfManagedProvider()
result = provider.initialize({"endpoint": "http://invalid:9385"})
assert result is False
assert not provider._initialized
def test_initialize_default_config(self):
"""Test initialization with default config."""
with patch('requests.get') as mock_get:
mock_response = Mock()
mock_response.status_code = 200
mock_get.return_value = mock_response
provider = SelfManagedProvider()
result = provider.initialize({})
assert result is True
assert provider.endpoint == "http://localhost:9385"
assert provider.timeout == 30
def test_create_instance_python(self):
"""Test creating a Python instance."""
provider = SelfManagedProvider()
provider._initialized = True
instance = provider.create_instance("python")
assert instance.provider == "self_managed"
assert instance.status == "running"
assert instance.metadata["language"] == "python"
assert instance.metadata["endpoint"] == "http://localhost:9385"
assert len(instance.instance_id) > 0 # Verify instance_id exists
def test_create_instance_nodejs(self):
"""Test creating a Node.js instance."""
provider = SelfManagedProvider()
provider._initialized = True
instance = provider.create_instance("nodejs")
assert instance.metadata["language"] == "nodejs"
def test_create_instance_not_initialized(self):
"""Test creating instance when provider not initialized."""
provider = SelfManagedProvider()
with pytest.raises(RuntimeError, match="Provider not initialized"):
provider.create_instance("python")
@patch('requests.post')
def test_execute_code_success(self, mock_post):
"""Test successful code execution."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"status": "success",
"stdout": '{"result": 42}',
"stderr": "",
"exit_code": 0,
"time_used_ms": 100.0,
"memory_used_kb": 1024.0
}
mock_post.return_value = mock_response
provider = SelfManagedProvider()
provider._initialized = True
result = provider.execute_code(
instance_id="test-123",
code="def main(): return {'result': 42}",
language="python",
timeout=10
)
assert result.stdout == '{"result": 42}'
assert result.stderr == ""
assert result.exit_code == 0
assert result.execution_time > 0
assert result.metadata["status"] == "success"
assert result.metadata["instance_id"] == "test-123"
@patch('requests.post')
def test_execute_code_timeout(self, mock_post):
"""Test code execution timeout."""
mock_post.side_effect = requests.Timeout()
provider = SelfManagedProvider()
provider._initialized = True
with pytest.raises(TimeoutError, match="Execution timed out"):
provider.execute_code(
instance_id="test-123",
code="while True: pass",
language="python",
timeout=5
)
@patch('requests.post')
def test_execute_code_http_error(self, mock_post):
"""Test code execution with HTTP error."""
mock_response = Mock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_post.return_value = mock_response
provider = SelfManagedProvider()
provider._initialized = True
with pytest.raises(RuntimeError, match="HTTP 500"):
provider.execute_code(
instance_id="test-123",
code="invalid code",
language="python"
)
def test_execute_code_not_initialized(self):
"""Test executing code when provider not initialized."""
provider = SelfManagedProvider()
with pytest.raises(RuntimeError, match="Provider not initialized"):
provider.execute_code(
instance_id="test-123",
code="print('hello')",
language="python"
)
def test_destroy_instance(self):
"""Test destroying an instance (no-op for self-managed)."""
provider = SelfManagedProvider()
provider._initialized = True
# For self-managed, destroy_instance is a no-op
result = provider.destroy_instance("test-123")
assert result is True
@patch('requests.get')
def test_health_check_success(self, mock_get):
"""Test successful health check."""
mock_response = Mock()
mock_response.status_code = 200
mock_get.return_value = mock_response
provider = SelfManagedProvider()
result = provider.health_check()
assert result is True
mock_get.assert_called_once_with("http://localhost:9385/healthz", timeout=5)
@patch('requests.get')
def test_health_check_failure(self, mock_get):
"""Test health check failure."""
mock_get.side_effect = Exception("Connection error")
provider = SelfManagedProvider()
result = provider.health_check()
assert result is False
def test_get_supported_languages(self):
"""Test getting supported languages."""
provider = SelfManagedProvider()
languages = provider.get_supported_languages()
assert "python" in languages
assert "nodejs" in languages
assert "javascript" in languages
def test_get_config_schema(self):
"""Test getting configuration schema."""
schema = SelfManagedProvider.get_config_schema()
assert "endpoint" in schema
assert schema["endpoint"]["type"] == "string"
assert schema["endpoint"]["required"] is True
assert schema["endpoint"]["default"] == "http://localhost:9385"
assert "timeout" in schema
assert schema["timeout"]["type"] == "integer"
assert schema["timeout"]["default"] == 30
assert "max_retries" in schema
assert schema["max_retries"]["type"] == "integer"
assert "pool_size" in schema
assert schema["pool_size"]["type"] == "integer"
def test_normalize_language_python(self):
"""Test normalizing Python language identifier."""
provider = SelfManagedProvider()
assert provider._normalize_language("python") == "python"
assert provider._normalize_language("python3") == "python"
assert provider._normalize_language("PYTHON") == "python"
assert provider._normalize_language("Python3") == "python"
def test_normalize_language_javascript(self):
"""Test normalizing JavaScript language identifier."""
provider = SelfManagedProvider()
assert provider._normalize_language("javascript") == "nodejs"
assert provider._normalize_language("nodejs") == "nodejs"
assert provider._normalize_language("JavaScript") == "nodejs"
assert provider._normalize_language("NodeJS") == "nodejs"
def test_normalize_language_default(self):
"""Test language normalization with empty/unknown input."""
provider = SelfManagedProvider()
assert provider._normalize_language("") == "python"
assert provider._normalize_language(None) == "python"
assert provider._normalize_language("unknown") == "unknown"
class TestProviderInterface:
"""Test that providers correctly implement the interface."""
def test_self_managed_provider_is_abstract(self):
"""Test that SelfManagedProvider is a SandboxProvider."""
provider = SelfManagedProvider()
assert isinstance(provider, SandboxProvider)
def test_self_managed_provider_has_abstract_methods(self):
"""Test that SelfManagedProvider implements all abstract methods."""
provider = SelfManagedProvider()
# Check all abstract methods are implemented
assert hasattr(provider, 'initialize')
assert callable(provider.initialize)
assert hasattr(provider, 'create_instance')
assert callable(provider.create_instance)
assert hasattr(provider, 'execute_code')
assert callable(provider.execute_code)
assert hasattr(provider, 'destroy_instance')
assert callable(provider.destroy_instance)
assert hasattr(provider, 'health_check')
assert callable(provider.health_check)
assert hasattr(provider, 'get_supported_languages')
assert callable(provider.get_supported_languages)

View File

@ -1,78 +0,0 @@
#!/usr/bin/env python3
"""
Quick verification script for Aliyun Code Interpreter provider using official SDK.
"""
import importlib.util
import sys
sys.path.insert(0, ".")
print("=" * 60)
print("Aliyun Code Interpreter Provider - SDK Verification")
print("=" * 60)
# Test 1: Import provider
print("\n[1/5] Testing provider import...")
try:
from agent.sandbox.providers.aliyun_codeinterpreter import AliyunCodeInterpreterProvider
print("✓ Provider imported successfully")
except ImportError as e:
print(f"✗ Import failed: {e}")
sys.exit(1)
# Test 2: Check provider class
print("\n[2/5] Testing provider class...")
provider = AliyunCodeInterpreterProvider()
assert hasattr(provider, "initialize")
assert hasattr(provider, "create_instance")
assert hasattr(provider, "execute_code")
assert hasattr(provider, "destroy_instance")
assert hasattr(provider, "health_check")
print("✓ Provider has all required methods")
# Test 3: Check SDK imports
print("\n[3/5] Testing SDK imports...")
try:
# Check if agentrun SDK is available using importlib
if (
importlib.util.find_spec("agentrun.sandbox") is None
or importlib.util.find_spec("agentrun.utils.config") is None
or importlib.util.find_spec("agentrun.utils.exception") is None
):
raise ImportError("agentrun SDK not found")
# Verify imports work (assign to _ to indicate they're intentionally unused)
from agentrun.sandbox import CodeInterpreterSandbox, TemplateType, CodeLanguage
from agentrun.utils.config import Config
from agentrun.utils.exception import ServerError
_ = (CodeInterpreterSandbox, TemplateType, CodeLanguage, Config, ServerError)
print("✓ SDK modules imported successfully")
except ImportError as e:
print(f"✗ SDK import failed: {e}")
sys.exit(1)
# Test 4: Check config schema
print("\n[4/5] Testing configuration schema...")
schema = AliyunCodeInterpreterProvider.get_config_schema()
required_fields = ["access_key_id", "access_key_secret", "account_id"]
for field in required_fields:
assert field in schema
assert schema[field]["required"] is True
print(f"✓ All required fields present: {', '.join(required_fields)}")
# Test 5: Check supported languages
print("\n[5/5] Testing supported languages...")
languages = provider.get_supported_languages()
assert "python" in languages
assert "javascript" in languages
print(f"✓ Supported languages: {', '.join(languages)}")
print("\n" + "=" * 60)
print("All verification tests passed! ✓")
print("=" * 60)
print("\nNote: This provider now uses the official agentrun-sdk.")
print("SDK Documentation: https://github.com/Serverless-Devs/agentrun-sdk-python")
print("API Documentation: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter")

View File

@ -193,7 +193,7 @@
"presence_penalty": 0.4,
"prompts": [
{
"content": "Text Content:\n{Extractor:NineTiesSin@chunks}\n",
"content": "Text Content:\n{Splitter:NineTiesSin@chunks}\n",
"role": "user"
}
],
@ -226,7 +226,7 @@
"presence_penalty": 0.4,
"prompts": [
{
"content": "Text Content:\n\n{Extractor:TastyPointsLay@chunks}\n",
"content": "Text Content:\n\n{Splitter:TastyPointsLay@chunks}\n",
"role": "user"
}
],
@ -259,7 +259,7 @@
"presence_penalty": 0.4,
"prompts": [
{
"content": "Content: \n\n{Extractor:BlueResultsWink@chunks}",
"content": "Content: \n\n{Splitter:CuteBusesBet@chunks}",
"role": "user"
}
],
@ -485,7 +485,7 @@
"outputs": {},
"presencePenaltyEnabled": false,
"presence_penalty": 0.4,
"prompts": "Text Content:\n{Extractor:NineTiesSin@chunks}\n",
"prompts": "Text Content:\n{Splitter:NineTiesSin@chunks}\n",
"sys_prompt": "Role\nYou are a text analyzer.\n\nTask\nExtract the most important keywords/phrases of a given piece of text content.\n\nRequirements\n- Summarize the text content, and give the top 5 important keywords/phrases.\n- The keywords MUST be in the same language as the given piece of text content.\n- The keywords are delimited by ENGLISH COMMA.\n- Output keywords ONLY.",
"temperature": 0.1,
"temperatureEnabled": false,
@ -522,7 +522,7 @@
"outputs": {},
"presencePenaltyEnabled": false,
"presence_penalty": 0.4,
"prompts": "Text Content:\n\n{Extractor:TastyPointsLay@chunks}\n",
"prompts": "Text Content:\n\n{Splitter:TastyPointsLay@chunks}\n",
"sys_prompt": "Role\nYou are a text analyzer.\n\nTask\nPropose 3 questions about a given piece of text content.\n\nRequirements\n- Understand and summarize the text content, and propose the top 3 important questions.\n- The questions SHOULD NOT have overlapping meanings.\n- The questions SHOULD cover the main content of the text as much as possible.\n- The questions MUST be in the same language as the given piece of text content.\n- One question per line.\n- Output questions ONLY.",
"temperature": 0.1,
"temperatureEnabled": false,
@ -559,7 +559,7 @@
"outputs": {},
"presencePenaltyEnabled": false,
"presence_penalty": 0.4,
"prompts": "Content: \n\n{Extractor:BlueResultsWink@chunks}",
"prompts": "Content: \n\n{Splitter:BlueResultsWink@chunks}",
"sys_prompt": "Extract important structured information from the given content. Output ONLY a valid JSON string with no additional text. If no important structured information is found, output an empty JSON object: {}.\n\nImportant structured information may include: names, dates, locations, events, key facts, numerical data, or other extractable entities.",
"temperature": 0.1,
"temperatureEnabled": false,

View File

@ -5,9 +5,9 @@
"de": "Wählen Sie Ihren Wissensdatenbank Agenten",
"zh": "选择知识库智能体"},
"description": {
"en": "This Agent generates responses solely from the specified dataset (knowledge base). You are required to select a knowledge base from the dropdown when running the Agent.",
"de": "Dieser Agent erzeugt Antworten ausschließlich aus dem angegebenen Datensatz (Wissensdatenbank). Beim Ausführen des Agents müssen Sie eine Wissensdatenbank aus dem Dropdown-Menü auswählen.",
"zh": "本工作流仅根据指定知识库内容生成回答。运行时,请在下拉菜单选择需要查询的知识库。"},
"en": "Select your desired knowledge base from the dropdown menu. The Agent will only retrieve from the selected knowledge base and use this content to generate responses.",
"de": "Wählen Sie Ihre gewünschte Wissensdatenbank aus dem Dropdown-Menü. Der Agent ruft nur Informationen aus der ausgewählten Wissensdatenbank ab und verwendet diesen Inhalt zur Generierung von Antworten.",
"zh": "从下拉菜单中选择知识库,智能体将仅根据所选知识库内容生成回答。"},
"canvas_type": "Agent",
"dsl": {
"components": {
@ -387,10 +387,10 @@
{
"data": {
"form": {
"text": "This Agent generates responses solely from the specified dataset (knowledge base). \nYou are required to select a knowledge base from the dropdown when running the Agent."
"text": "Select your desired knowledge base from the dropdown menu. \nThe Agent will only retrieve from the selected knowledge base and use this content to generate responses."
},
"label": "Note",
"name": "Workflow description"
"name": "Workflow overall description"
},
"dragHandle": ".note-drag-handle",
"dragging": false,

View File

@ -5,9 +5,9 @@
"de": "Wählen Sie Ihren Wissensdatenbank Workflow",
"zh": "选择知识库工作流"},
"description": {
"en": "This Agent generates responses solely from the specified dataset (knowledge base). You are required to select a knowledge base from the dropdown when running the Agent.",
"de": "Dieser Agent erzeugt Antworten ausschließlich aus dem angegebenen Datensatz (Wissensdatenbank). Beim Ausführen des Agents müssen Sie eine Wissensdatenbank aus dem Dropdown-Menü auswählen.",
"zh": "工作流仅根据指定知识库内容生成回答。运行时,请在下拉菜单选择需要查询的知识库。"},
"en": "Select your desired knowledge base from the dropdown menu. The retrieval assistant will only use data from your selected knowledge base to generate responses.",
"de": "Wählen Sie Ihre gewünschte Wissensdatenbank aus dem Dropdown-Menü. Der Abrufassistent verwendet nur Daten aus Ihrer ausgewählten Wissensdatenbank, um Antworten zu generieren.",
"zh": "从下拉菜单中选择知识库,工作流仅根据所选知识库内容生成回答。"},
"canvas_type": "Other",
"dsl": {
"components": {
@ -334,10 +334,10 @@
{
"data": {
"form": {
"text": "This Agent generates responses solely from the specified dataset (knowledge base). \nYou are required to select a knowledge base from the dropdown when running the Agent."
"text": "Select your desired knowledge base from the dropdown menu. \nThe retrieval assistant will only use data from your selected knowledge base to generate responses."
},
"label": "Note",
"name": "Workflow description"
"name": "Workflow overall description"
},
"dragHandle": ".note-drag-handle",
"dragging": false,

View File

@ -2,12 +2,10 @@
"id": 27,
"title": {
"en": "Interactive Agent",
"de": "Interaktiver Agent",
"zh": "可交互的 Agent"
},
"description": {
"en": "During the Agents execution, users can actively intervene and interact with the Agent to adjust or guide its output, ensuring the final result aligns with their intentions.",
"de": "Wahrend der Ausführung des Agenten können Benutzer aktiv eingreifen und mit dem Agenten interagieren, um dessen Ausgabe zu steuern, sodass das Endergebnis ihren Vorstellungen entspricht.",
"zh": "在 Agent 的运行过程中,用户可以随时介入,与 Agent 进行交互,以调整或引导生成结果,使最终输出更符合预期。"
},
"canvas_type": "Agent",

View File

@ -27,10 +27,6 @@ from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
from timeit import default_timer as timer
from common.misc_utils import thread_pool_exec
class ToolParameter(TypedDict):
type: str
description: str
@ -60,12 +56,12 @@ class LLMToolPluginCallSession(ToolCallSession):
st = timer()
tool_obj = self.tools_map[name]
if isinstance(tool_obj, MCPToolCallSession):
resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, 60)
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
else:
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
resp = await tool_obj.invoke_async(**arguments)
else:
resp = await thread_pool_exec(tool_obj.invoke, **arguments)
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
self.callback(name, arguments, resp, elapsed_time=timer()-st)
return resp
@ -126,7 +122,6 @@ class ToolParamBase(ComponentParamBase):
class ToolBase(ComponentBase):
def __init__(self, canvas, id, param: ComponentParamBase):
from agent.canvas import Canvas # Local import to avoid cyclic dependency
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
self._canvas = canvas
self._id = id
@ -169,7 +164,7 @@ class ToolBase(ComponentBase):
elif asyncio.iscoroutinefunction(self._invoke):
res = await self._invoke(**kwargs)
else:
res = await thread_pool_exec(self._invoke, **kwargs)
res = await asyncio.to_thread(self._invoke, **kwargs)
except Exception as e:
self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e)

View File

@ -110,7 +110,7 @@ module.exports = { main };
self.lang = Language.PYTHON.value
self.script = 'def main(arg1: str, arg2: str) -> dict: return {"result": arg1 + arg2}'
self.arguments = {}
self.outputs = {"result": {"value": "", "type": "object"}}
self.outputs = {"result": {"value": "", "type": "string"}}
def check(self):
self.check_valid_value(self.lang, "Support languages", ["python", "python3", "nodejs", "javascript"])
@ -140,61 +140,26 @@ class CodeExec(ToolBase, ABC):
continue
arguments[k] = self._canvas.get_variable_value(v) if v else None
return self._execute_code(language=lang, code=script, arguments=arguments)
self._execute_code(language=lang, code=script, arguments=arguments)
def _execute_code(self, language: str, code: str, arguments: dict):
import requests
if self.check_if_canceled("CodeExec execution"):
return self.output()
return
try:
# Try using the new sandbox provider system first
try:
from agent.sandbox.client import execute_code as sandbox_execute_code
if self.check_if_canceled("CodeExec execution"):
return
# Execute code using the provider system
result = sandbox_execute_code(
code=code,
language=language,
timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)),
arguments=arguments
)
if self.check_if_canceled("CodeExec execution"):
return
# Process the result
if result.stderr:
self.set_output("_ERROR", result.stderr)
return
parsed_stdout = self._deserialize_stdout(result.stdout)
logging.info(f"[CodeExec]: Provider system -> {parsed_stdout}")
self._populate_outputs(parsed_stdout, result.stdout)
return
except (ImportError, RuntimeError) as provider_error:
# Provider system not available or not configured, fall back to HTTP
logging.info(f"[CodeExec]: Provider system not available, using HTTP fallback: {provider_error}")
# Fallback to direct HTTP request
code_b64 = self._encode_code(code)
code_req = CodeExecutionRequest(code_b64=code_b64, language=language, arguments=arguments).model_dump()
except Exception as e:
if self.check_if_canceled("CodeExec execution"):
return self.output()
return
self.set_output("_ERROR", "construct code request error: " + str(e))
return self.output()
try:
if self.check_if_canceled("CodeExec execution"):
self.set_output("_ERROR", "Task has been canceled")
return self.output()
return "Task has been canceled"
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:")
@ -209,18 +174,17 @@ class CodeExec(ToolBase, ABC):
stderr = body.get("stderr")
if stderr:
self.set_output("_ERROR", stderr)
return self.output()
return
raw_stdout = body.get("stdout", "")
parsed_stdout = self._deserialize_stdout(raw_stdout)
logging.info(f"[CodeExec]: http://{settings.SANDBOX_HOST}:9385/run -> {parsed_stdout}")
self._populate_outputs(parsed_stdout, raw_stdout)
else:
self.set_output("_ERROR", "There is no response from sandbox")
return self.output()
except Exception as e:
if self.check_if_canceled("CodeExec execution"):
return self.output()
return
self.set_output("_ERROR", "Exception executing code: " + str(e))
@ -331,8 +295,6 @@ class CodeExec(ToolBase, ABC):
if key.startswith("_"):
continue
val = self._get_by_path(parsed_stdout, key)
if val is None and len(outputs_items) == 1:
val = parsed_stdout
coerced = self._coerce_output_value(val, meta.get("type"))
logging.info(f"[CodeExec]: populate dict key='{key}' raw='{val}' coerced='{coerced}'")
self.set_output(key, coerced)

View File

@ -53,7 +53,7 @@ class ExeSQLParam(ToolParamBase):
self.max_records = 1024
def check(self):
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2', 'trino', 'oceanbase'])
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2', 'trino'])
self.check_empty(self.database, "Database name")
self.check_empty(self.username, "database username")
self.check_empty(self.host, "IP Address")
@ -86,12 +86,6 @@ class ExeSQL(ToolBase, ABC):
def convert_decimals(obj):
from decimal import Decimal
import math
if isinstance(obj, float):
# Handle NaN and Infinity which are not valid JSON values
if math.isnan(obj) or math.isinf(obj):
return None
return obj
if isinstance(obj, Decimal):
return float(obj) # 或 str(obj)
elif isinstance(obj, dict):
@ -126,9 +120,6 @@ class ExeSQL(ToolBase, ABC):
if self._param.db_type in ["mysql", "mariadb"]:
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)
elif self._param.db_type == 'oceanbase':
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password, charset='utf8mb4')
elif self._param.db_type == 'postgres':
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)

View File

@ -21,12 +21,12 @@ import re
from abc import ABC
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from common.constants import LLMType
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.document_service import DocumentService
from common.metadata_utils import apply_meta_data_filter
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.memory_service import MemoryService
from api.db.joint_services import memory_message_service
from api.db.joint_services.memory_message_service import query_message
from common import settings
from common.connection_utils import timeout
from rag.app.tag import label_question
@ -125,7 +125,7 @@ class Retrieval(ToolBase, ABC):
doc_ids = []
if self._param.meta_data_filter != {}:
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
metas = DocumentService.get_meta_by_kbs(kb_ids)
def _resolve_manual_filter(flt: dict) -> dict:
pat = re.compile(self.variable_ref_patt)
@ -174,7 +174,7 @@ class Retrieval(ToolBase, ABC):
if kbs:
query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE)
kbinfos = await settings.retriever.retrieval(
kbinfos = settings.retriever.retrieval(
query,
embd_mdl,
[kb.tenant_id for kb in kbs],
@ -193,7 +193,7 @@ class Retrieval(ToolBase, ABC):
if self._param.toc_enhance:
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
cks = await settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
chat_mdl, self._param.top_n)
if self.check_if_canceled("Retrieval processing"):
return
@ -202,7 +202,7 @@ class Retrieval(ToolBase, ABC):
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
[kb.tenant_id for kb in kbs])
if self._param.use_kg:
ck = await settings.kg_retriever.retrieval(query,
ck = settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
@ -215,7 +215,7 @@ class Retrieval(ToolBase, ABC):
kbinfos = {"chunks": [], "doc_aggs": []}
if self._param.use_kg and kbs:
ck = await settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return
@ -259,36 +259,36 @@ class Retrieval(ToolBase, ABC):
vars = {k: o["value"] for k, o in vars.items()}
query = self.string_format(query_text, vars)
# query message
message_list = memory_message_service.query_message({"memory_id": memory_ids}, {
message_list = query_message({"memory_id": memory_ids}, {
"query": query,
"similarity_threshold": self._param.similarity_threshold,
"keywords_similarity_weight": self._param.keywords_similarity_weight,
"top_n": self._param.top_n
})
print(f"found {len(message_list)} messages.")
if not message_list:
self.set_output("formalized_content", self._param.empty_response)
return ""
return
formated_content = "\n".join(memory_prompt(message_list, 200000))
# set formalized_content output
self.set_output("formalized_content", formated_content)
print(f"formated_content {formated_content}")
return formated_content
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Retrieval processing"):
return
print(f"debug retrieval, query is {kwargs.get('query')}.", flush=True)
if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response)
return
if hasattr(self._param, "retrieval_from") and self._param.retrieval_from == "dataset":
if self._param.kb_ids:
return await self._retrieve_kb(kwargs["query"])
elif hasattr(self._param, "retrieval_from") and self._param.retrieval_from == "memory":
return await self._retrieve_memory(kwargs["query"])
elif self._param.kb_ids:
return await self._retrieve_kb(kwargs["query"])
elif hasattr(self._param, "memory_ids") and self._param.memory_ids:
elif self._param.memory_ids:
return await self._retrieve_memory(kwargs["query"])
else:
self.set_output("formalized_content", self._param.empty_response)

View File

@ -0,0 +1 @@
from .deep_research import DeepResearcher as DeepResearcher

View File

@ -0,0 +1,238 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
from functools import partial
from agentic_reasoning.prompts import BEGIN_SEARCH_QUERY, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT, MAX_SEARCH_LIMIT, \
END_SEARCH_QUERY, REASON_PROMPT, RELEVANT_EXTRACTION_PROMPT
from api.db.services.llm_service import LLMBundle
from rag.nlp import extract_between
from rag.prompts import kb_prompt
from rag.utils.tavily_conn import Tavily
class DeepResearcher:
def __init__(self,
chat_mdl: LLMBundle,
prompt_config: dict,
kb_retrieve: partial = None,
kg_retrieve: partial = None
):
self.chat_mdl = chat_mdl
self.prompt_config = prompt_config
self._kb_retrieve = kb_retrieve
self._kg_retrieve = kg_retrieve
def _remove_tags(text: str, start_tag: str, end_tag: str) -> str:
"""General Tag Removal Method"""
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
return re.sub(pattern, "", text)
@staticmethod
def _remove_query_tags(text: str) -> str:
"""Remove Query Tags"""
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
@staticmethod
def _remove_result_tags(text: str) -> str:
"""Remove Result Tags"""
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT)
async def _generate_reasoning(self, msg_history):
"""Generate reasoning steps"""
query_think = ""
if msg_history[-1]["role"] != "user":
msg_history.append({"role": "user", "content": "Continues reasoning with the new information.\n"})
else:
msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if not ans:
continue
query_think = ans
yield query_think
query_think = ""
yield query_think
def _extract_search_queries(self, query_think, question, step_index):
"""Extract search queries from thinking"""
queries = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
if not queries and step_index == 0:
# If this is the first step and no queries are found, use the original question as the query
queries = [question]
return queries
def _truncate_previous_reasoning(self, all_reasoning_steps):
"""Truncate previous reasoning steps to maintain a reasonable length"""
truncated_prev_reasoning = ""
for i, step in enumerate(all_reasoning_steps):
truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n"
prev_steps = truncated_prev_reasoning.split('\n\n')
if len(prev_steps) <= 5:
truncated_prev_reasoning = '\n\n'.join(prev_steps)
else:
truncated_prev_reasoning = ''
for i, step in enumerate(prev_steps):
if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step:
truncated_prev_reasoning += step + '\n\n'
else:
if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
truncated_prev_reasoning += '...\n\n'
return truncated_prev_reasoning.strip('\n')
def _retrieve_information(self, search_query):
"""Retrieve information from different sources"""
# 1. Knowledge base retrieval
kbinfos = []
try:
kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []}
except Exception as e:
logging.error(f"Knowledge base retrieval error: {e}")
# 2. Web retrieval (if Tavily API is configured)
try:
if self.prompt_config.get("tavily_api_key"):
tav = Tavily(self.prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(search_query)
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
except Exception as e:
logging.error(f"Web retrieval error: {e}")
# 3. Knowledge graph retrieval (if configured)
try:
if self.prompt_config.get("use_kg") and self._kg_retrieve:
ck = self._kg_retrieve(question=search_query)
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
except Exception as e:
logging.error(f"Knowledge graph retrieval error: {e}")
return kbinfos
def _update_chunk_info(self, chunk_info, kbinfos):
"""Update chunk information for citations"""
if not chunk_info["chunks"]:
# If this is the first retrieval, use the retrieval results directly
for k in chunk_info.keys():
chunk_info[k] = kbinfos[k]
else:
# Merge newly retrieved information, avoiding duplicates
cids = [c["chunk_id"] for c in chunk_info["chunks"]]
for c in kbinfos["chunks"]:
if c["chunk_id"] not in cids:
chunk_info["chunks"].append(c)
dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
for d in kbinfos["doc_aggs"]:
if d["doc_id"] not in dids:
chunk_info["doc_aggs"].append(d)
async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
"""Extract and summarize relevant information"""
summary_think = ""
async for ans in self.chat_mdl.async_chat_streamly(
RELEVANT_EXTRACTION_PROMPT.format(
prev_reasoning=truncated_prev_reasoning,
search_query=search_query,
document="\n".join(kb_prompt(kbinfos, 4096))
),
[{"role": "user",
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
{"temperature": 0.7}):
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if not ans:
continue
summary_think = ans
yield summary_think
summary_think = ""
yield summary_think
async def thinking(self, chunk_info: dict, question: str):
executed_search_queries = []
msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
all_reasoning_steps = []
think = "<think>"
for step_index in range(MAX_SEARCH_LIMIT + 1):
# Check if the maximum search limit has been reached
if step_index == MAX_SEARCH_LIMIT - 1:
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n"
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
all_reasoning_steps.append(summary_think)
msg_history.append({"role": "assistant", "content": summary_think})
break
# Step 1: Generate reasoning
query_think = ""
async for ans in self._generate_reasoning(msg_history):
query_think = ans
yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
think += self._remove_query_tags(query_think)
all_reasoning_steps.append(query_think)
# Step 2: Extract search queries
queries = self._extract_search_queries(query_think, question, step_index)
if not queries and step_index > 0:
# If not the first step and no queries, end the search process
break
# Process each search query
for search_query in queries:
logging.info(f"[THINK]Query: {step_index}. {search_query}")
msg_history.append({"role": "assistant", "content": search_query})
think += f"\n\n> {step_index + 1}. {search_query}\n\n"
yield {"answer": think + "</think>", "reference": {}, "audio_binary": None}
# Check if the query has already been executed
if search_query in executed_search_queries:
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n"
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
all_reasoning_steps.append(summary_think)
msg_history.append({"role": "user", "content": summary_think})
think += summary_think
continue
executed_search_queries.append(search_query)
# Step 3: Truncate previous reasoning steps
truncated_prev_reasoning = self._truncate_previous_reasoning(all_reasoning_steps)
# Step 4: Retrieve information
kbinfos = self._retrieve_information(search_query)
# Step 5: Update chunk information
self._update_chunk_info(chunk_info, kbinfos)
# Step 6: Extract relevant information
think += "\n\n"
summary_think = ""
async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
summary_think = ans
yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}
all_reasoning_steps.append(summary_think)
msg_history.append(
{"role": "user", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"})
think += self._remove_result_tags(summary_think)
logging.info(f"[THINK]Summary: {step_index}. {summary_think}")
yield think + "</think>"

View File

@ -0,0 +1,147 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
END_SEARCH_QUERY = "<|end_search_query|>"
BEGIN_SEARCH_RESULT = "<|begin_search_result|>"
END_SEARCH_RESULT = "<|end_search_result|>"
MAX_SEARCH_LIMIT = 6
REASON_PROMPT = f"""You are an advanced reasoning agent. Your goal is to answer the user's question by breaking it down into a series of verifiable steps.
You have access to a powerful search tool to find information.
**Your Task:**
1. Analyze the user's question.
2. If you need information, issue a search query to find a specific fact.
3. Review the search results.
4. Repeat the search process until you have all the facts needed to answer the question.
5. Once you have gathered sufficient information, synthesize the facts and provide the final answer directly.
**Tool Usage:**
- To search, you MUST write your query between the special tokens: {BEGIN_SEARCH_QUERY}your query{END_SEARCH_QUERY}.
- The system will provide results between {BEGIN_SEARCH_RESULT}search results{END_SEARCH_RESULT}.
- You have a maximum of {MAX_SEARCH_LIMIT} search attempts.
---
**Example 1: Multi-hop Question**
**Question:** "Are both the directors of Jaws and Casino Royale from the same country?"
**Your Thought Process & Actions:**
First, I need to identify the director of Jaws.
{BEGIN_SEARCH_QUERY}who is the director of Jaws?{END_SEARCH_QUERY}
[System returns search results]
{BEGIN_SEARCH_RESULT}
Jaws is a 1975 American thriller film directed by Steven Spielberg.
{END_SEARCH_RESULT}
Okay, the director of Jaws is Steven Spielberg. Now I need to find out his nationality.
{BEGIN_SEARCH_QUERY}where is Steven Spielberg from?{END_SEARCH_QUERY}
[System returns search results]
{BEGIN_SEARCH_RESULT}
Steven Allan Spielberg is an American filmmaker. Born in Cincinnati, Ohio...
{END_SEARCH_RESULT}
So, Steven Spielberg is from the USA. Next, I need to find the director of Casino Royale.
{BEGIN_SEARCH_QUERY}who is the director of Casino Royale 2006?{END_SEARCH_QUERY}
[System returns search results]
{BEGIN_SEARCH_RESULT}
Casino Royale is a 2006 spy film directed by Martin Campbell.
{END_SEARCH_RESULT}
The director of Casino Royale is Martin Campbell. Now I need his nationality.
{BEGIN_SEARCH_QUERY}where is Martin Campbell from?{END_SEARCH_QUERY}
[System returns search results]
{BEGIN_SEARCH_RESULT}
Martin Campbell (born 24 October 1943) is a New Zealand film and television director.
{END_SEARCH_RESULT}
I have all the information. Steven Spielberg is from the USA, and Martin Campbell is from New Zealand. They are not from the same country.
Final Answer: No, the directors of Jaws and Casino Royale are not from the same country. Steven Spielberg is from the USA, and Martin Campbell is from New Zealand.
---
**Example 2: Simple Fact Retrieval**
**Question:** "When was the founder of craigslist born?"
**Your Thought Process & Actions:**
First, I need to know who founded craigslist.
{BEGIN_SEARCH_QUERY}who founded craigslist?{END_SEARCH_QUERY}
[System returns search results]
{BEGIN_SEARCH_RESULT}
Craigslist was founded in 1995 by Craig Newmark.
{END_SEARCH_RESULT}
The founder is Craig Newmark. Now I need his birth date.
{BEGIN_SEARCH_QUERY}when was Craig Newmark born?{END_SEARCH_QUERY}
[System returns search results]
{BEGIN_SEARCH_RESULT}
Craig Newmark was born on December 6, 1952.
{END_SEARCH_RESULT}
I have found the answer.
Final Answer: The founder of craigslist, Craig Newmark, was born on December 6, 1952.
---
**Important Rules:**
- **One Fact at a Time:** Decompose the problem and issue one search query at a time to find a single, specific piece of information.
- **Be Precise:** Formulate clear and precise search queries. If a search fails, rephrase it.
- **Synthesize at the End:** Do not provide the final answer until you have completed all necessary searches.
- **Language Consistency:** Your search queries should be in the same language as the user's question.
Now, begin your work. Please answer the following question by thinking step-by-step.
"""
RELEVANT_EXTRACTION_PROMPT = """You are a highly efficient information extraction module. Your sole purpose is to extract the single most relevant piece of information from the provided `Searched Web Pages` that directly answers the `Current Search Query`.
**Your Task:**
1. Read the `Current Search Query` to understand what specific information is needed.
2. Scan the `Searched Web Pages` to find the answer to that query.
3. Extract only the essential, factual information that answers the query. Be concise.
**Context (For Your Information Only):**
The `Previous Reasoning Steps` are provided to give you context on the overall goal, but your primary focus MUST be on answering the `Current Search Query`. Do not use information from the previous steps in your output.
**Output Format:**
Your response must follow one of two formats precisely.
1. **If a direct and relevant answer is found:**
- Start your response immediately with `Final Information`.
- Provide only the extracted fact(s). Do not add any extra conversational text.
*Example:*
`Current Search Query`: Where is Martin Campbell from?
`Searched Web Pages`: [Long article snippet about Martin Campbell's career, which includes the sentence "Martin Campbell (born 24 October 1943) is a New Zealand film and television director..."]
*Your Output:*
Final Information
Martin Campbell is a New Zealand film and television director.
2. **If no relevant answer that directly addresses the query is found in the web pages:**
- Start your response immediately with `Final Information`.
- Write the exact phrase: `No helpful information found.`
---
**BEGIN TASK**
**Inputs:**
- **Previous Reasoning Steps:**
{prev_reasoning}
- **Current Search Query:**
{search_query}
- **Searched Web Pages:**
{document}
"""

View File

@ -16,23 +16,21 @@
import logging
import os
import sys
import time
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from quart import Blueprint, Quart, request, g, current_app, session, jsonify
from quart import Blueprint, Quart, request, g, current_app, session
from flasgger import Swagger
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from quart_cors import cors
from common.constants import StatusEnum, RetCode
from common.constants import StatusEnum
from api.db.db_models import close_connection, APIToken
from api.db.services import UserService
from api.utils.json_encode import CustomJSONEncoder
from api.utils import commands
from quart_auth import Unauthorized as QuartAuthUnauthorized
from werkzeug.exceptions import Unauthorized as WerkzeugUnauthorized
from quart_schema import QuartSchema
from quart_auth import Unauthorized
from common import settings
from api.utils.api_utils import server_error_response, get_json_result
from api.utils.api_utils import server_error_response
from api.constants import API_VERSION
from common.misc_utils import get_uuid
@ -40,27 +38,41 @@ settings.init_settings()
__all__ = ["app"]
UNAUTHORIZED_MESSAGE = "<Unauthorized '401: Unauthorized'>"
def _unauthorized_message(error):
if error is None:
return UNAUTHORIZED_MESSAGE
try:
msg = repr(error)
except Exception:
return UNAUTHORIZED_MESSAGE
if msg == UNAUTHORIZED_MESSAGE:
return msg
if "Unauthorized" in msg and "401" in msg:
return msg
return UNAUTHORIZED_MESSAGE
app = Quart(__name__)
app = cors(app, allow_origin="*")
# openapi supported
QuartSchema(app)
# Add this at the beginning of your file to configure Swagger UI
swagger_config = {
"headers": [],
"specs": [
{
"endpoint": "apispec",
"route": "/apispec.json",
"rule_filter": lambda rule: True, # Include all endpoints
"model_filter": lambda tag: True, # Include all models
}
],
"static_url_path": "/flasgger_static",
"swagger_ui": True,
"specs_route": "/apidocs/",
}
swagger = Swagger(
app,
config=swagger_config,
template={
"swagger": "2.0",
"info": {
"title": "RAGFlow API",
"description": "",
"version": "1.0.0",
},
"securityDefinitions": {
"ApiKeyAuth": {"type": "apiKey", "name": "Authorization", "in": "header"}
},
},
)
app.url_map.strict_slashes = False
app.json_encoder = CustomJSONEncoder
@ -91,7 +103,6 @@ from werkzeug.local import LocalProxy
T = TypeVar("T")
P = ParamSpec("P")
def _load_user():
jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = request.headers.get("Authorization")
@ -114,28 +125,18 @@ def _load_user():
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
if not user and len(authorization.split()) == 2:
objs = APIToken.query(token=authorization.split()[1])
if objs:
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return None
g.user = user[0]
return user[0]
except Exception as e_auth:
logging.warning(f"load_user got exception {e_auth}")
try:
authorization = request.headers.get("Authorization")
if len(authorization.split()) == 2:
objs = APIToken.query(token=authorization.split()[1])
if objs:
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return None
g.user = user[0]
return user[0]
except Exception as e_api_token:
logging.warning(f"load_user got exception {e_api_token}")
except Exception as e:
logging.warning(f"load_user got exception {e}")
current_user = LocalProxy(_load_user)
@ -163,18 +164,10 @@ def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
t_start = time.perf_counter() if timing_enabled else None
user = current_user
if timing_enabled:
logging.info(
"api_timing login_required auth_ms=%.2f path=%s",
(time.perf_counter() - t_start) * 1000,
request.path,
)
if not user: # or not session.get("_user_id"):
raise QuartAuthUnauthorized()
return await current_app.ensure_async(func)(*args, **kwargs)
if not current_user:# or not session.get("_user_id"):
raise Unauthorized()
else:
return await current_app.ensure_async(func)(*args, **kwargs)
return wrapper
@ -235,7 +228,6 @@ def logout_user():
return True
def search_pages_path(page_path):
app_path_list = [
path for path in page_path.glob("*_app.py") if not path.name.startswith(".")
@ -282,36 +274,6 @@ client_urls_prefix = [
]
@app.errorhandler(404)
async def not_found(error):
logging.error(f"The requested URL {request.path} was not found")
message = f"Not Found: {request.path}"
response = {
"code": RetCode.NOT_FOUND,
"message": message,
"data": None,
"error": "Not Found",
}
return jsonify(response), RetCode.NOT_FOUND
@app.errorhandler(401)
async def unauthorized(error):
logging.warning("Unauthorized request")
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
@app.errorhandler(QuartAuthUnauthorized)
async def unauthorized_quart_auth(error):
logging.warning("Unauthorized request (quart_auth)")
return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(error)), RetCode.UNAUTHORIZED
@app.errorhandler(WerkzeugUnauthorized)
async def unauthorized_werkzeug(error):
logging.warning("Unauthorized request (werkzeug)")
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
@app.teardown_request
def _db_close(exception):
if exception:

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import asyncio
import inspect
import json
import logging
@ -29,14 +29,9 @@ from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, Ta
from api.db.services.user_service import TenantService
from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode
from common.misc_utils import get_uuid, thread_pool_exec
from api.utils.api_utils import (
get_json_result,
server_error_response,
validate_request,
get_data_error_result,
get_request_json,
)
from common.misc_utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
get_request_json
from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken, Task
@ -47,7 +42,6 @@ from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN
from common import settings
from api.apps import login_required, current_user
from api.db.services.canvas_service import completion as agent_completion
@manager.route('/templates', methods=['GET']) # noqa: F821
@ -138,12 +132,12 @@ async def run():
files = req.get("files", [])
inputs = req.get("inputs", {})
user_id = req.get("user_id", current_user.id)
if not await thread_pool_exec(UserCanvasService.accessible, req["id"], current_user.id):
if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, req["id"])
e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
@ -153,13 +147,13 @@ async def run():
if cvs.canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid()
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
ok, error_message = await thread_pool_exec(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
if not ok:
return get_data_error_result(message=error_message)
return get_json_result(data={"message_id": task_id})
try:
canvas = Canvas(cvs.dsl, current_user.id, canvas_id=cvs.id)
canvas = Canvas(cvs.dsl, current_user.id)
except Exception as e:
return server_error_response(e)
@ -186,50 +180,6 @@ async def run():
return resp
@manager.route("/<canvas_id>/completion", methods=["POST"]) # noqa: F821
@login_required
async def exp_agent_completion(canvas_id):
tenant_id = current_user.id
req = await get_request_json()
return_trace = bool(req.get("return_trace", False))
async def generate():
trace_items = []
async for answer in agent_completion(tenant_id=tenant_id, agent_id=canvas_id, **req):
if isinstance(answer, str):
try:
ans = json.loads(answer[5:]) # remove "data:"
except Exception:
continue
event = ans.get("event")
if event == "node_finished":
if return_trace:
data = ans.get("data", {})
trace_items.append(
{
"component_id": data.get("component_id"),
"trace": [copy.deepcopy(data)],
}
)
ans.setdefault("data", {})["trace"] = trace_items
answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
yield answer
if event not in ["message", "message_end"]:
continue
yield answer
yield "data:[DONE]\n\n"
resp = Response(generate(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
@manager.route('/rerun', methods=['POST']) # noqa: F821
@validate_request("id", "dsl", "component_id")
@login_required
@ -282,7 +232,7 @@ async def reset():
if not e:
return get_data_error_result(message="canvas not found.")
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset()
req["dsl"] = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
@ -299,14 +249,11 @@ async def upload(canvas_id):
user_id = cvs["user_id"]
files = await request.files
file_objs = files.getlist("file") if files and files.get("file") else []
file = files['file'] if files and files.get("file") else None
try:
if len(file_objs) == 1:
return get_json_result(data=FileService.upload_info(user_id, file_objs[0], request.args.get("url")))
results = [FileService.upload_info(user_id, f) for f in file_objs]
return get_json_result(data=results)
return get_json_result(data=FileService.upload_info(user_id, file, request.args.get("url")))
except Exception as e:
return server_error_response(e)
return server_error_response(e)
@manager.route('/input_form', methods=['GET']) # noqa: F821
@ -323,7 +270,7 @@ def input_form():
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
return get_json_result(data=canvas.get_component_input_form(cpn_id))
except Exception as e:
return server_error_response(e)
@ -340,7 +287,7 @@ async def debug():
code=RetCode.OPERATING_ERROR)
try:
e, user_canvas = UserCanvasService.get_by_id(req["id"])
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset()
canvas.message_id = get_uuid()
component = canvas.get_component(req["component_id"])["obj"]
@ -375,9 +322,6 @@ async def test_db_connect():
if req["db_type"] in ["mysql", "mariadb"]:
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
password=req["password"])
elif req["db_type"] == "oceanbase":
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
password=req["password"], charset="utf8mb4")
elif req["db_type"] == 'postgres':
db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
password=req["password"])
@ -578,70 +522,24 @@ def sessions(canvas_id):
from_date = request.args.get("from_date")
to_date = request.args.get("to_date")
orderby = request.args.get("orderby", "update_time")
exp_user_id = request.args.get("exp_user_id")
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
desc = False
else:
desc = True
if exp_user_id:
sess = API4ConversationService.get_names(canvas_id, exp_user_id)
return get_json_result(data={"total": len(sess), "sessions": sess})
# dsl defaults to True in all cases except for False and false
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc,
None, user_id, include_dsl, keywords, from_date, to_date, exp_user_id=exp_user_id)
None, user_id, include_dsl, keywords, from_date, to_date)
try:
return get_json_result(data={"total": total, "sessions": sess})
except Exception as e:
return server_error_response(e)
@manager.route('/<canvas_id>/sessions', methods=['PUT']) # noqa: F821
@login_required
async def set_session(canvas_id):
req = await get_request_json()
tenant_id = current_user.id
e, cvs = UserCanvasService.get_by_id(canvas_id)
assert e, "Agent not found."
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
session_id=get_uuid()
canvas = Canvas(cvs.dsl, tenant_id, canvas_id, canvas_id=cvs.id)
canvas.reset()
conv = {
"id": session_id,
"name": req.get("name", ""),
"dialog_id": cvs.id,
"user_id": tenant_id,
"exp_user_id": tenant_id,
"message": [],
"source": "agent",
"dsl": cvs.dsl,
"reference": []
}
API4ConversationService.save(**conv)
return get_json_result(data=conv)
@manager.route('/<canvas_id>/sessions/<session_id>', methods=['GET']) # noqa: F821
@login_required
def get_session(canvas_id, session_id):
tenant_id = current_user.id
if not UserCanvasService.accessible(canvas_id, tenant_id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
conv = API4ConversationService.get_by_id(session_id)
return get_json_result(data=conv.to_dict())
@manager.route('/prompts', methods=['GET']) # noqa: F821
@login_required
def prompts():
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
return get_json_result(data={
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
"plan_generation": NEXT_STEP,

View File

@ -13,29 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import asyncio
import datetime
import json
import logging
import re
import base64
import xxhash
from quart import request
from api.db.services.document_service import DocumentService
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from common.metadata_utils import apply_meta_data_filter
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import (
get_data_error_result,
get_json_result,
server_error_response,
validate_request,
get_request_json,
)
from common.misc_utils import thread_pool_exec
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
get_request_json
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
@ -45,6 +38,7 @@ from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
from common import settings
from api.apps import login_required, current_user
@manager.route('/list', methods=['POST']) # noqa: F821
@login_required
@validate_request("doc_id")
@ -67,7 +61,7 @@ async def list_chunk():
}
if "available_int" in req:
query["available_int"] = int(req["available_int"])
sres = await settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids:
d = {
@ -132,15 +126,10 @@ def get():
@validate_request("doc_id", "chunk_id", "content_with_weight")
async def set():
req = await get_request_json()
content_with_weight = req["content_with_weight"]
if not isinstance(content_with_weight, (str, bytes)):
raise TypeError("expected string or bytes-like object")
if isinstance(content_with_weight, bytes):
content_with_weight = content_with_weight.decode("utf-8", errors="ignore")
d = {
"id": req["chunk_id"],
"content_with_weight": content_with_weight}
d["content_ltks"] = rag_tokenizer.tokenize(content_with_weight)
"content_with_weight": req["content_with_weight"]}
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if "important_kwd" in req:
if not isinstance(req["important_kwd"], list):
@ -182,21 +171,20 @@ async def set():
_d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, content_with_weight if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
_d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
# update image
image_base64 = req.get("image_base64", None)
img_id = req.get("img_id", "")
if image_base64 and img_id and "-" in img_id:
bkt, name = img_id.split("-", 1)
if image_base64:
bkt, name = req.get("img_id", "-").split("-")
image_binary = base64.b64decode(image_base64)
settings.STORAGE_IMPL.put(bkt, name, image_binary)
return get_json_result(data=True)
return await thread_pool_exec(_set_sync)
return await asyncio.to_thread(_set_sync)
except Exception as e:
return server_error_response(e)
@ -219,7 +207,7 @@ async def switch():
return get_data_error_result(message="Index updating failure")
return get_json_result(data=True)
return await thread_pool_exec(_switch_sync)
return await asyncio.to_thread(_switch_sync)
except Exception as e:
return server_error_response(e)
@ -234,34 +222,19 @@ async def rm():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
condition = {"id": req["chunk_ids"], "doc_id": req["doc_id"]}
try:
deleted_count = settings.docStoreConn.delete(condition,
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id)
except Exception:
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
return get_data_error_result(message="Chunk deleting failure")
deleted_chunk_ids = req["chunk_ids"]
if isinstance(deleted_chunk_ids, list):
unique_chunk_ids = list(dict.fromkeys(deleted_chunk_ids))
has_ids = len(unique_chunk_ids) > 0
else:
unique_chunk_ids = [deleted_chunk_ids]
has_ids = deleted_chunk_ids not in (None, "")
if has_ids and deleted_count == 0:
return get_data_error_result(message="Index updating failure")
if deleted_count > 0 and deleted_count < len(unique_chunk_ids):
deleted_count += settings.docStoreConn.delete({"doc_id": req["doc_id"]},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id)
chunk_number = deleted_count
chunk_number = len(deleted_chunk_ids)
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
for cid in deleted_chunk_ids:
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
return get_json_result(data=True)
return await thread_pool_exec(_rm_sync)
return await asyncio.to_thread(_rm_sync)
except Exception as e:
return server_error_response(e)
@ -271,7 +244,6 @@ async def rm():
@validate_request("doc_id", "content_with_weight")
async def create():
req = await get_request_json()
req_id = request.headers.get("X-Request-ID")
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
"content_with_weight": req["content_with_weight"]}
@ -288,23 +260,14 @@ async def create():
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"]
if "tag_feas" in req:
d["tag_feas"] = req["tag_feas"]
try:
def _log_response(resp, code, message):
logging.info(
"chunk_create response req_id=%s status=%s code=%s message=%s",
req_id,
getattr(resp, "status_code", None),
code,
message,
)
def _create_sync():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
resp = get_data_error_result(message="Document not found!")
_log_response(resp, RetCode.DATA_ERROR, "Document not found!")
return resp
return get_data_error_result(message="Document not found!")
d["kb_id"] = [doc.kb_id]
d["docnm_kwd"] = doc.name
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
@ -312,15 +275,11 @@ async def create():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
resp = get_data_error_result(message="Tenant not found!")
_log_response(resp, RetCode.DATA_ERROR, "Tenant not found!")
return resp
return get_data_error_result(message="Tenant not found!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
resp = get_data_error_result(message="Knowledgebase not found!")
_log_response(resp, RetCode.DATA_ERROR, "Knowledgebase not found!")
return resp
return get_data_error_result(message="Knowledgebase not found!")
if kb.pagerank:
d[PAGERANK_FLD] = kb.pagerank
@ -334,13 +293,10 @@ async def create():
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
resp = get_json_result(data={"chunk_id": chunck_id})
_log_response(resp, RetCode.SUCCESS, "success")
return resp
return get_json_result(data={"chunk_id": chunck_id})
return await thread_pool_exec(_create_sync)
return await asyncio.to_thread(_create_sync)
except Exception as e:
logging.info("chunk_create exception req_id=%s error=%r", req_id, e)
return server_error_response(e)
@ -382,7 +338,7 @@ async def retrieval_test():
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
if meta_data_filter:
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
metas = DocumentService.get_meta_by_kbs(kb_ids)
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids)
tenants = UserTenantService.query(user_id=user_id)
@ -416,23 +372,16 @@ async def retrieval_test():
_question += await keyword_extraction(chat_mdl, _question)
labels = label_question(_question, [kb])
ranks = await settings.retriever.retrieval(
_question,
embd_mdl,
tenant_ids,
kb_ids,
page,
size,
float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)),
doc_ids=local_doc_ids,
top=top,
rerank_mdl=rerank_mdl,
rank_feature=labels
)
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)),
top,
local_doc_ids, rerank_mdl=rerank_mdl,
highlight=req.get("highlight", False),
rank_feature=labels
)
if use_kg:
ck = await settings.kg_retriever.retrieval(_question,
ck = settings.kg_retriever.retrieval(_question,
tenant_ids,
kb_ids,
embd_mdl,
@ -458,7 +407,7 @@ async def retrieval_test():
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
@login_required
async def knowledge_graph():
def knowledge_graph():
doc_id = request.args["doc_id"]
tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
@ -466,7 +415,7 @@ async def knowledge_graph():
"doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"]
}
sres = await settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"]

View File

@ -52,7 +52,7 @@ async def set_connector():
"source": req["source"],
"input_type": InputType.POLL,
"config": req["config"],
"refresh_freq": int(req.get("refresh_freq", 5)),
"refresh_freq": int(req.get("refresh_freq", 30)),
"prune_freq": int(req.get("prune_freq", 720)),
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
"status": TaskStatus.SCHEDULE,

View File

@ -25,7 +25,6 @@ from api.utils.api_utils import get_data_error_result, get_json_result, get_requ
from common.misc_utils import get_uuid
from common.constants import RetCode
from api.apps import login_required, current_user
import logging
@manager.route('/set', methods=['POST']) # noqa: F821
@ -43,19 +42,13 @@ async def set_dialog():
if len(name.encode("utf-8")) > 255:
return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255")
name = name.strip()
if is_create:
# only for chat creating
existing_names = {
d.name.casefold()
for d in DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value)
if d.name
}
if name.casefold() in existing_names:
def _name_exists(name: str, **_kwargs) -> bool:
return name.casefold() in existing_names
name = duplicate_name(_name_exists, name=name)
if is_create and DialogService.query(tenant_id=current_user.id, name=name.strip()):
name = name.strip()
name = duplicate_name(
DialogService.query,
name=name,
tenant_id=current_user.id,
status=StatusEnum.VALID.value)
description = req.get("description", "A helpful dialog")
icon = req.get("icon", "")
@ -70,30 +63,16 @@ async def set_dialog():
meta_data_filter = req.get("meta_data_filter", {})
prompt_config = req["prompt_config"]
# Set default parameters for datasets with knowledge retrieval
# All datasets with {knowledge} in system prompt need "knowledge" parameter to enable retrieval
kb_ids = req.get("kb_ids", [])
parameters = prompt_config.get("parameters")
logging.debug(f"set_dialog: kb_ids={kb_ids}, parameters={parameters}, is_create={not is_create}")
# Check if parameters is missing, None, or empty list
if kb_ids and not parameters:
# Check if system prompt uses {knowledge} placeholder
if "{knowledge}" in prompt_config.get("system", ""):
# Set default parameters for any dataset with knowledge placeholder
prompt_config["parameters"] = [{"key": "knowledge", "optional": False}]
logging.debug(f"Set default parameters for datasets with knowledge placeholder: {kb_ids}")
if not is_create:
# only for chat updating
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config.get("system", ""):
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']:
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.")
for p in prompt_config.get("parameters", []):
if p["optional"]:
continue
if prompt_config.get("system", "").find("{%s}" % p["key"]) < 0:
return get_data_error_result(
message="Parameter '{}' is not used".format(p["key"]))
for p in prompt_config["parameters"]:
if p["optional"]:
continue
if prompt_config["system"].find("{%s}" % p["key"]) < 0:
return get_data_error_result(
message="Parameter '{}' is not used".format(p["key"]))
try:
e, tenant = TenantService.get_by_id(current_user.id)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
#
import asyncio
import json
import os.path
import pathlib
@ -26,20 +27,18 @@ from api.db import VALID_FILE_TYPES, FileType
from api.db.db_models import Task
from api.db.services import duplicate_name
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.db.services.doc_metadata_service import DocMetadataService
from common.metadata_utils import meta_filter, convert_conditions, turn2jsonschema
from common.metadata_utils import meta_filter, convert_conditions
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService, cancel_all_task_of
from api.db.services.user_service import UserTenantService
from common.misc_utils import get_uuid, thread_pool_exec
from common.misc_utils import get_uuid
from api.utils.api_utils import (
get_data_error_result,
get_json_result,
server_error_response,
validate_request,
get_request_json,
validate_request, get_request_json,
)
from api.utils.file_utils import filename_type, thumbnail
from common.file_utils import get_project_base_directory
@ -63,21 +62,10 @@ async def upload():
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
file_objs = files.getlist("file")
def _close_file_objs(objs):
for obj in objs:
try:
obj.close()
except Exception:
try:
obj.stream.close()
except Exception:
pass
for file_obj in file_objs:
if file_obj.filename == "":
_close_file_objs(file_objs)
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
_close_file_objs(file_objs)
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id)
@ -86,9 +74,8 @@ async def upload():
if not check_kb_team_permission(kb, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
err, files = await thread_pool_exec(FileService.upload_document, kb, file_objs, current_user.id)
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
if err:
files = [f[0] for f in files] if files else []
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
if not files:
@ -227,7 +214,6 @@ async def list_docs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
@ -248,10 +234,6 @@ async def list_docs():
req = await get_request_json()
return_empty_metadata = req.get("return_empty_metadata", False)
if isinstance(return_empty_metadata, str):
return_empty_metadata = return_empty_metadata.lower() == "true"
run_status = req.get("run_status", [])
if run_status:
invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS}
@ -266,23 +248,16 @@ async def list_docs():
suffix = req.get("suffix", [])
metadata_condition = req.get("metadata_condition", {}) or {}
if metadata_condition and not isinstance(metadata_condition, dict):
return get_data_error_result(message="metadata_condition must be an object.")
metadata = req.get("metadata", {}) or {}
if isinstance(metadata, dict) and metadata.get("empty_metadata"):
return_empty_metadata = True
metadata = {k: v for k, v in metadata.items() if k != "empty_metadata"}
if return_empty_metadata:
metadata_condition = {}
metadata = {}
else:
if metadata_condition and not isinstance(metadata_condition, dict):
return get_data_error_result(message="metadata_condition must be an object.")
if metadata and not isinstance(metadata, dict):
return get_data_error_result(message="metadata must be an object.")
if metadata and not isinstance(metadata, dict):
return get_data_error_result(message="metadata must be an object.")
doc_ids_filter = None
metas = None
if metadata_condition or metadata:
metas = DocMetadataService.get_flatted_meta_by_kbs([kb_id])
metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
if metadata_condition:
doc_ids_filter = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
@ -320,19 +295,7 @@ async def list_docs():
doc_ids_filter = list(doc_ids_filter)
try:
docs, tol = DocumentService.get_by_kb_id(
kb_id,
page_number,
items_per_page,
orderby,
desc,
keywords,
run_status,
types,
suffix,
doc_ids_filter,
return_empty_metadata=return_empty_metadata,
)
docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, doc_ids_filter)
if create_time_from or create_time_to:
filtered_docs = []
@ -347,8 +310,6 @@ async def list_docs():
doc_item["thumbnail"] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}"
if doc_item.get("source_type"):
doc_item["source_type"] = doc_item["source_type"].split("/")[0]
if doc_item["parser_config"].get("metadata"):
doc_item["parser_config"]["metadata"] = turn2jsonschema(doc_item["parser_config"]["metadata"])
return get_json_result(data={"total": tol, "docs": docs})
except Exception as e:
@ -402,11 +363,7 @@ async def doc_infos():
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
docs = DocumentService.get_by_ids(doc_ids)
docs_list = list(docs.dicts())
# Add meta_fields for each document
for doc in docs_list:
doc["meta_fields"] = DocMetadataService.get_document_metadata(doc["id"])
return get_json_result(data=docs_list)
return get_json_result(data=list(docs.dicts()))
@manager.route("/metadata/summary", methods=["POST"]) # noqa: F821
@ -414,7 +371,6 @@ async def doc_infos():
async def metadata_summary():
req = await get_request_json()
kb_id = req.get("kb_id")
doc_ids = req.get("doc_ids")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@ -426,7 +382,7 @@ async def metadata_summary():
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
try:
summary = DocMetadataService.get_metadata_summary(kb_id, doc_ids)
summary = DocumentService.get_metadata_summary(kb_id)
return get_json_result(data={"summary": summary})
except Exception as e:
return server_error_response(e)
@ -434,20 +390,36 @@ async def metadata_summary():
@manager.route("/metadata/update", methods=["POST"]) # noqa: F821
@login_required
@validate_request("doc_ids")
async def metadata_update():
req = await get_request_json()
kb_id = req.get("kb_id")
document_ids = req.get("doc_ids")
updates = req.get("updates", []) or []
deletes = req.get("deletes", []) or []
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
break
else:
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
selector = req.get("selector", {}) or {}
updates = req.get("updates", []) or []
deletes = req.get("deletes", []) or []
if not isinstance(selector, dict):
return get_json_result(data=False, message="selector must be an object.", code=RetCode.ARGUMENT_ERROR)
if not isinstance(updates, list) or not isinstance(deletes, list):
return get_json_result(data=False, message="updates and deletes must be lists.", code=RetCode.ARGUMENT_ERROR)
metadata_condition = selector.get("metadata_condition", {}) or {}
if metadata_condition and not isinstance(metadata_condition, dict):
return get_json_result(data=False, message="metadata_condition must be an object.", code=RetCode.ARGUMENT_ERROR)
document_ids = selector.get("document_ids", []) or []
if document_ids and not isinstance(document_ids, list):
return get_json_result(data=False, message="document_ids must be a list.", code=RetCode.ARGUMENT_ERROR)
for upd in updates:
if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd:
return get_json_result(data=False, message="Each update requires key and value.", code=RetCode.ARGUMENT_ERROR)
@ -455,8 +427,24 @@ async def metadata_update():
if not isinstance(d, dict) or not d.get("key"):
return get_json_result(data=False, message="Each delete requires key.", code=RetCode.ARGUMENT_ERROR)
updated = DocMetadataService.batch_update_metadata(kb_id, document_ids, updates, deletes)
return get_json_result(data={"updated": updated, "matched_docs": len(document_ids)})
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([kb_id])
target_doc_ids = set(kb_doc_ids)
if document_ids:
invalid_ids = set(document_ids) - set(kb_doc_ids)
if invalid_ids:
return get_json_result(data=False, message=f"These documents do not belong to dataset {kb_id}: {', '.join(invalid_ids)}", code=RetCode.ARGUMENT_ERROR)
target_doc_ids = set(document_ids)
if metadata_condition:
metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
target_doc_ids = target_doc_ids & filtered_ids
if metadata_condition.get("conditions") and not target_doc_ids:
return get_json_result(data={"updated": 0, "matched_docs": 0})
target_doc_ids = list(target_doc_ids)
updated = DocumentService.batch_update_metadata(kb_id, target_doc_ids, updates, deletes)
return get_json_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
@manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821
@ -510,61 +498,31 @@ async def change_status():
return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=RetCode.ARGUMENT_ERROR)
result = {}
has_error = False
for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id):
result[doc_id] = {"error": "No authorization."}
has_error = True
continue
try:
e, doc = DocumentService.get_by_id(doc_id)
if not e:
result[doc_id] = {"error": "No authorization."}
has_error = True
continue
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
result[doc_id] = {"error": "Can't find this dataset!"}
has_error = True
continue
current_status = str(doc.status)
if current_status == status:
result[doc_id] = {"status": status}
continue
if not DocumentService.update_by_id(doc_id, {"status": str(status)}):
result[doc_id] = {"error": "Database error (Document update)!"}
has_error = True
continue
status_int = int(status)
if getattr(doc, "chunk_num", 0) > 0:
try:
ok = settings.docStoreConn.update(
{"doc_id": doc_id},
{"available_int": status_int},
search.index_name(kb.tenant_id),
doc.kb_id,
)
except Exception as exc:
msg = str(exc)
if "3022" in msg:
result[doc_id] = {"error": "Document store table missing."}
else:
result[doc_id] = {"error": f"Document store update failed: {msg}"}
has_error = True
continue
if not ok:
result[doc_id] = {"error": "Database error (docStore update)!"}
has_error = True
continue
if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
result[doc_id] = {"error": "Database error (docStore update)!"}
result[doc_id] = {"status": status}
except Exception as e:
result[doc_id] = {"error": f"Internal server error: {str(e)}"}
has_error = True
if has_error:
return get_json_result(data=result, message="Partial failure", code=RetCode.SERVER_ERROR)
return get_json_result(data=result)
@ -581,7 +539,7 @@ async def rm():
if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
errors = await thread_pool_exec(FileService.delete_docs, doc_ids, current_user.id)
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
@ -594,11 +552,10 @@ async def rm():
@validate_request("doc_ids", "run")
async def run():
req = await get_request_json()
uid = current_user.id
try:
def _run_sync():
for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, uid):
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
kb_table_num_map = {}
@ -631,20 +588,12 @@ async def run():
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value:
if req.get("apply_kb"):
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
raise LookupError("Can't find this dataset!")
doc.parser_config["llm_id"] = kb.parser_config.get("llm_id")
doc.parser_config["enable_metadata"] = kb.parser_config.get("enable_metadata", False)
doc.parser_config["metadata"] = kb.parser_config.get("metadata", {})
DocumentService.update_parser_config(doc.id, doc.parser_config)
doc_dict = doc.to_dict()
DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
return get_json_result(data=True)
return await thread_pool_exec(_run_sync)
return await asyncio.to_thread(_run_sync)
except Exception as e:
return server_error_response(e)
@ -654,10 +603,9 @@ async def run():
@validate_request("doc_id", "name")
async def rename():
req = await get_request_json()
uid = current_user.id
try:
def _rename_sync():
if not DocumentService.accessible(req["doc_id"], uid):
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
e, doc = DocumentService.get_by_id(req["doc_id"])
@ -696,7 +644,7 @@ async def rename():
)
return get_json_result(data=True)
return await thread_pool_exec(_rename_sync)
return await asyncio.to_thread(_rename_sync)
except Exception as e:
return server_error_response(e)
@ -711,7 +659,7 @@ async def get(doc_id):
return get_data_error_result(message="Document not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
data = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
response = await make_response(data)
ext = re.search(r"\.([^.]+)$", doc.name.lower())
@ -733,7 +681,7 @@ async def get(doc_id):
async def download_attachment(attachment_id):
try:
ext = request.args.get("ext", "markdown")
data = await thread_pool_exec(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
@ -768,7 +716,6 @@ async def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
DocumentService.delete_chunk_images(doc, tenant_id)
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
return None
@ -806,7 +753,7 @@ async def get_image(image_id):
if len(arr) != 2:
return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-")
data = await thread_pool_exec(settings.STORAGE_IMPL.get, bkt, nm)
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
response = await make_response(data)
response.headers.set("Content-Type", "image/JPEG")
return response
@ -914,7 +861,7 @@ async def set_meta():
if not e:
return get_data_error_result(message="Document not found!")
if not DocMetadataService.update_document_metadata(req["doc_id"], meta):
if not DocumentService.update_by_id(req["doc_id"], {"meta_fields": meta}):
return get_data_error_result(message="Database error (meta updates)!")
return get_json_result(data=True)

View File

@ -14,6 +14,7 @@
# limitations under the License
#
import logging
import asyncio
import os
import pathlib
import re
@ -24,7 +25,7 @@ from api.common.check_team_permission import check_file_team_permission
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from common.misc_utils import get_uuid, thread_pool_exec
from common.misc_utils import get_uuid
from common.constants import RetCode, FileSource
from api.db import FileType
from api.db.services import duplicate_name
@ -34,6 +35,7 @@ from api.utils.file_utils import filename_type
from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings
@manager.route('/upload', methods=['POST']) # noqa: F821
@login_required
# @validate_request("parent_id")
@ -63,7 +65,7 @@ async def upload():
async def _handle_single_file(file_obj):
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if 0 < MAX_FILE_NUM_PER_USER <= await thread_pool_exec(DocumentService.get_doc_count, current_user.id):
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
return get_data_error_result( message="Exceed the maximum file number of a free user!")
# split file name path
@ -75,35 +77,35 @@ async def upload():
file_len = len(file_obj_names)
# get folder
file_id_list = await thread_pool_exec(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
len_id_list = len(file_id_list)
# create folder
if file_len != len_id_list:
e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 1])
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
if not e:
return get_data_error_result(message="Folder not found!")
last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
len_id_list)
else:
e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 2])
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
if not e:
return get_data_error_result(message="Folder not found!")
last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
len_id_list)
# file type
filetype = filename_type(file_obj_names[file_len - 1])
location = file_obj_names[file_len - 1]
while await thread_pool_exec(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
location += "_"
blob = await thread_pool_exec(file_obj.read)
filename = await thread_pool_exec(
blob = await asyncio.to_thread(file_obj.read)
filename = await asyncio.to_thread(
duplicate_name,
FileService.query,
name=file_obj_names[file_len - 1],
parent_id=last_folder.id)
await thread_pool_exec(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
file_data = {
"id": get_uuid(),
"parent_id": last_folder.id,
@ -114,7 +116,7 @@ async def upload():
"location": location,
"size": len(blob),
}
inserted = await thread_pool_exec(FileService.insert, file_data)
inserted = await asyncio.to_thread(FileService.insert, file_data)
return inserted.to_json()
for file_obj in file_objs:
@ -247,7 +249,6 @@ def get_all_parent_folders():
async def rm():
req = await get_request_json()
file_ids = req["file_ids"]
uid = current_user.id
try:
def _delete_single_file(file):
@ -286,21 +287,21 @@ async def rm():
return get_data_error_result(message="File or Folder not found!")
if not file.tenant_id:
return get_data_error_result(message="Tenant not found!")
if not check_file_team_permission(file, uid):
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
if file.source_type == FileSource.KNOWLEDGEBASE:
continue
if file.type == FileType.FOLDER.value:
_delete_folder_recursive(file, uid)
_delete_folder_recursive(file, current_user.id)
continue
_delete_single_file(file)
return get_json_result(data=True)
return await thread_pool_exec(_rm_sync)
return await asyncio.to_thread(_rm_sync)
except Exception as e:
return server_error_response(e)
@ -356,10 +357,10 @@ async def get(file_id):
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
blob = await thread_pool_exec(settings.STORAGE_IMPL.get, file.parent_id, file.location)
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
if not blob:
b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
response = await make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name.lower())
@ -459,7 +460,7 @@ async def move():
_move_entry_recursive(file, dest_folder)
return get_json_result(data=True)
return await thread_pool_exec(_move_sync)
return await asyncio.to_thread(_move_sync)
except Exception as e:
return server_error_response(e)

View File

@ -17,29 +17,21 @@ import json
import logging
import random
import re
import asyncio
from common.metadata_utils import turn2jsonschema
from quart import request
import numpy as np
from api.db.services.connector_service import Connector2KbService
from api.db.services.llm_service import LLMBundle
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import (
get_error_data_result,
server_error_response,
get_data_error_result,
validate_request,
not_allowed_parameters,
get_request_json,
)
from common.misc_utils import thread_pool_exec
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
get_request_json
from api.db import VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File
@ -52,6 +44,7 @@ from common import settings
from common.doc_store.doc_store_base import OrderByExpr
from api.apps import login_required, current_user
@manager.route('/create', methods=['post']) # noqa: F821
@login_required
@validate_request("name")
@ -89,20 +82,6 @@ async def update():
return get_data_error_result(
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
req["name"] = req["name"].strip()
if settings.DOC_ENGINE_INFINITY:
parser_id = req.get("parser_id")
if isinstance(parser_id, str) and parser_id.lower() == "tag":
return get_json_result(
code=RetCode.OPERATING_ERROR,
message="The chunking method Tag has not been supported by Infinity yet.",
data=False,
)
if "pagerank" in req and req["pagerank"] > 0:
return get_json_result(
code=RetCode.DATA_ERROR,
message="'pagerank' can only be set when doc_engine is elasticsearch",
data=False,
)
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
@ -151,7 +130,7 @@ async def update():
if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0:
await thread_pool_exec(
await asyncio.to_thread(
settings.docStoreConn.update,
{"kb_id": kb.id},
{PAGERANK_FLD: req["pagerank"]},
@ -160,7 +139,7 @@ async def update():
)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
await thread_pool_exec(
await asyncio.to_thread(
settings.docStoreConn.update,
{"exists": PAGERANK_FLD},
{"remove": PAGERANK_FLD},
@ -195,7 +174,6 @@ async def update_metadata_setting():
message="Database error (Knowledgebase rename)!")
kb = kb.to_dict()
kb["parser_config"]["metadata"] = req["metadata"]
kb["parser_config"]["enable_metadata"] = req.get("enable_metadata", True)
KnowledgebaseService.update_by_id(kb["id"], kb)
return get_json_result(data=kb)
@ -220,8 +198,6 @@ def detail():
message="Can't find this dataset!")
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
kb["connectors"] = Connector2KbService.list_connectors(kb_id)
if kb["parser_config"].get("metadata"):
kb["parser_config"]["metadata"] = turn2jsonschema(kb["parser_config"]["metadata"])
for key in ["graphrag_task_finish_at", "raptor_task_finish_at", "mindmap_task_finish_at"]:
if finish_at := kb.get(key):
@ -273,8 +249,7 @@ async def list_kbs():
@validate_request("kb_id")
async def rm():
req = await get_request_json()
uid = current_user.id
if not KnowledgebaseService.accessible4deletion(req["kb_id"], uid):
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
data=False,
message='No authorization.',
@ -282,7 +257,7 @@ async def rm():
)
try:
kbs = KnowledgebaseService.query(
created_by=uid, id=req["kb_id"])
created_by=current_user.id, id=req["kb_id"])
if not kbs:
return get_json_result(
data=False, message='Only owner of dataset authorized for this operation.',
@ -305,24 +280,17 @@ async def rm():
File.name == kbs[0].name,
]
)
# Delete the table BEFORE deleting the database record
for kb in kbs:
try:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
logging.info(f"Dropped index for dataset {kb.id}")
except Exception as e:
logging.error(f"Failed to drop index for dataset {kb.id}: {e}")
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
return get_data_error_result(
message="Database error (Knowledgebase removal)!")
for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)
return await thread_pool_exec(_rm_sync)
return await asyncio.to_thread(_rm_sync)
except Exception as e:
return server_error_response(e)
@ -404,7 +372,7 @@ async def rename_tags(kb_id):
@manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821
@login_required
async def knowledge_graph(kb_id):
def knowledge_graph(kb_id):
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@ -420,7 +388,7 @@ async def knowledge_graph(kb_id):
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
return get_json_result(data=obj)
sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
if not len(sres.ids):
return get_json_result(data=obj)
@ -468,7 +436,7 @@ def get_meta():
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
return get_json_result(data=DocMetadataService.get_flatted_meta_by_kbs(kb_ids))
return get_json_result(data=DocumentService.get_meta_by_kbs(kb_ids))
@manager.route("/basic_info", methods=["GET"]) # noqa: F821

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import logging
import json
import os
@ -65,17 +64,13 @@ async def set_api_key():
chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"]
extra = {"provider": factory}
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
msg = ""
for llm in LLMService.query(fid=factory):
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try:
arr, tc = await asyncio.wait_for(
asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
timeout=timeout_seconds,
)
arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0:
raise Exception("Fail")
embd_passed = True
@ -85,27 +80,17 @@ async def set_api_key():
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
try:
m, tc = await asyncio.wait_for(
mdl.async_chat(
None,
[{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9, "max_tokens": 50},
),
timeout=timeout_seconds,
)
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
if m.find("**ERROR**") >= 0:
raise Exception(m)
chat_passed = True
except Exception as e:
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
elif not rerank_passed and llm.model_type == LLMType.RERANK.value:
elif not rerank_passed and llm.model_type == LLMType.RERANK:
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try:
arr, tc = await asyncio.wait_for(
asyncio.to_thread(mdl.similarity, "What's the weather?", ["Is it sunny today?"]),
timeout=timeout_seconds,
)
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr) == 0 or tc == 0:
raise Exception("Fail")
rerank_passed = True
@ -116,9 +101,6 @@ async def set_api_key():
msg = ""
break
if req.get("verify", False):
return get_json_result(data={"message": msg, "success": len(msg.strip())==0})
if msg:
return get_data_error_result(message=msg)
@ -151,7 +133,6 @@ async def add_llm():
factory = req["llm_factory"]
api_key = req.get("api_key", "x")
llm_name = req.get("llm_name")
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
if factory not in [f.name for f in get_allowed_llm_factories()]:
return get_data_error_result(message=f"LLM factory {factory} is not allowed")
@ -165,6 +146,10 @@ async def add_llm():
# Assemble ark_api_key endpoint_id into api_key
api_key = apikey_json(["ark_api_key", "endpoint_id"])
elif factory == "Tencent Hunyuan":
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
return await set_api_key()
elif factory == "Tencent Cloud":
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
return await set_api_key()
@ -210,9 +195,6 @@ async def add_llm():
elif factory == "MinerU":
api_key = apikey_json(["api_key", "provider_order"])
elif factory == "PaddleOCR":
api_key = apikey_json(["api_key", "provider_order"])
llm = {
"tenant_id": current_user.id,
"llm_factory": factory,
@ -234,10 +216,7 @@ async def add_llm():
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
arr, tc = await asyncio.wait_for(
asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
timeout=timeout_seconds,
)
arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0:
raise Exception("Fail")
except Exception as e:
@ -251,14 +230,8 @@ async def add_llm():
**extra,
)
try:
m, tc = await asyncio.wait_for(
mdl.async_chat(
None,
[{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9},
),
timeout=timeout_seconds,
)
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9})
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
@ -268,10 +241,7 @@ async def add_llm():
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
try:
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
arr, tc = await asyncio.wait_for(
asyncio.to_thread(mdl.similarity, "Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]),
timeout=timeout_seconds,
)
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
if len(arr) == 0:
raise Exception("Not known.")
except KeyError:
@ -284,10 +254,7 @@ async def add_llm():
mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
image_data = test_image
m, tc = await asyncio.wait_for(
asyncio.to_thread(mdl.describe, image_data),
timeout=timeout_seconds,
)
m, tc = mdl.describe(image_data)
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
@ -296,29 +263,20 @@ async def add_llm():
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
def drain_tts():
for _ in mdl.tts("Hello~ RAGFlower!"):
pass
await asyncio.wait_for(
asyncio.to_thread(drain_tts),
timeout=timeout_seconds,
)
for resp in mdl.tts("Hello~ RAGFlower!"):
pass
except RuntimeError as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.OCR.value:
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
try:
mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
ok, reason = await asyncio.wait_for(
asyncio.to_thread(mdl.check_available),
timeout=timeout_seconds,
)
ok, reason = mdl.check_available()
if not ok:
raise RuntimeError(reason or "Model not available")
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.SPEECH2TEXT.value:
case LLMType.SPEECH2TEXT:
assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet."
try:
mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
@ -328,9 +286,6 @@ async def add_llm():
case _:
raise RuntimeError(f"Unknown model type: {model_type}")
if req.get("verify", False):
return get_json_result(data={"message": msg, "success": len(msg.strip()) == 0})
if msg:
return get_data_error_result(message=msg)
@ -416,18 +371,17 @@ def my_llms():
@manager.route("/list", methods=["GET"]) # noqa: F821
@login_required
async def list_app():
def list_app():
self_deployed = ["FastEmbed", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
weighted = []
model_type = request.args.get("model_type")
tenant_id = current_user.id
try:
TenantLLMService.ensure_mineru_from_env(tenant_id)
objs = TenantLLMService.query(tenant_id=tenant_id)
TenantLLMService.ensure_mineru_from_env(current_user.id)
objs = TenantLLMService.query(tenant_id=current_user.id)
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value])
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
llms = LLMService.get_all()
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.fid == "Builtin" or (m.llm_name + "@" + m.fid) in status)]
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.fid == 'Builtin' or (m.llm_name + "@" + m.fid) in status)]
for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"] == LLMType.EMBEDDING and m["fid"] == "Builtin" and m["llm_name"] == os.getenv("TEI_MODEL", ""):

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
from quart import Response, request
from api.apps import current_user, login_required
@ -21,11 +23,12 @@ from api.db.services.mcp_server_service import MCPServerService
from api.db.services.user_service import TenantService
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
from common.misc_utils import get_uuid, thread_pool_exec
from common.misc_utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
from api.utils.web_utils import get_float, safe_json_parse
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@manager.route("/list", methods=["POST"]) # noqa: F821
@login_required
async def list_mcp() -> Response:
@ -105,7 +108,7 @@ async def create() -> Response:
return get_data_error_result(message="Tenant not found.")
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
if err_message:
return get_data_error_result(err_message)
@ -157,7 +160,7 @@ async def update() -> Response:
req["id"] = mcp_id
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
if err_message:
return get_data_error_result(err_message)
@ -241,7 +244,7 @@ async def import_multiple() -> Response:
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
if err_message:
results.append({"server": base_name, "success": False, "message": err_message})
continue
@ -321,7 +324,7 @@ async def list_tools() -> Response:
tool_call_sessions.append(tool_call_session)
try:
tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
except Exception as e:
return get_data_error_result(message=f"MCP list tools error: {e}")
@ -338,7 +341,7 @@ async def list_tools() -> Response:
return server_error_response(e)
finally:
# PERF: blocking call to close sessions — consider moving to background thread or task queue
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
@ -365,10 +368,10 @@ async def test_tool() -> Response:
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
result = await thread_pool_exec(tool_call_session.tool_call, tool_name, arguments, timeout)
result = await asyncio.to_thread(tool_call_session.tool_call, tool_name, arguments, timeout)
# PERF: blocking call to close sessions — consider moving to background thread or task queue
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@ -422,12 +425,12 @@ async def test_mcp() -> Response:
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
try:
tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
except Exception as e:
return get_data_error_result(message=f"Test MCP error: {e}")
finally:
# PERF: blocking call to close sessions — consider moving to background thread or task queue
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session])
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, [tool_call_session])
for tool in tools:
tool_dict = tool.model_dump()

View File

@ -14,8 +14,6 @@
# limitations under the License.
#
import logging
import os
import time
from quart import request
from api.apps import login_required, current_user
@ -23,70 +21,34 @@ from api.db import TenantPermission
from api.db.services.memory_service import MemoryService
from api.db.services.user_service import UserTenantService
from api.db.services.canvas_service import UserCanvasService
from api.db.services.task_service import TaskService
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
not_allowed_parameters
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
from memory.services.messages import MessageService
from memory.utils.prompt_util import PromptAssembler
from common.constants import MemoryType, RetCode, ForgettingPolicy
@manager.route("/memories", methods=["POST"]) # noqa: F821
@manager.route("", methods=["POST"]) # noqa: F821
@login_required
@validate_request("name", "memory_type", "embd_id", "llm_id")
async def create_memory():
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
t_start = time.perf_counter() if timing_enabled else None
req = await get_request_json()
t_parsed = time.perf_counter() if timing_enabled else None
# check name length
name = req["name"]
memory_name = name.strip()
if len(memory_name) == 0:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result("Memory name cannot be empty or whitespace.")
if len(memory_name) > MEMORY_NAME_LIMIT:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
# check memory_type valid
if not isinstance(req["memory_type"], list):
if timing_enabled:
logging.info(
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result("Memory type must be a list.")
memory_type = set(req["memory_type"])
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
if invalid_type:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
memory_type = list(memory_type)
try:
t_before_db = time.perf_counter() if timing_enabled else None
res, memory = MemoryService.create_memory(
tenant_id=current_user.id,
name=memory_name,
@ -94,15 +56,6 @@ async def create_memory():
embd_id=req["embd_id"],
llm_id=req["llm_id"]
)
if timing_enabled:
logging.info(
"api_timing create_memory parse_ms=%.2f validate_ms=%.2f db_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(t_before_db - t_parsed) * 1000,
(time.perf_counter() - t_before_db) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
if res:
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
@ -113,8 +66,9 @@ async def create_memory():
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/memories/<memory_id>", methods=["PUT"]) # noqa: F821
@manager.route("/<memory_id>", methods=["PUT"]) # noqa: F821
@login_required
@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type", "embd_id")
async def update_memory(memory_id):
req = await get_request_json()
update_dict = {}
@ -134,14 +88,6 @@ async def update_memory(memory_id):
update_dict["permissions"] = req["permissions"]
if req.get("llm_id"):
update_dict["llm_id"] = req["llm_id"]
if req.get("embd_id"):
update_dict["embd_id"] = req["embd_id"]
if req.get("memory_type"):
memory_type = set(req["memory_type"])
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
if invalid_type:
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
update_dict["memory_type"] = list(memory_type)
# check memory_size valid
if req.get("memory_size"):
if not 0 < int(req["memory_size"]) <= MEMORY_SIZE_LIMIT:
@ -177,15 +123,6 @@ async def update_memory(memory_id):
if not to_update:
return get_json_result(message=True, data=memory_dict)
# check memory empty when update embd_id, memory_type
memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id)
not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0]
if not_allowed_update:
return get_error_argument_result(f"Can't update {not_allowed_update} when memory isn't empty.")
if "memory_type" in to_update:
if "system_prompt" not in to_update and judge_system_prompt_is_default(current_memory.system_prompt, current_memory.memory_type):
# update old default prompt, assemble a new one
to_update["system_prompt"] = PromptAssembler.assemble_system_prompt({"memory_type": to_update["memory_type"]})
try:
MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update)
@ -197,7 +134,7 @@ async def update_memory(memory_id):
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/memories/<memory_id>", methods=["DELETE"]) # noqa: F821
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def delete_memory(memory_id):
memory = MemoryService.get_by_memory_id(memory_id)
@ -205,15 +142,14 @@ async def delete_memory(memory_id):
return get_json_result(message=True, code=RetCode.NOT_FOUND)
try:
MemoryService.delete_memory(memory_id)
if MessageService.has_index(memory.tenant_id, memory_id):
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
return get_json_result(message=True)
except Exception as e:
logging.error(e)
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/memories", methods=["GET"]) # noqa: F821
@manager.route("", methods=["GET"]) # noqa: F821
@login_required
async def list_memory():
args = request.args
@ -225,18 +161,13 @@ async def list_memory():
page = int(args.get("page", 1))
page_size = int(args.get("page_size", 50))
# make filter dict
filter_dict: dict = {"storage_type": storage_type}
filter_dict = {"memory_type": memory_types, "storage_type": storage_type}
if not tenant_ids:
# restrict to current user's tenants
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
else:
if len(tenant_ids) == 1 and ',' in tenant_ids[0]:
tenant_ids = tenant_ids[0].split(',')
filter_dict["tenant_id"] = tenant_ids
if memory_types and len(memory_types) == 1 and ',' in memory_types[0]:
memory_types = memory_types[0].split(',')
filter_dict["memory_type"] = memory_types
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
@ -247,7 +178,7 @@ async def list_memory():
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/memories/<memory_id>/config", methods=["GET"]) # noqa: F821
@manager.route("/<memory_id>/config", methods=["GET"]) # noqa: F821
@login_required
async def get_memory_config(memory_id):
memory = MemoryService.get_with_owner_name_by_id(memory_id)
@ -256,13 +187,11 @@ async def get_memory_config(memory_id):
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
@manager.route("/memories/<memory_id>", methods=["GET"]) # noqa: F821
@manager.route("/<memory_id>", methods=["GET"]) # noqa: F821
@login_required
async def get_memory_detail(memory_id):
args = request.args
agent_ids = args.getlist("agent_id")
if len(agent_ids) == 1 and ',' in agent_ids[0]:
agent_ids = agent_ids[0].split(',')
keywords = args.get("keywords", "")
keywords = keywords.strip()
page = int(args.get("page", 1))
@ -273,19 +202,9 @@ async def get_memory_detail(memory_id):
messages = MessageService.list_message(
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
agent_name_mapping = {}
extract_task_mapping = {}
if messages["message_list"]:
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id])
if task_list:
task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task
for task in task_list:
# the 'digest' field carries the source_id when a task is created, so use 'digest' as key
extract_task_mapping.update({int(task["digest"]): task})
for message in messages["message_list"]:
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
message["task"] = extract_task_mapping.get(message["message_id"], {})
for extract_msg in message["extract"]:
extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown")
return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True)

View File

@ -20,35 +20,49 @@ from common.time_utils import current_timestamp, timestamp_to_date
from memory.services.messages import MessageService
from api.db.joint_services import memory_message_service
from api.db.joint_services.memory_message_service import query_message
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
from common.constants import RetCode
@manager.route("/messages", methods=["POST"]) # noqa: F821
@manager.route("", methods=["POST"]) # noqa: F821
@login_required
@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response")
async def add_message():
req = await get_request_json()
memory_ids = req["memory_id"]
agent_id = req["agent_id"]
session_id = req["session_id"]
user_id = req["user_id"] if req.get("user_id") else ""
user_input = req["user_input"]
agent_response = req["agent_response"]
message_dict = {
"user_id": req.get("user_id"),
"agent_id": req["agent_id"],
"session_id": req["session_id"],
"user_input": req["user_input"],
"agent_response": req["agent_response"],
}
res = []
for memory_id in memory_ids:
success, msg = await memory_message_service.save_to_memory(
memory_id,
{
"user_id": user_id,
"agent_id": agent_id,
"session_id": session_id,
"user_input": user_input,
"agent_response": agent_response
}
)
res.append({
"memory_id": memory_id,
"success": success,
"message": msg
})
res, msg = await memory_message_service.queue_save_to_memory_task(memory_ids, message_dict)
if all([r["success"] for r in res]):
return get_json_result(message="Successfully added to memories.")
if res:
return get_json_result(message=msg)
return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add. Detail:" + msg)
return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add.", data=res)
@manager.route("/messages/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821
@manager.route("/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def forget_message(memory_id: str, message_id: int):
@ -67,7 +81,7 @@ async def forget_message(memory_id: str, message_id: int):
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.")
@manager.route("/messages/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821
@manager.route("/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821
@login_required
@validate_request("status")
async def update_message(memory_id: str, message_id: int):
@ -87,17 +101,16 @@ async def update_message(memory_id: str, message_id: int):
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.")
@manager.route("/messages/search", methods=["GET"]) # noqa: F821
@manager.route("/search", methods=["GET"]) # noqa: F821
@login_required
async def search_message():
args = request.args
print(args, flush=True)
empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)]
if empty_fields:
return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.")
memory_ids = args.getlist("memory_id")
if len(memory_ids) == 1 and ',' in memory_ids[0]:
memory_ids = memory_ids[0].split(',')
query = args.get("query")
similarity_threshold = float(args.get("similarity_threshold", 0.2))
keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7))
@ -116,17 +129,15 @@ async def search_message():
"keywords_similarity_weight": keywords_similarity_weight,
"top_n": top_n
}
res = memory_message_service.query_message(filter_dict, params)
res = query_message(filter_dict, params)
return get_json_result(message=True, data=res)
@manager.route("/messages", methods=["GET"]) # noqa: F821
@manager.route("", methods=["GET"]) # noqa: F821
@login_required
async def get_messages():
args = request.args
memory_ids = args.getlist("memory_id")
if len(memory_ids) == 1 and ',' in memory_ids[0]:
memory_ids = memory_ids[0].split(',')
agent_id = args.get("agent_id", "")
session_id = args.get("session_id", "")
limit = int(args.get("limit", 10))
@ -144,7 +155,7 @@ async def get_messages():
return get_json_result(message=True, data=res)
@manager.route("/messages/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821
@manager.route("/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821
@login_required
async def get_message_content(memory_id:str, message_id: int):
memory = MemoryService.get_by_memory_id(memory_id)

View File

@ -18,7 +18,7 @@
from quart import Response
from api.apps import login_required
from api.utils.api_utils import get_json_result
from agent.plugin import GlobalPluginManager
from plugin import GlobalPluginManager
@manager.route('/llm_tools', methods=['GET']) # noqa: F821

View File

@ -51,7 +51,7 @@ def list_agents(tenant_id):
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 30))
order_by = request.args.get("orderby", "update_time")
if str(request.args.get("desc","false")).lower() == "false":
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
desc = False
else:
desc = True
@ -162,7 +162,6 @@ async def webhook(agent_id: str):
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
# 4. Check webhook configuration in DSL
webhook_cfg = {}
components = dsl.get("components", {})
for k, _ in components.items():
cpn_obj = components[k]["obj"]
@ -395,7 +394,7 @@ async def webhook(agent_id: str):
if not isinstance(cvs.dsl, str):
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
try:
canvas = Canvas(dsl, cvs.user_id, agent_id, canvas_id=agent_id)
canvas = Canvas(dsl, cvs.user_id, agent_id)
except Exception as e:
resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e))
resp.status_code = RetCode.BAD_REQUEST

View File

@ -51,9 +51,7 @@ async def create(tenant_id):
req["llm_id"] = llm.pop("model_name")
if req.get("llm_id") is not None:
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
model_type = llm.get("model_type")
model_type = model_type if model_type in ["chat", "image2text"] else "chat"
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type):
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"):
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
req["llm_setting"] = req.pop("llm")
e, tenant = TenantService.get_by_id(tenant_id)
@ -176,7 +174,7 @@ async def update(tenant_id, chat_id):
req["llm_id"] = llm.pop("model_name")
if req.get("llm_id") is not None:
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
model_type = llm.get("model_type")
model_type = llm.pop("model_type")
model_type = model_type if model_type in ["chat", "image2text"] else "chat"
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type):
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")

View File

@ -233,15 +233,6 @@ async def delete(tenant_id):
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
# Drop index for this dataset
try:
from rag.nlp import search
idxnm = search.index_name(kb.tenant_id)
settings.docStoreConn.delete_idx(idxnm, kb_id)
except Exception as e:
logging.warning(f"Failed to drop index for dataset {kb_id}: {e}")
if not KnowledgebaseService.delete_by_id(kb_id):
errors.append(f"Delete dataset error for {kb_id}")
continue
@ -490,7 +481,7 @@ def list_datasets(tenant_id):
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
@token_required
async def knowledge_graph(tenant_id, dataset_id):
def knowledge_graph(tenant_id, dataset_id):
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
@ -506,7 +497,7 @@ async def knowledge_graph(tenant_id, dataset_id):
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
return get_result(data=obj)
sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
if not len(sres.ids):
return get_result(data=obj)

View File

@ -18,7 +18,6 @@ import logging
from quart import jsonify
from api.db.services.document_service import DocumentService
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from common.metadata_utils import meta_filter, convert_conditions
@ -122,7 +121,7 @@ async def retrieval(tenant_id):
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
top = int(retrieval_setting.get("top_k", 1024))
metadata_condition = req.get("metadata_condition", {}) or {}
metas = DocMetadataService.get_meta_by_kbs([kb_id])
metas = DocumentService.get_meta_by_kbs([kb_id])
doc_ids = []
try:
@ -136,7 +135,7 @@ async def retrieval(tenant_id):
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
if not doc_ids and metadata_condition:
doc_ids = ["-999"]
ranks = await settings.retriever.retrieval(
ranks = settings.retriever.retrieval(
question,
embd_mdl,
kb.tenant_id,
@ -149,10 +148,9 @@ async def retrieval(tenant_id):
doc_ids=doc_ids,
rank_feature=label_question(question, [kb])
)
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], [tenant_id])
if use_kg:
ck = await settings.kg_retriever.retrieval(question,
ck = settings.kg_retriever.retrieval(question,
[tenant_id],
[kb_id],
embd_mdl,

View File

@ -29,7 +29,6 @@ from api.constants import FILE_NAME_LEN_LIMIT
from api.db import FileType
from api.db.db_models import File, Task
from api.db.services.document_service import DocumentService
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -256,8 +255,7 @@ async def update_doc(tenant_id, dataset_id, document_id):
if "meta_fields" in req:
if not isinstance(req["meta_fields"], dict):
return get_error_data_result(message="meta_fields must be a dictionary")
if not DocMetadataService.update_document_metadata(document_id, req["meta_fields"]):
return get_error_data_result(message="Failed to update metadata")
DocumentService.update_meta_fields(document_id, req["meta_fields"])
if "name" in req and req["name"] != doc.name:
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
@ -570,7 +568,7 @@ def list_docs(dataset_id, tenant_id):
doc_ids_filter = None
if metadata_condition:
metas = DocMetadataService.get_flatted_meta_by_kbs([dataset_id])
metas = DocumentService.get_flatted_meta_by_kbs([dataset_id])
doc_ids_filter = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
if metadata_condition.get("conditions") and not doc_ids_filter:
return get_result(data={"total": 0, "docs": []})
@ -608,12 +606,12 @@ def list_docs(dataset_id, tenant_id):
@manager.route("/datasets/<dataset_id>/metadata/summary", methods=["GET"]) # noqa: F821
@token_required
async def metadata_summary(dataset_id, tenant_id):
def metadata_summary(dataset_id, tenant_id):
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
req = await get_request_json()
try:
summary = DocMetadataService.get_metadata_summary(dataset_id, req.get("doc_ids"))
summary = DocumentService.get_metadata_summary(dataset_id)
return get_result(data={"summary": summary})
except Exception as e:
return server_error_response(e)
@ -650,23 +648,23 @@ async def metadata_batch_update(dataset_id, tenant_id):
if not isinstance(d, dict) or not d.get("key"):
return get_error_data_result(message="Each delete requires key.")
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
target_doc_ids = set(kb_doc_ids)
if document_ids:
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
target_doc_ids = set(kb_doc_ids)
invalid_ids = set(document_ids) - set(kb_doc_ids)
if invalid_ids:
return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}")
target_doc_ids = set(document_ids)
if metadata_condition:
metas = DocMetadataService.get_flatted_meta_by_kbs([dataset_id])
metas = DocumentService.get_flatted_meta_by_kbs([dataset_id])
filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
target_doc_ids = target_doc_ids & filtered_ids
if metadata_condition.get("conditions") and not target_doc_ids:
return get_result(data={"updated": 0, "matched_docs": 0})
target_doc_ids = list(target_doc_ids)
updated = DocMetadataService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes)
updated = DocumentService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes)
return get_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
@manager.route("/datasets/<dataset_id>/documents", methods=["DELETE"]) # noqa: F821
@ -937,7 +935,7 @@ async def stop_parsing(tenant_id, dataset_id):
@manager.route("/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["GET"]) # noqa: F821
@token_required
async def list_chunks(tenant_id, dataset_id, document_id):
def list_chunks(tenant_id, dataset_id, document_id):
"""
List chunks of a document.
---
@ -1083,7 +1081,7 @@ async def list_chunks(tenant_id, dataset_id, document_id):
_ = Chunk(**final_chunk)
elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
sres = await settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total
for id in sres.ids:
d = {
@ -1288,9 +1286,6 @@ async def rm_chunk(tenant_id, dataset_id, document_id):
if "chunk_ids" in req:
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
condition["id"] = unique_chunk_ids
else:
unique_chunk_ids = []
duplicate_messages = []
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
if chunk_number != 0:
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
@ -1516,36 +1511,25 @@ async def retrieval_test(tenant_id):
page = int(req.get("page", 1))
size = int(req.get("page_size", 30))
question = req["question"]
# Trim whitespace and validate question
if isinstance(question, str):
question = question.strip()
# Return empty result if question is empty or whitespace-only
if not question:
return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}})
doc_ids = req.get("document_ids", [])
use_kg = req.get("use_kg", False)
toc_enhance = req.get("toc_enhance", False)
langs = req.get("cross_languages", [])
if not isinstance(doc_ids, list):
return get_error_data_result("`documents` should be a list")
if doc_ids:
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
for doc_id in doc_ids:
if doc_id not in doc_ids_list:
return get_error_data_result(f"The datasets don't own the document {doc_id}")
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
for doc_id in doc_ids:
if doc_id not in doc_ids_list:
return get_error_data_result(f"The datasets don't own the document {doc_id}")
if not doc_ids:
metadata_condition = req.get("metadata_condition")
if metadata_condition:
metas = DocMetadataService.get_meta_by_kbs(kb_ids)
doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
# If metadata_condition has conditions but no docs match, return empty result
if not doc_ids and metadata_condition.get("conditions"):
return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}})
if metadata_condition and not doc_ids:
doc_ids = ["-999"]
else:
# If doc_ids is None all documents of the datasets are used
doc_ids = None
metadata_condition = req.get("metadata_condition", {}) or {}
metas = DocumentService.get_meta_by_kbs(kb_ids)
doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
# If metadata_condition has conditions but no docs match, return empty result
if not doc_ids and metadata_condition.get("conditions"):
return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}})
if metadata_condition and not doc_ids:
doc_ids = ["-999"]
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))
@ -1571,7 +1555,7 @@ async def retrieval_test(tenant_id):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += await keyword_extraction(chat_mdl, question)
ranks = await settings.retriever.retrieval(
ranks = settings.retriever.retrieval(
question,
embd_mdl,
tenant_ids,
@ -1588,12 +1572,11 @@ async def retrieval_test(tenant_id):
)
if toc_enhance:
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
cks = await settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
cks = settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
if cks:
ranks["chunks"] = cks
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
if use_kg:
ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
import asyncio
import pathlib
import re
from quart import request, make_response
@ -23,7 +24,7 @@ from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
from common.misc_utils import get_uuid, thread_pool_exec
from common.misc_utils import get_uuid
from api.db import FileType
from api.db.services import duplicate_name
from api.db.services.file_service import FileService
@ -32,6 +33,7 @@ from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings
from common.constants import RetCode
@manager.route('/file/upload', methods=['POST']) # noqa: F821
@token_required
async def upload(tenant_id):
@ -638,7 +640,7 @@ async def get(tenant_id, file_id):
async def download_attachment(tenant_id, attachment_id):
try:
ext = request.args.get("ext", "markdown")
data = await thread_pool_exec(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))

View File

@ -18,14 +18,9 @@ import copy
import re
import time
import os
import tempfile
import logging
import tiktoken
from quart import Response, jsonify, request
from common.token_utils import num_tokens_from_string
from agent.canvas import Canvas
from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService
@ -35,12 +30,12 @@ from api.db.services.conversation_service import ConversationService
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
from api.db.services.conversation_service import async_completion as rag_completion
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter
from api.db.services.search_service import SearchService
from api.db.services.user_service import TenantService,UserTenantService
from api.db.services.user_service import UserTenantService
from common.misc_utils import get_uuid
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
get_result, get_request_json, server_error_response, token_required, validate_request
@ -65,7 +60,7 @@ async def create(tenant_id, chat_id):
"name": req.get("name", "New session"),
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}],
"user_id": req.get("user_id", ""),
"reference": [],
"reference": [{}],
}
if not conv.get("name"):
return get_error_data_result(message="`name` can not be empty.")
@ -93,7 +88,7 @@ async def create_agent_session(tenant_id, agent_id):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
session_id = get_uuid()
canvas = Canvas(cvs.dsl, tenant_id, agent_id, canvas_id=cvs.id)
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
canvas.reset()
cvs.dsl = json.loads(str(canvas))
@ -147,7 +142,7 @@ async def chat_completion(tenant_id, chat_id):
return get_error_data_result(message="metadata_condition must be an object.")
if metadata_condition and req.get("question"):
metas = DocMetadataService.get_flatted_meta_by_kbs(dia.kb_ids or [])
metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
filtered_doc_ids = meta_filter(
metas,
convert_conditions(metadata_condition),
@ -192,7 +187,6 @@ async def chat_completion_openai_like(tenant_id, chat_id):
- If `stream` is True, the final answer and reference information will appear in the **last chunk** of the stream.
- If `stream` is False, the reference will be included in `choices[0].message.reference`.
- If `extra_body.reference_metadata.include` is True, each reference chunk may include `document_metadata` in both streaming and non-streaming responses.
Example usage:
@ -207,12 +201,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
Alternatively, you can use Python's `OpenAI` client:
NOTE: Streaming via `client.chat.completions.create(stream=True, ...)` does
not return `reference` currently. The only way to return `reference` is
non-stream mode with `with_raw_response`.
from openai import OpenAI
import json
model = "model"
client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/<chat_id>")
@ -220,20 +209,17 @@ async def chat_completion_openai_like(tenant_id, chat_id):
stream = True
reference = True
request_kwargs = dict(
model="model",
completion = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who are you?"},
{"role": "assistant", "content": "I am an AI assistant named..."},
{"role": "user", "content": "Can you tell me how to install neovim"},
],
stream=stream,
extra_body={
"reference": reference,
"reference_metadata": {
"include": True,
"fields": ["author", "year", "source"],
},
"metadata_condition": {
"logic": "and",
"conditions": [
@ -244,25 +230,19 @@ async def chat_completion_openai_like(tenant_id, chat_id):
}
]
}
},
}
)
if stream:
completion = client.chat.completions.create(stream=True, **request_kwargs)
for chunk in completion:
print(chunk)
for chunk in completion:
print(chunk)
if reference and chunk.choices[0].finish_reason == "stop":
print(f"Reference:\n{chunk.choices[0].delta.reference}")
print(f"Final content:\n{chunk.choices[0].delta.final_content}")
else:
resp = client.chat.completions.with_raw_response.create(
stream=False, **request_kwargs
)
print("status:", resp.http_response.status_code)
raw_text = resp.http_response.text
print("raw:", raw_text)
data = json.loads(raw_text)
print("assistant:", data["choices"][0]["message"].get("content"))
print("reference:", data["choices"][0]["message"].get("reference"))
print(completion.choices[0].message.content)
if reference:
print(completion.choices[0].message.reference)
"""
req = await get_request_json()
@ -271,13 +251,6 @@ async def chat_completion_openai_like(tenant_id, chat_id):
return get_error_data_result("extra_body must be an object.")
need_reference = bool(extra_body.get("reference", False))
reference_metadata = extra_body.get("reference_metadata") or {}
if reference_metadata and not isinstance(reference_metadata, dict):
return get_error_data_result("reference_metadata must be an object.")
include_reference_metadata = bool(reference_metadata.get("include", False))
metadata_fields = reference_metadata.get("fields")
if metadata_fields is not None and not isinstance(metadata_fields, list):
return get_error_data_result("reference_metadata.fields must be an array.")
messages = req.get("messages", [])
# To prevent empty [] input
@ -288,7 +261,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
prompt = messages[-1]["content"]
# Treat context tokens as reasoning tokens
context_token_used = sum(num_tokens_from_string(message["content"]) for message in messages)
context_token_used = sum(len(message["content"]) for message in messages)
dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value)
if not dia:
@ -301,7 +274,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
doc_ids_str = None
if metadata_condition:
metas = DocMetadataService.get_flatted_meta_by_kbs(dia.kb_ids or [])
metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
filtered_doc_ids = meta_filter(
metas,
convert_conditions(metadata_condition),
@ -331,12 +304,9 @@ async def chat_completion_openai_like(tenant_id, chat_id):
# The choices field on the last chunk will always be an empty array [].
async def streamed_response_generator(chat_id, dia, msg):
token_used = 0
answer_cache = ""
reasoning_cache = ""
last_ans = {}
full_content = ""
full_reasoning = ""
final_answer = None
final_reference = None
in_think = False
response = {
"id": f"chatcmpl-{chat_id}",
"choices": [
@ -366,30 +336,47 @@ async def chat_completion_openai_like(tenant_id, chat_id):
chat_kwargs["doc_ids"] = doc_ids_str
async for ans in async_chat(dia, msg, True, **chat_kwargs):
last_ans = ans
if ans.get("final"):
if ans.get("answer"):
full_content = ans["answer"]
final_answer = ans.get("answer") or full_content
final_reference = ans.get("reference", {})
continue
if ans.get("start_to_think"):
in_think = True
continue
if ans.get("end_to_think"):
in_think = False
continue
delta = ans.get("answer") or ""
if not delta:
continue
token_used += num_tokens_from_string(delta)
if in_think:
full_reasoning += delta
response["choices"][0]["delta"]["reasoning_content"] = delta
response["choices"][0]["delta"]["content"] = None
answer = ans["answer"]
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
if reasoning_match:
reasoning_part = reasoning_match.group(1)
content_part = answer[reasoning_match.end() :]
else:
reasoning_part = ""
content_part = answer
reasoning_incremental = ""
if reasoning_part:
if reasoning_part.startswith(reasoning_cache):
reasoning_incremental = reasoning_part.replace(reasoning_cache, "", 1)
else:
reasoning_incremental = reasoning_part
reasoning_cache = reasoning_part
content_incremental = ""
if content_part:
if content_part.startswith(answer_cache):
content_incremental = content_part.replace(answer_cache, "", 1)
else:
content_incremental = content_part
answer_cache = content_part
token_used += len(reasoning_incremental) + len(content_incremental)
if not any([reasoning_incremental, content_incremental]):
continue
if reasoning_incremental:
response["choices"][0]["delta"]["reasoning_content"] = reasoning_incremental
else:
full_content += delta
response["choices"][0]["delta"]["content"] = delta
response["choices"][0]["delta"]["reasoning_content"] = None
if content_incremental:
response["choices"][0]["delta"]["content"] = content_incremental
else:
response["choices"][0]["delta"]["content"] = None
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
except Exception as e:
response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
@ -399,16 +386,10 @@ async def chat_completion_openai_like(tenant_id, chat_id):
response["choices"][0]["delta"]["content"] = None
response["choices"][0]["delta"]["reasoning_content"] = None
response["choices"][0]["finish_reason"] = "stop"
prompt_tokens = num_tokens_from_string(prompt)
response["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": token_used, "total_tokens": prompt_tokens + token_used}
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
if need_reference:
reference_payload = final_reference if final_reference is not None else last_ans.get("reference", [])
response["choices"][0]["delta"]["reference"] = _build_reference_chunks(
reference_payload,
include_metadata=include_reference_metadata,
metadata_fields=metadata_fields,
)
response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
yield "data:[DONE]\n\n"
@ -435,12 +416,12 @@ async def chat_completion_openai_like(tenant_id, chat_id):
"created": int(time.time()),
"model": req.get("model", ""),
"usage": {
"prompt_tokens": num_tokens_from_string(prompt),
"completion_tokens": num_tokens_from_string(content),
"total_tokens": num_tokens_from_string(prompt) + num_tokens_from_string(content),
"prompt_tokens": len(prompt),
"completion_tokens": len(content),
"total_tokens": len(prompt) + len(content),
"completion_tokens_details": {
"reasoning_tokens": context_token_used,
"accepted_prediction_tokens": num_tokens_from_string(content),
"accepted_prediction_tokens": len(content),
"rejected_prediction_tokens": 0, # 0 for simplicity
},
},
@ -457,11 +438,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
],
}
if need_reference:
response["choices"][0]["message"]["reference"] = _build_reference_chunks(
answer.get("reference", {}),
include_metadata=include_reference_metadata,
metadata_fields=metadata_fields,
)
response["choices"][0]["message"]["reference"] = chunks_format(answer.get("reference", []))
return jsonify(response)
@ -471,6 +448,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
@token_required
async def agents_completion_openai_compatibility(tenant_id, agent_id):
req = await get_request_json()
tiktoken_encode = tiktoken.get_encoding("cl100k_base")
messages = req.get("messages", [])
if not messages:
return get_error_data_result("You must provide at least one message.")
@ -478,7 +456,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
return get_error_data_result(f"You don't own the agent {agent_id}")
filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
prompt_tokens = sum(num_tokens_from_string(m["content"]) for m in filtered_messages)
prompt_tokens = sum(len(tiktoken_encode.encode(m["content"])) for m in filtered_messages)
if not filtered_messages:
return jsonify(
get_data_openai(
@ -486,7 +464,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
content="No valid messages found (user or assistant).",
finish_reason="stop",
model=req.get("model", ""),
completion_tokens=num_tokens_from_string("No valid messages found (user or assistant)."),
completion_tokens=len(tiktoken_encode.encode("No valid messages found (user or assistant).")),
prompt_tokens=prompt_tokens,
)
)
@ -965,7 +943,6 @@ async def chatbots_inputs(dialog_id):
"title": dialog.name,
"avatar": dialog.icon,
"prologue": dialog.prompt_config.get("prologue", ""),
"has_tavily_key": bool(dialog.prompt_config.get("tavily_api_key", "").strip()),
}
)
@ -1009,7 +986,7 @@ async def begin_inputs(agent_id):
if not e:
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id, canvas_id=cvs.id)
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id)
return get_result(
data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"),
"prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
@ -1081,13 +1058,11 @@ async def retrieval_test_embedded():
use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", [])
rerank_id = req.get("rerank_id", "")
tenant_id = objs[0].tenant_id
if not tenant_id:
return get_error_data_result(message="permission denined.")
async def _retrieval():
nonlocal similarity_threshold, vector_similarity_weight, top, rerank_id
local_doc_ids = list(doc_ids) if doc_ids else []
tenant_ids = []
_question = question
@ -1099,22 +1074,13 @@ async def retrieval_test_embedded():
meta_data_filter = search_config.get("meta_data_filter", {})
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
# Apply search_config settings if not explicitly provided in request
if not req.get("similarity_threshold"):
similarity_threshold = float(search_config.get("similarity_threshold", similarity_threshold))
if not req.get("vector_similarity_weight"):
vector_similarity_weight = float(search_config.get("vector_similarity_weight", vector_similarity_weight))
if not req.get("top_k"):
top = int(search_config.get("top_k", top))
if not req.get("rerank_id"):
rerank_id = search_config.get("rerank_id", "")
else:
meta_data_filter = req.get("meta_data_filter") or {}
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
if meta_data_filter:
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
metas = DocumentService.get_meta_by_kbs(kb_ids)
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, _question, chat_mdl, local_doc_ids)
tenants = UserTenantService.query(user_id=tenant_id)
@ -1137,20 +1103,20 @@ async def retrieval_test_embedded():
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
rerank_mdl = None
if rerank_id:
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=rerank_id)
if req.get("rerank_id"):
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
_question += await keyword_extraction(chat_mdl, _question)
labels = label_question(_question, [kb])
ranks = await settings.retriever.retrieval(
ranks = settings.retriever.retrieval(
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
)
if use_kg:
ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
@ -1267,135 +1233,3 @@ async def mindmap():
if "error" in mind_map:
return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map)
@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821
@token_required
async def sequence2txt(tenant_id):
req = await request.form
stream_mode = req.get("stream", "false").lower() == "true"
files = await request.files
if "file" not in files:
return get_error_data_result(message="Missing 'file' in multipart form-data")
uploaded = files["file"]
ALLOWED_EXTS = {
".wav", ".mp3", ".m4a", ".aac",
".flac", ".ogg", ".webm",
".opus", ".wma"
}
filename = uploaded.filename or ""
suffix = os.path.splitext(filename)[-1].lower()
if suffix not in ALLOWED_EXTS:
return get_error_data_result(message=
f"Unsupported audio format: {suffix}. "
f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}"
)
fd, temp_audio_path = tempfile.mkstemp(suffix=suffix)
os.close(fd)
await uploaded.save(temp_audio_path)
tenants = TenantService.get_info_by(tenant_id)
if not tenants:
return get_error_data_result(message="Tenant not found!")
asr_id = tenants[0]["asr_id"]
if not asr_id:
return get_error_data_result(message="No default ASR model is set")
asr_mdl=LLMBundle(tenants[0]["tenant_id"], LLMType.SPEECH2TEXT, asr_id)
if not stream_mode:
text = asr_mdl.transcription(temp_audio_path)
try:
os.remove(temp_audio_path)
except Exception as e:
logging.error(f"Failed to remove temp audio file: {str(e)}")
return get_json_result(data={"text": text})
async def event_stream():
try:
for evt in asr_mdl.stream_transcription(temp_audio_path):
yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n"
except Exception as e:
err = {"event": "error", "text": str(e)}
yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n"
finally:
try:
os.remove(temp_audio_path)
except Exception as e:
logging.error(f"Failed to remove temp audio file: {str(e)}")
return Response(event_stream(), content_type="text/event-stream")
@manager.route("/tts", methods=["POST"]) # noqa: F821
@token_required
async def tts(tenant_id):
req = await get_request_json()
text = req["text"]
tenants = TenantService.get_info_by(tenant_id)
if not tenants:
return get_error_data_result(message="Tenant not found!")
tts_id = tenants[0]["tts_id"]
if not tts_id:
return get_error_data_result(message="No default TTS model is set")
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
def stream_audio():
try:
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
for chunk in tts_mdl.tts(txt):
yield chunk
except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
return resp
def _build_reference_chunks(reference, include_metadata=False, metadata_fields=None):
chunks = chunks_format(reference)
if not include_metadata:
return chunks
doc_ids_by_kb = {}
for chunk in chunks:
kb_id = chunk.get("dataset_id")
doc_id = chunk.get("document_id")
if not kb_id or not doc_id:
continue
doc_ids_by_kb.setdefault(kb_id, set()).add(doc_id)
if not doc_ids_by_kb:
return chunks
meta_by_doc = {}
for kb_id, doc_ids in doc_ids_by_kb.items():
meta_map = DocMetadataService.get_metadata_for_documents(list(doc_ids), kb_id)
if meta_map:
meta_by_doc.update(meta_map)
if metadata_fields is not None:
metadata_fields = {f for f in metadata_fields if isinstance(f, str)}
if not metadata_fields:
return chunks
for chunk in chunks:
doc_id = chunk.get("document_id")
if not doc_id:
continue
meta = meta_by_doc.get(doc_id)
if not meta:
continue
if metadata_fields is not None:
meta = {k: v for k, v in meta.items() if k in metadata_fields}
if meta:
chunk["document_metadata"] = meta
return chunks

View File

@ -35,7 +35,7 @@ from timeit import default_timer as timer
from rag.utils.redis_conn import REDIS_CONN
from quart import jsonify
from api.utils.health_utils import run_health_checks, get_oceanbase_status
from api.utils.health_utils import run_health_checks
from common import settings
@ -177,47 +177,11 @@ def healthz():
return jsonify(result), (200 if all_ok else 500)
@manager.route("/ping", methods=["GET"]) # noqa: F821
async def ping():
@manager.route("/ping", methods=["GET"]) # noqa: F821
def ping():
return "pong", 200
@manager.route("/oceanbase/status", methods=["GET"]) # noqa: F821
@login_required
def oceanbase_status():
"""
Get OceanBase health status and performance metrics.
---
tags:
- System
security:
- ApiKeyAuth: []
responses:
200:
description: OceanBase status retrieved successfully.
schema:
type: object
properties:
status:
type: string
description: Status (alive/timeout).
message:
type: object
description: Detailed status information including health and performance metrics.
"""
try:
status_info = get_oceanbase_status()
return get_json_result(data=status_info)
except Exception as e:
return get_json_result(
data={
"status": "error",
"message": f"Failed to get OceanBase status: {str(e)}"
},
code=500
)
@manager.route("/new_token", methods=["POST"]) # noqa: F821
@login_required
def new_token():
@ -249,7 +213,7 @@ def new_token():
if not tenants:
return get_data_error_result(message="Tenant not found!")
tenant_id = [tenant for tenant in tenants if tenant.role == "owner"][0].tenant_id
tenant_id = [tenant for tenant in tenants if tenant.role == 'owner'][0].tenant_id
obj = {
"tenant_id": tenant_id,
"token": generate_confirmation_token(),
@ -304,12 +268,13 @@ def token_list():
if not tenants:
return get_data_error_result(message="Tenant not found!")
tenant_id = [tenant for tenant in tenants if tenant.role == "owner"][0].tenant_id
tenant_id = [tenant for tenant in tenants if tenant.role == 'owner'][0].tenant_id
objs = APITokenService.query(tenant_id=tenant_id)
objs = [o.to_dict() for o in objs]
for o in objs:
if not o["beta"]:
o["beta"] = generate_confirmation_token().replace("ragflow-", "")[:32]
o["beta"] = generate_confirmation_token().replace(
"ragflow-", "")[:32]
APITokenService.filter_update([APIToken.tenant_id == tenant_id, APIToken.token == o["token"]], o)
return get_json_result(data=objs)
except Exception as e:
@ -342,19 +307,13 @@ def rm(token):
type: boolean
description: Deletion status.
"""
try:
tenants = UserTenantService.query(user_id=current_user.id)
if not tenants:
return get_data_error_result(message="Tenant not found!")
tenant_id = tenants[0].tenant_id
APITokenService.filter_delete([APIToken.tenant_id == tenant_id, APIToken.token == token])
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
APITokenService.filter_delete(
[APIToken.tenant_id == current_user.id, APIToken.token == token]
)
return get_json_result(data=True)
@manager.route("/config", methods=["GET"]) # noqa: F821
@manager.route('/config', methods=['GET']) # noqa: F821
def get_config():
"""
Get system configuration.
@ -371,4 +330,6 @@ def get_config():
type: integer 0 means disabled, 1 means enabled
description: Whether user registration is enabled
"""
return get_json_result(data={"registerEnabled": settings.REGISTER_ENABLED})
return get_json_result(data={
"registerEnabled": settings.REGISTER_ENABLED
})

View File

@ -98,6 +98,8 @@ async def login():
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
email = json_body.get("email", "")
if email == "admin@ragflow.io":
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Default admin account cannot be used to login normal services!")
users = UserService.query(email=email)
if not users:

View File

@ -48,7 +48,6 @@ AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_ac
class TextFieldType(Enum):
MYSQL = "LONGTEXT"
OCEANBASE = "LONGTEXT"
POSTGRES = "TEXT"
@ -282,11 +281,7 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
except Exception as e:
logging.error(f"Failed to reconnect: {e}")
time.sleep(0.1)
try:
self.connect()
except Exception as e2:
logging.error(f"Failed to reconnect on second attempt: {e2}")
raise
self.connect()
def begin(self):
for attempt in range(self.max_retries + 1):
@ -357,11 +352,7 @@ class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
except Exception as e:
logging.error(f"Failed to reconnect to PostgreSQL: {e}")
time.sleep(0.1)
try:
self.connect()
except Exception as e2:
logging.error(f"Failed to reconnect to PostgreSQL on second attempt: {e2}")
raise
self.connect()
def begin(self):
for attempt in range(self.max_retries + 1):
@ -384,95 +375,13 @@ class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
return None
class RetryingPooledOceanBaseDatabase(PooledMySQLDatabase):
"""Pooled OceanBase database with retry mechanism.
OceanBase is compatible with MySQL protocol, so we inherit from PooledMySQLDatabase.
This class provides connection pooling and automatic retry for connection issues.
"""
def __init__(self, *args, **kwargs):
self.max_retries = kwargs.pop("max_retries", 5)
self.retry_delay = kwargs.pop("retry_delay", 1)
super().__init__(*args, **kwargs)
def execute_sql(self, sql, params=None, commit=True):
for attempt in range(self.max_retries + 1):
try:
return super().execute_sql(sql, params, commit)
except (OperationalError, InterfaceError) as e:
# OceanBase/MySQL specific error codes
# 2013: Lost connection to MySQL server during query
# 2006: MySQL server has gone away
error_codes = [2013, 2006]
error_messages = ['', 'Lost connection', 'gone away']
should_retry = (
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
any(msg in str(e).lower() for msg in error_messages) or
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
)
if should_retry and attempt < self.max_retries:
logging.warning(
f"OceanBase connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
)
self._handle_connection_loss()
time.sleep(self.retry_delay * (2 ** attempt))
else:
logging.error(f"OceanBase execution failure: {e}")
raise
return None
def _handle_connection_loss(self):
try:
self.close()
except Exception:
pass
try:
self.connect()
except Exception as e:
logging.error(f"Failed to reconnect to OceanBase: {e}")
time.sleep(0.1)
try:
self.connect()
except Exception as e2:
logging.error(f"Failed to reconnect to OceanBase on second attempt: {e2}")
raise
def begin(self):
for attempt in range(self.max_retries + 1):
try:
return super().begin()
except (OperationalError, InterfaceError) as e:
error_codes = [2013, 2006]
error_messages = ['', 'Lost connection']
should_retry = (
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
(str(e) in error_messages) or
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
)
if should_retry and attempt < self.max_retries:
logging.warning(
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
)
self._handle_connection_loss()
time.sleep(self.retry_delay * (2 ** attempt))
else:
raise
return None
class PooledDatabase(Enum):
MYSQL = RetryingPooledMySQLDatabase
OCEANBASE = RetryingPooledOceanBaseDatabase
POSTGRES = RetryingPooledPostgresqlDatabase
class DatabaseMigrator(Enum):
MYSQL = MySQLMigrator
OCEANBASE = MySQLMigrator
POSTGRES = PostgresqlMigrator
@ -631,7 +540,6 @@ class MysqlDatabaseLock:
class DatabaseLock(Enum):
MYSQL = MysqlDatabaseLock
OCEANBASE = MysqlDatabaseLock
POSTGRES = PostgresDatabaseLock
@ -879,6 +787,7 @@ class Document(DataBaseModel):
progress_msg = TextField(null=True, help_text="process message", default="")
process_begin_at = DateTimeField(null=True, index=True)
process_duration = FloatField(default=0)
meta_fields = JSONField(null=True, default={})
suffix = CharField(max_length=32, null=False, help_text="The real file extension suffix", index=True)
run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0", index=True)
@ -991,10 +900,8 @@ class APIToken(DataBaseModel):
class API4Conversation(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
name = CharField(max_length=255, null=True, help_text="conversation name", index=False)
dialog_id = CharField(max_length=32, null=False, index=True)
user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
exp_user_id = CharField(max_length=255, null=True, help_text="exp_user_id", index=True)
message = JSONField(null=True)
reference = JSONField(null=True, default=[])
tokens = IntegerField(default=0)
@ -1282,7 +1189,7 @@ class Memory(DataBaseModel):
permissions = CharField(max_length=16, null=False, index=True, help_text="me|team", default="me")
description = TextField(null=True, help_text="description")
memory_size = IntegerField(default=5242880, null=False, index=False)
forgetting_policy = CharField(max_length=32, null=False, default="FIFO", index=False, help_text="LRU|FIFO")
forgetting_policy = CharField(max_length=32, null=False, default="fifo", index=False, help_text="lru|fifo")
temperature = FloatField(default=0.5, index=False)
system_prompt = TextField(null=True, help_text="system prompt", index=False)
user_prompt = TextField(null=True, help_text="user prompt", index=False)
@ -1290,96 +1197,224 @@ class Memory(DataBaseModel):
class Meta:
db_table = "memory"
class SystemSettings(DataBaseModel):
name = CharField(max_length=128, primary_key=True)
source = CharField(max_length=32, null=False, index=False)
data_type = CharField(max_length=32, null=False, index=False)
value = TextField(null=False, help_text="Configuration value (JSON, string, etc.)")
class Meta:
db_table = "system_settings"
def alter_db_add_column(migrator, table_name, column_name, column_type):
try:
migrate(migrator.add_column(table_name, column_name, column_type))
except OperationalError as ex:
error_codes = [1060]
error_messages = ['Duplicate column name']
should_skip_error = (
(hasattr(ex, 'args') and ex.args and ex.args[0] in error_codes) or
(str(ex) in error_messages)
)
if not should_skip_error:
logging.critical(f"Failed to add {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name}, operation error: {ex}")
except Exception as ex:
logging.critical(f"Failed to add {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name}, error: {ex}")
pass
def alter_db_column_type(migrator, table_name, column_name, new_column_type):
try:
migrate(migrator.alter_column_type(table_name, column_name, new_column_type))
except Exception as ex:
logging.critical(f"Failed to alter {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name} type, error: {ex}")
pass
def alter_db_rename_column(migrator, table_name, old_column_name, new_column_name):
try:
migrate(migrator.rename_column(table_name, old_column_name, new_column_name))
except Exception:
# rename fail will lead to a weired error.
# logging.critical(f"Failed to rename {settings.DATABASE_TYPE.upper()}.{table_name} column {old_column_name} to {new_column_name}, error: {ex}")
pass
def migrate_db():
logging.disable(logging.ERROR)
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
alter_db_add_column(migrator, "file", "source_type", CharField(max_length=128, null=False, default="", help_text="where dose this document come from", index=True))
alter_db_add_column(migrator, "tenant", "rerank_id", CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID"))
alter_db_add_column(migrator, "dialog", "rerank_id", CharField(max_length=128, null=False, default="", help_text="default rerank model ID"))
alter_db_column_type(migrator, "dialog", "top_k", IntegerField(default=1024))
alter_db_add_column(migrator, "tenant_llm", "api_key", CharField(max_length=2048, null=True, help_text="API KEY", index=True))
alter_db_add_column(migrator, "api_token", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
alter_db_add_column(migrator, "tenant", "tts_id", CharField(max_length=256, null=True, help_text="default tts model ID", index=True))
alter_db_add_column(migrator, "api_4_conversation", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
alter_db_add_column(migrator, "task", "retry_count", IntegerField(default=0))
alter_db_column_type(migrator, "api_token", "dialog_id", CharField(max_length=32, null=True, index=True))
alter_db_add_column(migrator, "tenant_llm", "max_tokens", IntegerField(default=8192, index=True))
alter_db_add_column(migrator, "api_4_conversation", "dsl", JSONField(null=True, default={}))
alter_db_add_column(migrator, "knowledgebase", "pagerank", IntegerField(default=0, index=False))
alter_db_add_column(migrator, "api_token", "beta", CharField(max_length=255, null=True, index=True))
alter_db_add_column(migrator, "task", "digest", TextField(null=True, help_text="task digest", default=""))
alter_db_add_column(migrator, "task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default=""))
alter_db_add_column(migrator, "conversation", "user_id", CharField(max_length=255, null=True, help_text="user_id", index=True))
alter_db_add_column(migrator, "task", "task_type", CharField(max_length=32, null=False, default=""))
alter_db_add_column(migrator, "task", "priority", IntegerField(default=0))
alter_db_add_column(migrator, "user_canvas", "permission", CharField(max_length=16, null=False, help_text="me|team", default="me", index=True))
alter_db_add_column(migrator, "llm", "is_tools", BooleanField(null=False, help_text="support tools", default=False))
alter_db_add_column(migrator, "mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=dict))
alter_db_rename_column(migrator, "task", "process_duation", "process_duration")
alter_db_rename_column(migrator, "document", "process_duation", "process_duration")
alter_db_add_column(migrator, "document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True))
alter_db_add_column(migrator, "api_4_conversation", "errors", TextField(null=True, help_text="errors"))
alter_db_add_column(migrator, "dialog", "meta_data_filter", JSONField(null=True, default={}))
alter_db_column_type(migrator, "canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title"))
alter_db_column_type(migrator, "canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description"))
alter_db_add_column(migrator, "user_canvas", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True))
alter_db_add_column(migrator, "canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True))
alter_db_add_column(migrator, "knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True))
alter_db_add_column(migrator, "document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True))
alter_db_add_column(migrator, "knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True))
alter_db_add_column(migrator, "knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True))
alter_db_add_column(migrator, "knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True))
alter_db_add_column(migrator, "knowledgebase", "raptor_task_finish_at", CharField(null=True))
alter_db_add_column(migrator, "knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True))
alter_db_add_column(migrator, "knowledgebase", "mindmap_task_finish_at", CharField(null=True))
alter_db_column_type(migrator, "tenant_llm", "api_key", TextField(null=True, help_text="API KEY"))
alter_db_add_column(migrator, "tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True))
alter_db_add_column(migrator, "connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False))
alter_db_add_column(migrator, "llm_factories", "rank", IntegerField(default=0, index=False))
alter_db_add_column(migrator, "api_4_conversation", "name", CharField(max_length=255, null=True, help_text="conversation name", index=False))
alter_db_add_column(migrator, "api_4_conversation", "exp_user_id", CharField(max_length=255, null=True, help_text="exp_user_id", index=True))
# Migrate system_settings.value from CharField to TextField for longer sandbox configs
alter_db_column_type(migrator, "system_settings", "value", TextField(null=False, help_text="Configuration value (JSON, string, etc.)"))
try:
migrate(migrator.add_column("file", "source_type", CharField(max_length=128, null=False, default="", help_text="where dose this document come from", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("tenant", "rerank_id", CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID")))
except Exception:
pass
try:
migrate(migrator.add_column("dialog", "rerank_id", CharField(max_length=128, null=False, default="", help_text="default rerank model ID")))
except Exception:
pass
try:
migrate(migrator.add_column("dialog", "top_k", IntegerField(default=1024)))
except Exception:
pass
try:
migrate(migrator.alter_column_type("tenant_llm", "api_key", CharField(max_length=2048, null=True, help_text="API KEY", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("api_token", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("tenant", "tts_id", CharField(max_length=256, null=True, help_text="default tts model ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("api_4_conversation", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("task", "retry_count", IntegerField(default=0)))
except Exception:
pass
try:
migrate(migrator.alter_column_type("api_token", "dialog_id", CharField(max_length=32, null=True, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("tenant_llm", "max_tokens", IntegerField(default=8192, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("api_4_conversation", "dsl", JSONField(null=True, default={})))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "pagerank", IntegerField(default=0, index=False)))
except Exception:
pass
try:
migrate(migrator.add_column("api_token", "beta", CharField(max_length=255, null=True, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("task", "digest", TextField(null=True, help_text="task digest", default="")))
except Exception:
pass
try:
migrate(migrator.add_column("task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default="")))
except Exception:
pass
try:
migrate(migrator.add_column("conversation", "user_id", CharField(max_length=255, null=True, help_text="user_id", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("document", "meta_fields", JSONField(null=True, default={})))
except Exception:
pass
try:
migrate(migrator.add_column("task", "task_type", CharField(max_length=32, null=False, default="")))
except Exception:
pass
try:
migrate(migrator.add_column("task", "priority", IntegerField(default=0)))
except Exception:
pass
try:
migrate(migrator.add_column("user_canvas", "permission", CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("llm", "is_tools", BooleanField(null=False, help_text="support tools", default=False)))
except Exception:
pass
try:
migrate(migrator.add_column("mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=dict)))
except Exception:
pass
try:
migrate(migrator.rename_column("task", "process_duation", "process_duration"))
except Exception:
pass
try:
migrate(migrator.rename_column("document", "process_duation", "process_duration"))
except Exception:
pass
try:
migrate(migrator.add_column("document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("api_4_conversation", "errors", TextField(null=True, help_text="errors")))
except Exception:
pass
try:
migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
except Exception:
pass
try:
migrate(migrator.alter_column_type("canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title")))
except Exception:
pass
try:
migrate(migrator.alter_column_type("canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description")))
except Exception:
pass
try:
migrate(migrator.add_column("user_canvas", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "raptor_task_finish_at", CharField(null=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "mindmap_task_finish_at", CharField(null=True)))
except Exception:
pass
try:
migrate(migrator.alter_column_type("tenant_llm", "api_key", TextField(null=True, help_text="API KEY")))
except Exception:
pass
try:
migrate(migrator.add_column("tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False)))
except Exception:
pass
try:
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
except Exception:
pass
# RAG Evaluation tables
try:
migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "tenant_id", CharField(max_length=32, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "name", CharField(max_length=255, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "description", TextField(null=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "kb_ids", JSONField(null=False)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "created_by", CharField(max_length=32, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "create_time", BigIntegerField(null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "update_time", BigIntegerField(null=False)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1)))
except Exception:
pass
logging.disable(logging.NOTSET)

View File

@ -30,8 +30,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
from api.db.services.user_service import TenantService, UserTenantService
from api.db.services.system_settings_service import SystemSettingsService
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache, fix_missing_tokenized_memory
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache
from common.constants import LLMType
from common.file_utils import get_project_base_directory
from common import settings
@ -159,15 +158,13 @@ def add_graph_templates():
CanvasTemplateService.save(**cnvs)
except Exception:
CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
except Exception as e:
logging.exception(f"Add agent templates error: {e}")
except Exception:
logging.exception("Add agent templates error: ")
def init_web_data():
start_time = time.time()
init_table()
init_llm_factory()
# if not UserService.get_all().count():
# init_superuser()
@ -175,34 +172,8 @@ def init_web_data():
add_graph_templates()
init_message_id_sequence()
init_memory_size_cache()
fix_missing_tokenized_memory()
logging.info("init web data success:{}".format(time.time() - start_time))
def init_table():
# init system_settings
with open(os.path.join(get_project_base_directory(), "conf", "system_settings.json"), "r") as f:
records_from_file = json.load(f)["system_settings"]
record_index = {}
records_from_db = SystemSettingsService.get_all()
for index, record in enumerate(records_from_db):
record_index[record.name] = index
to_save = []
for record in records_from_file:
setting_name = record["name"]
if setting_name not in record_index:
to_save.append(record)
len_to_save = len(to_save)
if len_to_save > 0:
# not initialized
try:
SystemSettingsService.insert_many(to_save, len_to_save)
except Exception as e:
logging.exception("System settings init error: {}".format(e))
raise e
if __name__ == '__main__':
init_web_db()

View File

@ -16,14 +16,9 @@
import logging
from typing import List
from common import settings
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
from common.constants import MemoryType, LLMType
from common.doc_store.doc_store_base import FusionExpr
from common.misc_utils import get_uuid
from api.db.db_utils import bulk_insert_into_db
from api.db.db_models import Task
from api.db.services.task_service import TaskService
from api.db.services.memory_service import MemoryService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.llm_service import LLMBundle
@ -87,60 +82,36 @@ async def save_to_memory(memory_id: str, message_dict: dict):
"forget_at": None,
"status": True
} for content in extracted_content]]
return await embed_and_save(memory, message_list)
embedding_model = LLMBundle(tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
for idx, msg in enumerate(message_list):
msg["content_embed"] = vector_list[idx]
vector_dimension = len(vector_list[0])
if not MessageService.has_index(tenant_id, memory_id):
created = MessageService.create_index(tenant_id, memory_id, vector_size=vector_dimension)
if not created:
return False, "Failed to create message index."
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
current_memory_size = get_memory_size_cache(memory_id, tenant_id)
if new_msg_size + current_memory_size > memory.memory_size:
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
if memory.forgetting_policy == "fifo":
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory_id, tenant_id, size_to_delete)
MessageService.delete_message({"message_id": message_ids_to_delete}, tenant_id, memory_id)
decrease_memory_size_cache(memory_id, tenant_id, delete_size)
else:
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
fail_cases = MessageService.insert_message(message_list, tenant_id, memory_id)
if fail_cases:
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int, task_id: str=None):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
msg = f"Memory '{memory_id}' not found."
if task_id:
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
return False, msg
if memory.memory_type == MemoryType.RAW.value:
msg = f"Memory '{memory_id}' don't need to extract."
if task_id:
TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
return True, msg
tenant_id = memory.tenant_id
extracted_content = await extract_by_llm(
tenant_id,
memory.llm_id,
{"temperature": memory.temperature},
get_memory_type_human(memory.memory_type),
message_dict.get("user_input", ""),
message_dict.get("agent_response", ""),
task_id=task_id
)
message_list = [{
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
"message_type": content["message_type"],
"source_id": source_message_id,
"memory_id": memory_id,
"user_id": "",
"agent_id": message_dict["agent_id"],
"session_id": message_dict["session_id"],
"content": content["content"],
"valid_at": content["valid_at"],
"invalid_at": content["invalid_at"] if content["invalid_at"] else None,
"forget_at": None,
"status": True
} for content in extracted_content]
if not message_list:
msg = "No memory extracted from raw message."
if task_id:
TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
return True, msg
if task_id:
TaskService.update_progress(task_id, {"progress": 0.5, "progress_msg": timestamp_to_date(current_timestamp())+ " " + f"Extracted {len(message_list)} messages from raw dialogue."})
return await embed_and_save(memory, message_list, task_id)
increase_memory_size_cache(memory_id, tenant_id, new_msg_size)
return True, "Message saved successfully."
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
agent_response: str, system_prompt: str = "", user_prompt: str="", task_id: str=None) -> List[dict]:
agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]:
llm_type = TenantLLMService.llm_id2llm_type(llm_id)
if not llm_type:
raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
@ -155,12 +126,8 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
else:
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
llm = LLMBundle(tenant_id, llm_type, llm_id)
if task_id:
TaskService.update_progress(task_id, {"progress": 0.15, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared prompts and LLM."})
res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
res_json = get_json_result_from_llm_response(res)
if task_id:
TaskService.update_progress(task_id, {"progress": 0.35, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Get extracted result from LLM."})
return [{
"content": extracted_content["content"],
"valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]),
@ -169,51 +136,6 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
async def embed_and_save(memory, message_list: list[dict], task_id: str=None):
embedding_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
if task_id:
TaskService.update_progress(task_id, {"progress": 0.65, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared embedding model."})
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
for idx, msg in enumerate(message_list):
msg["content_embed"] = vector_list[idx]
if task_id:
TaskService.update_progress(task_id, {"progress": 0.85, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Embedded extracted content."})
vector_dimension = len(vector_list[0])
if not MessageService.has_index(memory.tenant_id, memory.id):
created = MessageService.create_index(memory.tenant_id, memory.id, vector_size=vector_dimension)
if not created:
error_msg = "Failed to create message index."
if task_id:
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
return False, error_msg
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
current_memory_size = get_memory_size_cache(memory.tenant_id, memory.id)
if new_msg_size + current_memory_size > memory.memory_size:
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
if memory.forgetting_policy == "FIFO":
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory.id, memory.tenant_id,
size_to_delete)
MessageService.delete_message({"message_id": message_ids_to_delete}, memory.tenant_id, memory.id)
decrease_memory_size_cache(memory.id, delete_size)
else:
error_msg = "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
if task_id:
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
return False, error_msg
fail_cases = MessageService.insert_message(message_list, memory.tenant_id, memory.id)
if fail_cases:
error_msg = "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
if task_id:
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
return False, error_msg
if task_id:
TaskService.update_progress(task_id, {"progress": 0.95, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Saved messages to storage."})
increase_memory_size_cache(memory.id, new_msg_size)
return True, "Message saved successfully."
def query_message(filter_dict: dict, params: dict):
"""
:param filter_dict: {
@ -241,9 +163,9 @@ def query_message(filter_dict: dict, params: dict):
memory = memory_list[0]
embd_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"])
match_text, _ = MsgTextQuery().question(question, min_match=params["similarity_threshold"])
match_text, _ = MsgTextQuery().question(question, min_match=0.3)
keywords_similarity_weight = params.get("keywords_similarity_weight", 0.7)
fusion_expr = FusionExpr("weighted_sum", params["top_n"], {"weights": ",".join([str(1 - keywords_similarity_weight), str(keywords_similarity_weight)])})
fusion_expr = FusionExpr("weighted_sum", params["top_n"], {"weights": ",".join([str(keywords_similarity_weight), str(1 - keywords_similarity_weight)])})
return MessageService.search_message(memory_ids, condition_dict, uids, [match_text, match_dense, fusion_expr], params["top_n"])
@ -269,8 +191,8 @@ def init_message_id_sequence():
def get_memory_size_cache(memory_id: str, uid: str):
redis_key = f"memory_{memory_id}"
if REDIS_CONN.exist(redis_key):
return int(REDIS_CONN.get(redis_key))
if REDIS_CONN.exists(redis_key):
return REDIS_CONN.get(redis_key)
else:
memory_size_map = MessageService.calculate_memory_size(
[memory_id],
@ -286,14 +208,14 @@ def set_memory_size_cache(memory_id: str, size: int):
return REDIS_CONN.set(redis_key, size)
def increase_memory_size_cache(memory_id: str, size: int):
redis_key = f"memory_{memory_id}"
return REDIS_CONN.incrby(redis_key, size)
def increase_memory_size_cache(memory_id: str, uid: str, size: int):
current_value = get_memory_size_cache(memory_id, uid)
return set_memory_size_cache(memory_id, current_value + size)
def decrease_memory_size_cache(memory_id: str, size: int):
redis_key = f"memory_{memory_id}"
return REDIS_CONN.decrby(redis_key, size)
def decrease_memory_size_cache(memory_id: str, uid: str, size: int):
current_value = get_memory_size_cache(memory_id, uid)
return set_memory_size_cache(memory_id, max(current_value - size, 0))
def init_memory_size_cache():
@ -301,138 +223,11 @@ def init_memory_size_cache():
if not memory_list:
logging.info("No memory found, no need to init memory size.")
else:
for m in memory_list:
get_memory_size_cache(m.id, m.tenant_id)
memory_size_map = MessageService.calculate_memory_size(
memory_ids=[m.id for m in memory_list],
uid_list=[m.tenant_id for m in memory_list],
)
for memory in memory_list:
memory_size = memory_size_map.get(memory.id, 0)
set_memory_size_cache(memory.id, memory_size)
logging.info("Memory size cache init done.")
def fix_missing_tokenized_memory():
if settings.DOC_ENGINE != "elasticsearch":
logging.info("Not using elasticsearch as doc engine, no need to fix missing tokenized memory.")
return
memory_list = MemoryService.get_all_memory()
if not memory_list:
logging.info("No memory found, no need to fix missing tokenized memory.")
else:
for m in memory_list:
message_list = MessageService.get_missing_field_messages(m.id, m.tenant_id, "tokenized_content_ltks")
for msg in message_list:
# update content to refresh tokenized field
MessageService.update_message({"message_id": msg["message_id"], "memory_id": m.id}, {"content": msg["content"]}, m.tenant_id, m.id)
if message_list:
logging.info(f"Fixed {len(message_list)} messages missing tokenized field in memory: {m.name}.")
logging.info("Fix missing tokenized memory done.")
def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]):
memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type)
return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list})
async def queue_save_to_memory_task(memory_ids: list[str], message_dict: dict):
"""
:param memory_ids:
:param message_dict: {
"user_id": str,
"agent_id": str,
"session_id": str,
"user_input": str,
"agent_response": str
}
"""
def new_task(_memory_id: str, _source_id: int):
return {
"id": get_uuid(),
"doc_id": _memory_id,
"task_type": "memory",
"progress": 0.0,
"digest": str(_source_id)
}
not_found_memory = []
failed_memory = []
for memory_id in memory_ids:
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
not_found_memory.append(memory_id)
continue
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
raw_message = {
"message_id": raw_message_id,
"message_type": MemoryType.RAW.name.lower(),
"source_id": 0,
"memory_id": memory_id,
"user_id": "",
"agent_id": message_dict["agent_id"],
"session_id": message_dict["session_id"],
"content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",
"valid_at": timestamp_to_date(current_timestamp()),
"invalid_at": None,
"forget_at": None,
"status": True
}
res, msg = await embed_and_save(memory, [raw_message])
if not res:
failed_memory.append({"memory_id": memory_id, "fail_msg": msg})
continue
task = new_task(memory_id, raw_message_id)
bulk_insert_into_db(Task, [task], replace_on_conflict=True)
task_message = {
"id": task["id"],
"task_id": task["id"],
"task_type": task["task_type"],
"memory_id": memory_id,
"source_id": raw_message_id,
"message_dict": message_dict
}
if not REDIS_CONN.queue_product(settings.get_svr_queue_name(priority=0), message=task_message):
failed_memory.append({"memory_id": memory_id, "fail_msg": "Can't access Redis."})
error_msg = ""
if not_found_memory:
error_msg = f"Memory {not_found_memory} not found."
if failed_memory:
error_msg += "".join([f"Memory {fm['memory_id']} failed. Detail: {fm['fail_msg']}" for fm in failed_memory])
if error_msg:
return False, error_msg
return True, "All add to task."
async def handle_save_to_memory_task(task_param: dict):
"""
:param task_param: {
"id": task_id
"memory_id": id
"source_id": id
"message_dict": {
"user_id": str,
"agent_id": str,
"session_id": str,
"user_input": str,
"agent_response": str
}
}
"""
_, task = TaskService.get_by_id(task_param["id"])
if not task:
return False, f"Task {task_param['id']} is not found."
if task.progress == -1:
return False, f"Task {task_param['id']} is already failed."
now_time = current_timestamp()
TaskService.update_by_id(task_param["id"], {"begin_at": timestamp_to_date(now_time)})
memory_id = task_param["memory_id"]
source_id = task_param["source_id"]
message_dict = task_param["message_dict"]
success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id, task.id)
if success:
TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
return True, msg
logging.error(msg)
TaskService.update_progress(task.id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
return False, msg

View File

@ -23,7 +23,6 @@ from api.db.services.canvas_service import UserCanvasService
from api.db.services.conversation_service import ConversationService
from api.db.services.dialog_service import DialogService
from api.db.services.document_service import DocumentService
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.langfuse_service import TenantLangfuseService
@ -108,11 +107,6 @@ def create_new_user(user_info: dict) -> dict:
except Exception as create_error:
logging.exception(create_error)
# rollback
try:
metadata_index_name = DocMetadataService._get_doc_meta_index_name(user_id)
settings.docStoreConn.delete_idx(metadata_index_name, "")
except Exception as e:
logging.exception(e)
try:
TenantService.delete_by_id(user_id)
except Exception as e:
@ -171,12 +165,6 @@ def delete_user_data(user_id: str) -> dict:
# step1.1.2 delete file and document info in db
doc_ids = DocumentService.get_all_doc_ids_by_kb_ids(kb_ids)
if doc_ids:
for doc in doc_ids:
try:
DocMetadataService.delete_document_metadata(doc["id"], skip_empty_check=True)
except Exception as e:
logging.warning(f"Failed to delete metadata for document {doc['id']}: {e}")
doc_delete_res = DocumentService.delete_by_ids([i["id"] for i in doc_ids])
done_msg += f"- Deleted {doc_delete_res} document records.\n"
task_delete_res = TaskService.delete_by_doc_ids([i["id"] for i in doc_ids])
@ -214,13 +202,6 @@ def delete_user_data(user_id: str) -> dict:
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
try:
metadata_index_name = DocMetadataService._get_doc_meta_index_name(tenant_id)
settings.docStoreConn.delete_idx(metadata_index_name, "")
done_msg += f"- Deleted metadata table {metadata_index_name}.\n"
except Exception as e:
logging.warning(f"Failed to delete metadata table for tenant {tenant_id}: {e}")
done_msg += "- Warning: Failed to delete metadata table (continuing).\n"
# step1.3 delete memory and messages
user_memory = MemoryService.get_by_tenant_id(tenant_id)
if user_memory:
@ -288,11 +269,6 @@ def delete_user_data(user_id: str) -> dict:
# step2.1.5 delete document record
doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents])
done_msg += f"- Deleted {doc_delete_res} documents.\n"
for doc in created_documents:
try:
DocMetadataService.delete_document_metadata(doc['id'])
except Exception as e:
logging.warning(f"Failed to delete metadata for document {doc['id']}: {e}")
# step2.1.6 update dataset doc&chunk&token cnt
for kb_id, doc_num in kb_doc_info.items():
KnowledgebaseService.decrease_document_num_in_delete(kb_id, doc_num)

View File

@ -48,8 +48,8 @@ class API4ConversationService(CommonService):
@DB.connection_context()
def get_list(cls, dialog_id, tenant_id,
page_number, items_per_page,
orderby, desc, id=None, user_id=None, include_dsl=True, keywords="",
from_date=None, to_date=None, exp_user_id=None
orderby, desc, id, user_id=None, include_dsl=True, keywords="",
from_date=None, to_date=None
):
if include_dsl:
sessions = cls.model.select().where(cls.model.dialog_id == dialog_id)
@ -66,8 +66,6 @@ class API4ConversationService(CommonService):
sessions = sessions.where(cls.model.create_date >= from_date)
if to_date:
sessions = sessions.where(cls.model.create_date <= to_date)
if exp_user_id:
sessions = sessions.where(cls.model.exp_user_id == exp_user_id)
if desc:
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
else:
@ -77,17 +75,6 @@ class API4ConversationService(CommonService):
return count, list(sessions.dicts())
@classmethod
@DB.connection_context()
def get_names(cls, dialog_id, exp_user_id):
fields = [cls.model.id, cls.model.name,]
sessions = cls.model.select(*fields).where(
cls.model.dialog_id == dialog_id,
cls.model.exp_user_id == exp_user_id
).order_by(cls.model.getter_by("create_date").desc())
return list(sessions.dicts())
@classmethod
@DB.connection_context()
def append_message(cls, id, conversation):

View File

@ -146,6 +146,7 @@ class UserCanvasService(CommonService):
cls.model.id,
cls.model.avatar,
cls.model.title,
cls.model.dsl,
cls.model.description,
cls.model.permission,
cls.model.user_id.alias("tenant_id"),
@ -194,7 +195,6 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
files = kwargs.get("files", [])
inputs = kwargs.get("inputs", {})
user_id = kwargs.get("user_id", "")
custom_header = kwargs.get("custom_header", "")
if session_id:
e, conv = API4ConversationService.get_by_id(session_id)
@ -203,7 +203,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
conv.message = []
if not isinstance(conv.dsl, str):
conv.dsl = json.dumps(conv.dsl, ensure_ascii=False)
canvas = Canvas(conv.dsl, tenant_id, agent_id, custom_header=custom_header)
canvas = Canvas(conv.dsl, tenant_id, agent_id)
else:
e, cvs = UserCanvasService.get_by_id(agent_id)
assert e, "Agent not found."
@ -211,7 +211,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
session_id=get_uuid()
canvas = Canvas(cvs.dsl, tenant_id, agent_id, canvas_id=cvs.id, custom_header=custom_header)
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
canvas.reset()
conv = {
"id": session_id,

View File

@ -190,15 +190,10 @@ class CommonService:
data_list (list): List of dictionaries containing record data to insert.
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
"""
current_ts = current_timestamp()
current_datetime = datetime_format(datetime.now())
with DB.atomic():
for d in data_list:
d["create_time"] = current_ts
d["create_date"] = current_datetime
d["update_time"] = current_ts
d["update_date"] = current_datetime
d["create_time"] = current_timestamp()
d["create_date"] = datetime_format(datetime.now())
for i in range(0, len(data_list), batch_size):
cls.model.insert_many(data_list[i : i + batch_size]).execute()

View File

@ -25,11 +25,11 @@ from api.db import InputType
from api.db.db_models import Connector, SyncLogs, Connector2Kb, Knowledgebase
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.document_service import DocMetadataService
from common.misc_utils import get_uuid
from common.constants import TaskStatus
from common.time_utils import current_timestamp, timestamp_to_date
class ConnectorService(CommonService):
model = Connector
@ -202,7 +202,6 @@ class SyncLogsService(CommonService):
return None
class FileObj(BaseModel):
id: str
filename: str
blob: bytes
@ -210,7 +209,7 @@ class SyncLogsService(CommonService):
return self.blob
errs = []
files = [FileObj(id=d["id"], filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs]
files = [FileObj(filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs]
doc_ids = []
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
errs.extend(err)
@ -228,7 +227,7 @@ class SyncLogsService(CommonService):
# Set metadata if available for this document
if doc["name"] in metadata_map:
DocMetadataService.update_document_metadata(doc["id"], metadata_map[doc["name"]])
DocumentService.update_by_id(doc["id"], {"meta_fields": metadata_map[doc["name"]]})
if not auto_parse or auto_parse == "0":
continue

View File

@ -64,13 +64,11 @@ class ConversationService(CommonService):
offset += limit
return res
def structure_answer(conv, ans, message_id, session_id):
reference = ans["reference"]
if not isinstance(reference, dict):
reference = {}
ans["reference"] = {}
is_final = ans.get("final", True)
chunk_list = chunks_format(reference)
@ -83,32 +81,14 @@ def structure_answer(conv, ans, message_id, session_id):
if not conv.message:
conv.message = []
content = ans["answer"]
if ans.get("start_to_think"):
content = "<think>"
elif ans.get("end_to_think"):
content = "</think>"
if not conv.message or conv.message[-1].get("role", "") != "assistant":
conv.message.append({"role": "assistant", "content": content, "created_at": time.time(), "id": message_id})
conv.message.append({"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id})
else:
if is_final:
if ans.get("answer"):
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
else:
conv.message[-1]["created_at"] = time.time()
conv.message[-1]["id"] = message_id
else:
conv.message[-1]["content"] = (conv.message[-1].get("content") or "") + content
conv.message[-1]["created_at"] = time.time()
conv.message[-1]["id"] = message_id
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
if conv.reference:
should_update_reference = is_final or bool(reference.get("chunks")) or bool(reference.get("doc_aggs"))
if should_update_reference:
conv.reference[-1] = reference
conv.reference[-1] = reference
return ans
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
assert name, "`name` can not be empty."
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
@ -136,16 +116,6 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
return
else:
answer = {
"answer": conv["message"][0]["content"],
"reference": {},
"audio_binary": None,
"id": None,
"session_id": session_id
}
yield answer
return
conv = ConversationService.query(id=session_id, dialog_id=chat_id)
if not conv:

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import binascii
import logging
import re
@ -24,19 +23,20 @@ from functools import partial
from timeit import default_timer as timer
from langfuse import Langfuse
from peewee import fn
from agentic_reasoning import DeepResearcher
from api.db.services.file_service import FileService
from common.constants import LLMType, ParserType, StatusEnum
from api.db.db_models import DB, Dialog
from api.db.services.common_service import CommonService
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.llm_service import LLMBundle
from common.metadata_utils import apply_meta_data_filter
from api.db.services.tenant_llm_service import TenantLLMService
from common.time_utils import current_timestamp, datetime_format
from rag.graphrag.general.mind_map_extractor import MindMapExtractor
from rag.advanced_rag import DeepResearcher
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
@ -196,13 +196,19 @@ async def async_chat_solo(dialog, messages, stream=True):
if attachments and msg:
msg[-1]["content"] += attachments
if stream:
stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting)
async for kind, value, state in _stream_with_think_delta(stream_iter):
if kind == "marker":
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
yield {"answer": "", "reference": {}, "audio_binary": None, "prompt": "", "created_at": time.time(), "final": False, **flags}
last_ans = ""
delta_ans = ""
answer = ""
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
continue
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False}
last_ans = answer
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
delta_ans = ""
if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
else:
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]")
@ -273,7 +279,6 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
async def async_chat(dialog, messages, stream=True, **kwargs):
logging.debug("Begin async_chat")
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
async for ans in async_chat_solo(dialog, messages, stream):
@ -296,14 +301,10 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id)
if langfuse_keys:
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
try:
if langfuse.auth_check():
langfuse_tracer = langfuse
trace_id = langfuse_tracer.create_trace_id()
trace_context = {"trace_id": trace_id}
except Exception:
# Skip langfuse tracing if connection fails
pass
if langfuse.auth_check():
langfuse_tracer = langfuse
trace_id = langfuse_tracer.create_trace_id()
trace_context = {"trace_id": trace_id}
check_langfuse_tracer_ts = timer()
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
@ -323,20 +324,13 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
logging.debug(f"field_map retrieved: {field_map}")
# try to use sql if field mapping is good to go
if field_map:
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
# For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid
if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")):
if ans:
yield ans
return
else:
logging.debug("SQL failed or returned no results, falling back to vector search")
param_keys = [p["key"] for p in prompt_config.get("parameters", [])]
logging.debug(f"attachments={attachments}, param_keys={param_keys}, embd_mdl={embd_mdl}")
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":
@ -355,7 +349,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
if dialog.meta_data_filter:
metas = DocMetadataService.get_flatted_meta_by_kbs(dialog.kb_ids)
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
attachments = await apply_meta_data_filter(
dialog.meta_data_filter,
metas,
@ -373,11 +367,10 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
knowledges = []
if attachments is not None and "knowledge" in param_keys:
logging.debug("Proceeding with retrieval")
if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]:
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
knowledges = []
if prompt_config.get("reasoning", False) or kwargs.get("reasoning"):
if prompt_config.get("reasoning", False):
reasoner = DeepResearcher(
chat_mdl,
prompt_config,
@ -393,28 +386,16 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
doc_ids=attachments,
),
)
queue = asyncio.Queue()
async def callback(msg:str):
nonlocal queue
await queue.put(msg + "<br/>")
await callback("<START_DEEP_RESEARCH>")
task = asyncio.create_task(reasoner.research(kbinfos, questions[-1], questions[-1], callback=callback))
while True:
msg = await queue.get()
if msg.find("<START_DEEP_RESEARCH>") == 0:
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "start_to_think": True}
elif msg.find("<END_DEEP_RESEARCH>") == 0:
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True}
break
else:
yield {"answer": msg, "reference": {}, "audio_binary": None, "final": False}
await task
async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
if isinstance(think, str):
thought = think
knowledges = [t for t in think.split("\n") if t]
elif stream:
yield think
else:
if embd_mdl:
kbinfos = await retriever.retrieval(
kbinfos = retriever.retrieval(
" ".join(questions),
embd_mdl,
tenant_ids,
@ -430,7 +411,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
rank_feature=label_question(" ".join(questions), kbs),
)
if prompt_config.get("toc_enhance"):
cks = await retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
if cks:
kbinfos["chunks"] = cks
kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids)
@ -440,19 +421,21 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
ck = settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
knowledges = kb_prompt(kbinfos, max_tokens)
knowledges = kb_prompt(kbinfos, max_tokens)
logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_ts = timer()
if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
"audio_binary": tts(tts_mdl, empty_res), "final": True}
"audio_binary": tts(tts_mdl, empty_res)}
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
return
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
@ -555,22 +538,21 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
)
if stream:
stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf)
last_state = None
async for kind, value, state in _stream_with_think_delta(stream_iter):
last_state = state
if kind == "marker":
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, **flags}
last_ans = ""
answer = ""
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
if thought:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
continue
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "final": False}
full_answer = last_state.full_text if last_state else ""
if full_answer:
final = decorate_answer(thought + full_answer)
final["final"] = True
final["audio_binary"] = None
final["answer"] = ""
yield final
last_ans = answer
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans):]
if delta_ans:
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought + answer)
else:
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]")
@ -583,342 +565,112 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
logging.debug(f"use_sql: Question: {question}")
# Determine which document engine we're using
if settings.DOC_ENGINE_INFINITY:
doc_engine = "infinity"
elif settings.DOC_ENGINE_OCEANBASE:
doc_engine = "oceanbase"
else:
doc_engine = "es"
# Construct the full table name
# For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause)
# For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table)
base_table = index_name(tenant_id)
if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1:
# Infinity: append kb_id to table name
table_name = f"{base_table}_{kb_ids[0]}"
logging.debug(f"use_sql: Using Infinity table name: {table_name}")
else:
# Elasticsearch/OpenSearch: use base index name
table_name = base_table
logging.debug(f"use_sql: Using ES/OS table name: {table_name}")
# Generate engine-specific SQL prompts
if doc_engine == "infinity":
# Build Infinity prompts with JSON extraction context
json_field_names = list(field_map.keys())
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT)
NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false
RULES:
1. Use EXACT field names (case-sensitive) from the list below
2. For SELECT: include doc_id, docnm, and json_extract_string() for requested fields
3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...))
4. Add AS alias for extracted field names
5. DO NOT select 'content' field
6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when:
- Question asks to "show me" or "display" specific columns
- Question mentions "not null" or "excluding null"
- Add NULL check for count specific column
- DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls)
7. Output ONLY the SQL, no explanations"""
user_prompt = """Table: {}
Fields (EXACT case): {}
sys_prompt = """
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
Ensure that:
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
2. Write only the SQL, no explanations or additional text.
"""
user_prompt = """
Table name: {};
Table of database fields are as follows:
{}
Question: {}
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format(
table_name,
", ".join(json_field_names),
"\n".join([f" - {field}" for field in json_field_names]),
question
)
elif doc_engine == "oceanbase":
# Build OceanBase prompts with JSON extraction context
json_field_names = list(field_map.keys())
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT)
NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false
RULES:
1. Use EXACT field names (case-sensitive) from the list below
2. For SELECT: include doc_id, docnm_kwd, and json_extract_string() for requested fields
3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...))
4. Add AS alias for extracted field names
5. DO NOT select 'content' field
6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when:
- Question asks to "show me" or "display" specific columns
- Question mentions "not null" or "excluding null"
- Add NULL check for count specific column
- DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls)
7. Output ONLY the SQL, no explanations"""
user_prompt = """Table: {}
Fields (EXACT case): {}
Question are as follows:
{}
Question: {}
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
table_name,
", ".join(json_field_names),
"\n".join([f" - {field}" for field in json_field_names]),
question
)
else:
# Build ES/OS prompts with direct field access
sys_prompt = """You are a Database Administrator. Write SQL queries.
RULES:
1. Use EXACT field names from the schema below (e.g., product_tks, not product)
2. Quote field names starting with digit: "123_field"
3. Add IS NOT NULL in WHERE clause when:
- Question asks to "show me" or "display" specific columns
4. Include doc_id/docnm in non-aggregate statement
5. Output ONLY the SQL, no explanations"""
user_prompt = """Table: {}
Available fields:
{}
Question: {}
Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
table_name,
"\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
question
)
Please write the SQL, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
tried_times = 0
async def get_table():
nonlocal sys_prompt, user_prompt, question, tried_times
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
# Remove think blocks if present (format: </think>...)
sql = re.sub(r"</think>\n.*?\n\s*", "", sql, flags=re.DOTALL)
sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL)
# Remove markdown code blocks (```sql ... ```)
sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
# Remove trailing semicolon that ES SQL parser doesn't like
sql = sql.rstrip().rstrip(';').strip()
# Add kb_id filter for ES/OS only (Infinity already has it in table name)
if doc_engine != "infinity" and kb_ids:
# Build kb_filter: single KB or multiple KBs with OR
if len(kb_ids) == 1:
kb_filter = f"kb_id = '{kb_ids[0]}'"
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql)
sql = re.sub(r"([;]|```).*", "", sql)
sql = re.sub(r"&", "and", sql)
if sql[: len("select ")] != "select ":
return None, None
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
if sql[: len("select *")] != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
else:
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
flds = []
for k in field_map.keys():
if k in forbidden_select_fields4resume:
continue
if len(flds) > 11:
break
flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
if "where " not in sql.lower():
if kb_ids:
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
if "where" not in sql.lower():
o = sql.lower().split("order by")
if len(o) > 1:
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
else:
sql += f" WHERE {kb_filter}"
elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower():
sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE)
else:
sql += f" AND {kb_filter}"
logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1
logging.debug(f"use_sql: Executing SQL retrieval (attempt {tried_times})")
tbl = settings.retriever.sql_retrieval(sql, format="json")
if tbl is None:
logging.debug("use_sql: SQL retrieval returned None")
return None, sql
logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows")
return tbl, sql
return settings.retriever.sql_retrieval(sql, format="json"), sql
try:
tbl, sql = await get_table()
logging.debug(f"use_sql: Initial SQL execution SUCCESS. SQL: {sql}")
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows, columns: {[c['name'] for c in tbl.get('columns', [])]}")
except Exception as e:
logging.warning(f"use_sql: Initial SQL execution FAILED with error: {e}")
# Build retry prompt with error information
if doc_engine in ("infinity", "oceanbase"):
# Build Infinity error retry prompt
json_field_names = list(field_map.keys())
user_prompt = """
Table name: {};
JSON fields available in 'chunk_data' column (use these exact names in json_extract_string):
{}
Question: {}
Please write the SQL using json_extract_string(chunk_data, '$.field_name') with the field names from the list above. Only SQL, no explanations.
The SQL error you provided last time is as follows:
{}
Please correct the error and write SQL again using json_extract_string(chunk_data, '$.field_name') syntax with the correct field names. Only SQL, no explanations.
""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, e)
else:
# Build ES/OS error retry prompt
user_prompt = """
user_prompt = """
Table name: {};
Table of database fields are as follows (use the field names directly in SQL):
Table of database fields are as follows:
{}
Question are as follows:
{}
Please write the SQL using the exact field names above, only SQL, without any other explanations or text.
Please write the SQL, only SQL, without any other explanations or text.
The SQL error you provided last time is as follows:
{}
Please correct the error and write SQL again using the exact field names above, only SQL, without any other explanations or text.
""".format(table_name, "\n".join([f"{k} ({v})" for k, v in field_map.items()]), question, e)
Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
try:
tbl, sql = await get_table()
logging.debug(f"use_sql: Retry SQL execution SUCCESS. SQL: {sql}")
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows on retry")
except Exception:
logging.error("use_sql: Retry SQL execution also FAILED, returning None")
return
if len(tbl["rows"]) == 0:
logging.warning(f"use_sql: No rows returned from SQL query, returning None. SQL: {sql}")
return None
logging.debug(f"use_sql: Proceeding with {len(tbl['rows'])} rows to build answer")
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"])
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]])
logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}")
logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}")
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
logging.debug(f"use_sql: column_idx={column_idx}")
logging.debug(f"use_sql: field_map={field_map}")
# Helper function to map column names to display names
def map_column_name(col_name):
if col_name.lower() == "count(star)":
return "COUNT(*)"
# First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.)
# Pattern: anything AS alias_name
as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE)
if as_match:
alias = as_match.group(1).strip('"\'')
# Use the alias for display name lookup
if alias in field_map:
display = field_map[alias]
return re.sub(r"(/.*|[^]+)", "", display)
# If alias not in field_map, try to match case-insensitively
for field_key, display_value in field_map.items():
if field_key.lower() == alias.lower():
return re.sub(r"(/.*|[^]+)", "", display_value)
# Return alias as-is if no mapping found
return alias
# Try direct mapping first (for simple column names)
if col_name in field_map:
display = field_map[col_name]
# Clean up any suffix patterns
return re.sub(r"(/.*|[^]+)", "", display)
# Try case-insensitive match for simple column names
col_lower = col_name.lower()
for field_key, display_value in field_map.items():
if field_key.lower() == col_lower:
return re.sub(r"(/.*|[^]+)", "", display_value)
# For aggregate expressions or complex expressions without AS alias,
# try to replace field names with display names
result = col_name
for field_name, display_name in field_map.items():
# Replace field_name with display_name in the expression
result = result.replace(field_name, display_name)
# Clean up any suffix patterns
result = re.sub(r"(/.*|[^]+)", "", result)
return result
# compose Markdown table
columns = (
"|" + "|".join(
[map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + (
"|Source|" if docid_idx and doc_name_idx else "|")
[re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
"|Source|" if docid_idx and docid_idx else "|")
)
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
# Build rows ensuring column names match values - create a dict for each row
# keyed by column name to handle any SQL column order
rows = []
for row_idx, r in enumerate(tbl["rows"]):
row_dict = {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)}
if row_idx == 0:
logging.debug(f"use_sql: First row data: {row_dict}")
row_values = []
for col_idx in column_idx:
col_name = tbl["columns"][col_idx]["name"]
value = row_dict.get(col_name, " ")
row_values.append(remove_redundant_spaces(str(value)).replace("None", " "))
# Add Source column with citation marker if Source column exists
if docid_idx and doc_name_idx:
row_values.append(f" ##{row_idx}$$")
row_str = "|" + "|".join(row_values) + "|"
if re.sub(r"[ |]+", "", row_str):
rows.append(row_str)
rows = ["|" + "|".join([remove_redundant_spaces(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
if quota:
rows = "\n".join(rows)
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
else:
rows = "\n".join(rows)
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
if not docid_idx or not doc_name_idx:
logging.warning(f"use_sql: SQL missing required doc_id or docnm_kwd field. docid_idx={docid_idx}, doc_name_idx={doc_name_idx}. SQL: {sql}")
# For aggregate queries (COUNT, SUM, AVG, MAX, MIN, DISTINCT), fetch doc_id, docnm_kwd separately
# to provide source chunks, but keep the original table format answer
if re.search(r"(count|sum|avg|max|min|distinct)\s*\(", sql.lower()):
# Keep original table format as answer
answer = "\n".join([columns, line, rows])
# Now fetch doc_id, docnm_kwd to provide source chunks
# Extract WHERE clause from the original SQL
where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE)
if where_match:
where_clause = where_match.group(1).strip()
# Build a query to get doc_id and docnm_kwd with the same WHERE clause
chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}"
# Add LIMIT to avoid fetching too many chunks
if "limit" not in chunks_sql.lower():
chunks_sql += " limit 20"
logging.debug(f"use_sql: Fetching chunks with SQL: {chunks_sql}")
try:
chunks_tbl = settings.retriever.sql_retrieval(chunks_sql, format="json")
if chunks_tbl.get("rows") and len(chunks_tbl["rows"]) > 0:
# Build chunks reference - use case-insensitive matching
chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None)
chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None)
if chunks_did_idx is not None and chunks_dn_idx is not None:
chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]]
# Build doc_aggs
doc_aggs = {}
for r in chunks_tbl["rows"]:
doc_id = r[chunks_did_idx]
doc_name = r[chunks_dn_idx]
if doc_id not in doc_aggs:
doc_aggs[doc_id] = {"doc_name": doc_name, "count": 0}
doc_aggs[doc_id]["count"] += 1
doc_aggs_list = [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]
logging.debug(f"use_sql: Returning aggregate answer with {len(chunks)} chunks from {len(doc_aggs)} documents")
return {"answer": answer, "reference": {"chunks": chunks, "doc_aggs": doc_aggs_list}, "prompt": sys_prompt}
except Exception as e:
logging.warning(f"use_sql: Failed to fetch chunks: {e}")
# Fallback: return answer without chunks
return {"answer": answer, "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
# Fallback to table format for other cases
logging.warning("SQL missing field: " + sql)
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
docid_idx = list(docid_idx)[0]
@ -928,8 +680,7 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat
if r[docid_idx] not in doc_aggs:
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
doc_aggs[r[docid_idx]]["count"] += 1
result = {
return {
"answer": "\n".join([columns, line, rows]),
"reference": {
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
@ -937,8 +688,6 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat
},
"prompt": sys_prompt,
}
logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents")
return result
def clean_tts_text(text: str) -> str:
if not text:
@ -984,84 +733,6 @@ def tts(tts_mdl, text):
return None
return binascii.hexlify(bin).decode("utf-8")
class _ThinkStreamState:
def __init__(self) -> None:
self.full_text = ""
self.last_idx = 0
self.endswith_think = False
self.last_full = ""
self.last_model_full = ""
self.in_think = False
self.buffer = ""
def _next_think_delta(state: _ThinkStreamState) -> str:
full_text = state.full_text
if full_text == state.last_full:
return ""
state.last_full = full_text
delta_ans = full_text[state.last_idx:]
if delta_ans.find("<think>") == 0:
state.last_idx += len("<think>")
return "<think>"
if delta_ans.find("<think>") > 0:
delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("<think>")]
state.last_idx += delta_ans.find("<think>")
return delta_text
if delta_ans.endswith("</think>"):
state.endswith_think = True
elif state.endswith_think:
state.endswith_think = False
return "</think>"
state.last_idx = len(full_text)
if full_text.endswith("</think>"):
state.last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
state = _ThinkStreamState()
async for chunk in stream_iter:
if not chunk:
continue
if chunk.startswith(state.last_model_full):
new_part = chunk[len(state.last_model_full):]
state.last_model_full = chunk
else:
new_part = chunk
state.last_model_full += chunk
if not new_part:
continue
state.full_text += new_part
delta = _next_think_delta(state)
if not delta:
continue
if delta in ("<think>", "</think>"):
if delta == "<think>" and state.in_think:
continue
if delta == "</think>" and not state.in_think:
continue
if state.buffer:
yield ("text", state.buffer, state)
state.buffer = ""
state.in_think = delta == "<think>"
yield ("marker", delta, state)
continue
state.buffer += delta
if num_tokens_from_string(state.buffer) < min_tokens:
continue
yield ("text", state.buffer, state)
state.buffer = ""
if state.buffer:
yield ("text", state.buffer, state)
state.buffer = ""
if state.endswith_think:
yield ("marker", "</think>", state)
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None
@ -1084,10 +755,10 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
if meta_data_filter:
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
metas = DocumentService.get_meta_by_kbs(kb_ids)
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
kbinfos = await retriever.retrieval(
kbinfos = retriever.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,
@ -1127,20 +798,11 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
refs["chunks"] = chunks_format(refs)
return {"answer": answer, "reference": refs}
stream_iter = chat_mdl.async_chat_streamly_delta(sys_prompt, msg, {"temperature": 0.1})
last_state = None
async for kind, value, state in _stream_with_think_delta(stream_iter):
last_state = state
if kind == "marker":
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
yield {"answer": "", "reference": {}, "final": False, **flags}
continue
yield {"answer": value, "reference": {}, "final": False}
full_answer = last_state.full_text if last_state else ""
final = decorate_answer(full_answer)
final["final"] = True
final["answer"] = ""
yield final
answer = ""
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)
async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
@ -1160,10 +822,10 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
if meta_data_filter:
metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
metas = DocumentService.get_meta_by_kbs(kb_ids)
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
ranks = await settings.retriever.retrieval(
ranks = settings.retriever.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,

File diff suppressed because it is too large Load Diff

View File

@ -33,7 +33,7 @@ from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTena
from api.db.db_utils import bulk_insert_into_db
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.doc_metadata_service import DocMetadataService
from common.metadata_utils import dedupe_list
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp, get_format_time
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
@ -67,6 +67,7 @@ class DocumentService(CommonService):
cls.model.progress_msg,
cls.model.process_begin_at,
cls.model.process_duration,
cls.model.meta_fields,
cls.model.suffix,
cls.model.run,
cls.model.status,
@ -109,12 +110,7 @@ class DocumentService(CommonService):
count = docs.count()
docs = docs.paginate(page_number, items_per_page)
docs_list = list(docs.dicts())
metadata_map = DocMetadataService.get_metadata_for_documents(None, kb_id)
for doc in docs_list:
doc["meta_fields"] = metadata_map.get(doc["id"], {})
return docs_list, count
return list(docs.dicts()), count
@classmethod
@DB.connection_context()
@ -129,26 +125,26 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, doc_ids=None, return_empty_metadata=False):
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords, run_status, types, suffix, doc_ids=None):
fields = cls.get_cls_model_fields()
if keywords:
docs = (
cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])
.join(File2Document, on=(File2Document.document_id == cls.model.id))
.join(File, on=(File.id == File2Document.file_id))
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)
.where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
)
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
.join(File, on=(File.id == File2Document.file_id))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
.where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
else:
docs = (
cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])
.join(File2Document, on=(File2Document.document_id == cls.model.id))
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)
.join(File, on=(File.id == File2Document.file_id))
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.join(File, on=(File.id == File2Document.file_id))\
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
.where(cls.model.kb_id == kb_id)
)
if doc_ids:
docs = docs.where(cls.model.id.in_(doc_ids))
@ -159,28 +155,17 @@ class DocumentService(CommonService):
if suffix:
docs = docs.where(cls.model.suffix.in_(suffix))
metadata_map = DocMetadataService.get_metadata_for_documents(None, kb_id)
doc_ids_with_metadata = set(metadata_map.keys())
if return_empty_metadata and doc_ids_with_metadata:
docs = docs.where(cls.model.id.not_in(doc_ids_with_metadata))
count = docs.count()
if desc:
docs = docs.order_by(cls.model.getter_by(orderby).desc())
else:
docs = docs.order_by(cls.model.getter_by(orderby).asc())
if page_number and items_per_page:
docs = docs.paginate(page_number, items_per_page)
docs_list = list(docs.dicts())
if return_empty_metadata:
for doc in docs_list:
doc["meta_fields"] = {}
else:
for doc in docs_list:
doc["meta_fields"] = metadata_map.get(doc["id"], {})
return docs_list, count
return list(docs.dicts()), count
@classmethod
@DB.connection_context()
@ -226,30 +211,24 @@ class DocumentService(CommonService):
if suffix:
query = query.where(cls.model.suffix.in_(suffix))
rows = query.select(cls.model.run, cls.model.suffix, cls.model.id)
rows = query.select(cls.model.run, cls.model.suffix, cls.model.meta_fields)
total = rows.count()
suffix_counter = {}
run_status_counter = {}
metadata_counter = {}
empty_metadata_count = 0
doc_ids = [row.id for row in rows]
metadata = {}
if doc_ids:
try:
metadata = DocMetadataService.get_metadata_for_documents(doc_ids, kb_id)
except Exception as e:
logging.warning(f"Failed to fetch metadata from ES/Infinity: {e}")
for row in rows:
suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1
run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1
meta_fields = metadata.get(row.id, {})
if not meta_fields:
empty_metadata_count += 1
meta_fields = row.meta_fields or {}
if isinstance(meta_fields, str):
try:
meta_fields = json.loads(meta_fields)
except Exception:
meta_fields = {}
if not isinstance(meta_fields, dict):
continue
has_valid_meta = False
for key, value in meta_fields.items():
values = value if isinstance(value, list) else [value]
for vv in values:
@ -261,11 +240,7 @@ class DocumentService(CommonService):
if key not in metadata_counter:
metadata_counter[key] = {}
metadata_counter[key][sv] = metadata_counter[key].get(sv, 0) + 1
has_valid_meta = True
if not has_valid_meta:
empty_metadata_count += 1
metadata_counter["empty_metadata"] = {"true": empty_metadata_count}
return {
"suffix": suffix_counter,
"run_status": run_status_counter,
@ -360,50 +335,30 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def remove_document(cls, doc, tenant_id):
from api.db.services.task_service import TaskService, cancel_all_task_of
from api.db.services.task_service import TaskService
cls.clear_chunk_num(doc.id)
# Cancel all running tasks first Using preset function in task_service.py --- set cancel flag in Redis
try:
cancel_all_task_of(doc.id)
logging.info(f"Cancelled all tasks for document {doc.id}")
except Exception as e:
logging.warning(f"Failed to cancel tasks for document {doc.id}: {e}")
# Delete tasks from database
try:
TaskService.filter_delete([Task.doc_id == doc.id])
except Exception as e:
logging.warning(f"Failed to delete tasks for document {doc.id}: {e}")
# Delete chunk images (non-critical, log and continue)
try:
cls.delete_chunk_images(doc, tenant_id)
except Exception as e:
logging.warning(f"Failed to delete chunk images for document {doc.id}: {e}")
# Delete thumbnail (non-critical, log and continue)
try:
page = 0
page_size = 1000
all_chunk_ids = []
while True:
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
if not chunk_ids:
break
all_chunk_ids.extend(chunk_ids)
page += 1
for cid in all_chunk_ids:
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
except Exception as e:
logging.warning(f"Failed to delete thumbnail for document {doc.id}: {e}")
# Delete chunks from doc store - this is critical, log errors
try:
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
except Exception as e:
logging.error(f"Failed to delete chunks from doc store for document {doc.id}: {e}")
# Delete document metadata (non-critical, log and continue)
try:
DocMetadataService.delete_document_metadata(doc.id)
except Exception as e:
logging.warning(f"Failed to delete metadata for document {doc.id}: {e}")
# Cleanup knowledge graph references (non-critical, log and continue)
try:
graph_source = settings.docStoreConn.get_fields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
)
@ -416,28 +371,10 @@ class DocumentService(CommonService):
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id)
except Exception as e:
logging.warning(f"Failed to cleanup knowledge graph for document {doc.id}: {e}")
except Exception:
pass
return cls.delete_by_id(doc.id)
@classmethod
@DB.connection_context()
def delete_chunk_images(cls, doc, tenant_id):
page = 0
page_size = 1000
while True:
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
if not chunk_ids:
break
for cid in chunk_ids:
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
page += 1
@classmethod
@DB.connection_context()
def get_newly_uploaded(cls):
@ -480,7 +417,6 @@ class DocumentService(CommonService):
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)),
(((cls.model.progress < 1) & (cls.model.progress > 0)) |
(cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
return list(docs.dicts())
@ -703,7 +639,8 @@ class DocumentService(CommonService):
if k not in old:
old[k] = v
continue
if isinstance(v, dict) and isinstance(old[k], dict):
if isinstance(v, dict):
assert isinstance(old[k], dict)
dfs_update(old[k], v)
else:
old[k] = v
@ -735,6 +672,209 @@ class DocumentService(CommonService):
cls.update_by_id(doc_id, info)
@classmethod
@DB.connection_context()
def update_meta_fields(cls, doc_id, meta_fields):
return cls.update_by_id(doc_id, {"meta_fields": meta_fields})
@classmethod
@DB.connection_context()
def get_meta_by_kbs(cls, kb_ids):
"""
Legacy metadata aggregator (backward-compatible).
- Does NOT expand list values and a list is kept as one string key.
Example: {"tags": ["foo","bar"]} -> meta["tags"]["['foo', 'bar']"] = [doc_id]
- Expects meta_fields is a dict.
Use when existing callers rely on the old list-as-string semantics.
"""
fields = [
cls.model.id,
cls.model.meta_fields,
]
meta = {}
for r in cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids)):
doc_id = r.id
for k,v in r.meta_fields.items():
if k not in meta:
meta[k] = {}
if not isinstance(v, list):
v = [v]
for vv in v:
if vv not in meta[k]:
if isinstance(vv, list) or isinstance(vv, dict):
continue
meta[k][vv] = []
meta[k][vv].append(doc_id)
return meta
@classmethod
@DB.connection_context()
def get_flatted_meta_by_kbs(cls, kb_ids):
"""
- Parses stringified JSON meta_fields when possible and skips non-dict or unparsable values.
- Expands list values into individual entries.
Example: {"tags": ["foo","bar"], "author": "alice"} ->
meta["tags"]["foo"] = [doc_id], meta["tags"]["bar"] = [doc_id], meta["author"]["alice"] = [doc_id]
Prefer for metadata_condition filtering and scenarios that must respect list semantics.
"""
fields = [
cls.model.id,
cls.model.meta_fields,
]
meta = {}
for r in cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids)):
doc_id = r.id
meta_fields = r.meta_fields or {}
if isinstance(meta_fields, str):
try:
meta_fields = json.loads(meta_fields)
except Exception:
continue
if not isinstance(meta_fields, dict):
continue
for k, v in meta_fields.items():
if k not in meta:
meta[k] = {}
values = v if isinstance(v, list) else [v]
for vv in values:
if vv is None:
continue
sv = str(vv)
if sv not in meta[k]:
meta[k][sv] = []
meta[k][sv].append(doc_id)
return meta
@classmethod
@DB.connection_context()
def get_metadata_summary(cls, kb_id):
fields = [cls.model.id, cls.model.meta_fields]
summary = {}
for r in cls.model.select(*fields).where(cls.model.kb_id == kb_id):
meta_fields = r.meta_fields or {}
if isinstance(meta_fields, str):
try:
meta_fields = json.loads(meta_fields)
except Exception:
continue
if not isinstance(meta_fields, dict):
continue
for k, v in meta_fields.items():
values = v if isinstance(v, list) else [v]
for vv in values:
if not vv:
continue
sv = str(vv)
if k not in summary:
summary[k] = {}
summary[k][sv] = summary[k].get(sv, 0) + 1
return {k: sorted([(val, cnt) for val, cnt in v.items()], key=lambda x: x[1], reverse=True) for k, v in summary.items()}
@classmethod
@DB.connection_context()
def batch_update_metadata(cls, kb_id, doc_ids, updates=None, deletes=None):
updates = updates or []
deletes = deletes or []
if not doc_ids:
return 0
def _normalize_meta(meta):
if isinstance(meta, str):
try:
meta = json.loads(meta)
except Exception:
return {}
if not isinstance(meta, dict):
return {}
return deepcopy(meta)
def _str_equal(a, b):
return str(a) == str(b)
def _apply_updates(meta):
changed = False
for upd in updates:
key = upd.get("key")
if not key or key not in meta:
continue
new_value = upd.get("value")
match_provided = "match" in upd
if isinstance(meta[key], list):
if not match_provided:
if isinstance(new_value, list):
meta[key] = dedupe_list(new_value)
else:
meta[key] = new_value
changed = True
else:
match_value = upd.get("match")
replaced = False
new_list = []
for item in meta[key]:
if _str_equal(item, match_value):
new_list.append(new_value)
replaced = True
else:
new_list.append(item)
if replaced:
meta[key] = dedupe_list(new_list)
changed = True
else:
if not match_provided:
meta[key] = new_value
changed = True
else:
match_value = upd.get("match")
if _str_equal(meta[key], match_value):
meta[key] = new_value
changed = True
return changed
def _apply_deletes(meta):
changed = False
for d in deletes:
key = d.get("key")
if not key or key not in meta:
continue
value = d.get("value", None)
if isinstance(meta[key], list):
if value is None:
del meta[key]
changed = True
continue
new_list = [item for item in meta[key] if not _str_equal(item, value)]
if len(new_list) != len(meta[key]):
if new_list:
meta[key] = new_list
else:
del meta[key]
changed = True
else:
if value is None or _str_equal(meta[key], value):
del meta[key]
changed = True
return changed
updated_docs = 0
with DB.atomic():
rows = cls.model.select(cls.model.id, cls.model.meta_fields).where(
(cls.model.id.in_(doc_ids)) & (cls.model.kb_id == kb_id)
)
for r in rows:
meta = _normalize_meta(r.meta_fields or {})
original_meta = deepcopy(meta)
changed = _apply_updates(meta)
changed = _apply_deletes(meta) or changed
if changed and meta != original_meta:
cls.model.update(
meta_fields=meta,
update_time=current_timestamp(),
update_date=get_format_time()
).where(cls.model.id == r.id).execute()
updated_docs += 1
return updated_docs
@classmethod
@DB.connection_context()
def update_progress(cls):
@ -768,8 +908,6 @@ class DocumentService(CommonService):
bad = 0
e, doc = DocumentService.get_by_id(d["id"])
status = doc.run # TaskStatus.RUNNING.value
if status == TaskStatus.CANCEL.value:
continue
doc_progress = doc.progress if doc and doc.progress else 0.0
special_task_running = False
priority = 0
@ -813,16 +951,7 @@ class DocumentService(CommonService):
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
else:
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
info["update_time"] = current_timestamp()
info["update_date"] = get_format_time()
(
cls.model.update(info)
.where(
(cls.model.id == d["id"])
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))
)
.execute()
)
cls.update_by_id(d["id"], info)
except Exception as e:
if str(e).find("'0'") < 0:
logging.exception("fetch task exception")
@ -855,7 +984,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
# cancelled: run == "2"
# cancelled: run == "2" but progress can vary
cancelled = (
cls.model.select(fn.COUNT(1))
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
@ -1082,7 +1211,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
cks = [c for c in docs if c["doc_id"] == doc_id]
if parser_ids[doc_id] != ParserType.PICTURE.value:
from rag.graphrag.general.mind_map_extractor import MindMapExtractor
from graphrag.general.mind_map_extractor import MindMapExtractor
mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
@ -1110,7 +1239,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not settings.docStoreConn.index_exist(idxnm, kb_id):
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]))
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)

View File

@ -65,7 +65,6 @@ class EvaluationService(CommonService):
(success, dataset_id or error_message)
"""
try:
timestamp= current_timestamp()
dataset_id = get_uuid()
dataset = {
"id": dataset_id,
@ -74,8 +73,8 @@ class EvaluationService(CommonService):
"description": description,
"kb_ids": kb_ids,
"created_by": user_id,
"create_time": timestamp,
"update_time": timestamp,
"create_time": current_timestamp(),
"update_time": current_timestamp(),
"status": StatusEnum.VALID.value
}
@ -225,36 +224,21 @@ class EvaluationService(CommonService):
"""
success_count = 0
failure_count = 0
case_instances = []
if not cases:
return success_count, failure_count
for case_data in cases:
success, _ = cls.add_test_case(
dataset_id=dataset_id,
question=case_data.get("question", ""),
reference_answer=case_data.get("reference_answer"),
relevant_doc_ids=case_data.get("relevant_doc_ids"),
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
metadata=case_data.get("metadata")
)
cur_timestamp = current_timestamp()
try:
for case_data in cases:
case_id = get_uuid()
case_info = {
"id": case_id,
"dataset_id": dataset_id,
"question": case_data.get("question", ""),
"reference_answer": case_data.get("reference_answer"),
"relevant_doc_ids": case_data.get("relevant_doc_ids"),
"relevant_chunk_ids": case_data.get("relevant_chunk_ids"),
"metadata": case_data.get("metadata"),
"create_time": cur_timestamp
}
case_instances.append(EvaluationCase(**case_info))
EvaluationCase.bulk_create(case_instances, batch_size=300)
success_count = len(case_instances)
failure_count = 0
except Exception as e:
logging.error(f"Error bulk importing test cases: {str(e)}")
failure_count = len(cases)
success_count = 0
if success:
success_count += 1
else:
failure_count += 1
return success_count, failure_count

View File

@ -439,15 +439,6 @@ class FileService(CommonService):
err, files = [], []
for file in file_objs:
doc_id = file.id if hasattr(file, "id") else get_uuid()
e, doc = DocumentService.get_by_id(doc_id)
if e:
blob = file.read()
settings.STORAGE_IMPL.put(kb.id, doc.location, blob, kb.tenant_id)
doc.size = len(blob)
doc = doc.to_dict()
DocumentService.update_by_id(doc["id"], doc)
continue
try:
DocumentService.check_doc_health(kb.tenant_id, file.filename)
filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id)
@ -464,6 +455,7 @@ class FileService(CommonService):
blob = read_potential_broken_pdf(blob)
settings.STORAGE_IMPL.put(kb.id, location, blob)
doc_id = get_uuid()
img = thumbnail_img(filename, blob)
thumbnail_location = ""

Some files were not shown because too many files have changed in this diff Show More