mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-23 03:26:53 +08:00
Compare commits
377 Commits
d38f8a1562
...
nightly
| Author | SHA1 | Date | |
|---|---|---|---|
| 7c9b6e032b | |||
| 3beb85efa0 | |||
| bc7b864a6c | |||
| 93091f4551 | |||
| 2d9e7b4acd | |||
| 6f3f69b62e | |||
| bfd5435087 | |||
| 0e9fe68110 | |||
| 89f438fe45 | |||
| 2e2c8f6ca9 | |||
| 6cd4fd91e6 | |||
| 83e17d8c4a | |||
| e1143d40bc | |||
| f98abf14a8 | |||
| 2a87778e10 | |||
| 5836823187 | |||
| 5a7026cf55 | |||
| bc7935d627 | |||
| 7787085664 | |||
| 960ecd3158 | |||
| aee9860970 | |||
| 9ebbc5a74d | |||
| 1c65f64bda | |||
| 32841549c1 | |||
| 046d4ffdef | |||
| 4c4d434bc1 | |||
| 80612bc992 | |||
| 927db0b373 | |||
| 120648ac81 | |||
| f367189703 | |||
| 1b1554c563 | |||
| 59f3da2bdf | |||
| b40d639fdb | |||
| 05da2a5872 | |||
| 4fbaa4aae9 | |||
| 3188cd2659 | |||
| c4a982e9fa | |||
| b27dc26be3 | |||
| ab1836f216 | |||
| 7a53d2dd97 | |||
| f3d347f55f | |||
| 9da48ab0bd | |||
| 4a7e40630b | |||
| d6897b6054 | |||
| 828ae1e82f | |||
| 57d189b483 | |||
| 0a8eb11c3d | |||
| 38f0a92da9 | |||
| 067ddcbf23 | |||
| 46305ef35e | |||
| bd9163904a | |||
| b6d7733058 | |||
| 4f036a881d | |||
| 59075a0b58 | |||
| 30bd25716b | |||
| 99dae3c64c | |||
| 045314a1aa | |||
| 2b20d0b3bb | |||
| 59f4c51222 | |||
| 8c1fbfb130 | |||
| cec06bfb5d | |||
| 2167e3a3c0 | |||
| 2ea8dddef6 | |||
| 18867daba7 | |||
| d68176326d | |||
| d531bd4f1a | |||
| ac936005e6 | |||
| d8192f8f17 | |||
| eb35e2b89f | |||
| 97b983fd0b | |||
| b40a7b2e7d | |||
| 9a10558f80 | |||
| f82628c40c | |||
| 7af98328f5 | |||
| 678a4f959c | |||
| 15a8bb2e9c | |||
| b091ff2730 | |||
| 5b22f94502 | |||
| a7671583b3 | |||
| d32fa02d97 | |||
| f72a35188d | |||
| ea619dba3b | |||
| 36b0835740 | |||
| 0795616b34 | |||
| 941651a16f | |||
| 360114ed42 | |||
| ffedb2c6d3 | |||
| 947e63ca14 | |||
| 34d74d9928 | |||
| accae95126 | |||
| 68e5c86e9c | |||
| 64c75d558e | |||
| 41c84fd78f | |||
| d76912ab15 | |||
| 4fe3c24198 | |||
| 44bada64c9 | |||
| 867ec94258 | |||
| fd0a1fde6b | |||
| 653001b14f | |||
| d4f8c724ed | |||
| a7dd3b7e9e | |||
| 638c510468 | |||
| ff11e3171e | |||
| 030d6ba004 | |||
| b226e06e2d | |||
| 2e09db02f3 | |||
| 6abf55c048 | |||
| f9d4179bf2 | |||
| 64b1e0b4c3 | |||
| b65daeb945 | |||
| fbe55cef05 | |||
| 0878526ba8 | |||
| a2db3e3292 | |||
| f522391d1e | |||
| 9562762af2 | |||
| 455fd04050 | |||
| 14c250e3d7 | |||
| a093e616cf | |||
| 696397ebba | |||
| 6f1a555d5f | |||
| 1996aa0dac | |||
| f4e2783eb4 | |||
| 2fd4a3134d | |||
| f1dc2df23c | |||
| de27c006d8 | |||
| 23a9544b73 | |||
| 011bbe9556 | |||
| a442c9cac6 | |||
| 671e719d75 | |||
| 07845be5bd | |||
| 8d406bd2e6 | |||
| 2a4627d9a0 | |||
| 6814ace1aa | |||
| ca9645f39b | |||
| 8e03843145 | |||
| 51ece37db2 | |||
| 45fb2719cf | |||
| bdd9f3d4d1 | |||
| 1f60863f60 | |||
| 02e6870755 | |||
| aa08920e51 | |||
| 7818644129 | |||
| 55c9fc0017 | |||
| 140dd2c8cc | |||
| fada223249 | |||
| 00f8a80ca4 | |||
| 4e9407b4ae | |||
| 42461bc378 | |||
| 92780c486a | |||
| 81f9296d79 | |||
| 606f4e6c9e | |||
| 4cd4526492 | |||
| cc8a10376a | |||
| 5ebe334a2f | |||
| 932496a8ec | |||
| a8a060676a | |||
| 2c10ccd622 | |||
| a2211c200d | |||
| 21ba9e6d72 | |||
| ac9113b0ef | |||
| 11779697de | |||
| d6e006f086 | |||
| d39fa75d36 | |||
| f56bceb2a9 | |||
| bbaf918d74 | |||
| 89a97be2c5 | |||
| 6f2fc2f1cb | |||
| 42da080d89 | |||
| 1f4a17863f | |||
| 4d3a3a97ef | |||
| ff1020ccfb | |||
| ca3bd2cf9f | |||
| eb661c028d | |||
| 10c28c5ecd | |||
| 96810b7d97 | |||
| 365f9b01ae | |||
| 7d4d687dde | |||
| 6a664fea3b | |||
| dcdc1b0ec7 | |||
| 4af4c36e60 | |||
| 05e5244d94 | |||
| c2ee2bf7fe | |||
| 461c81e14a | |||
| 675d18d359 | |||
| 750335978c | |||
| ae7c623a35 | |||
| f24bdc0f83 | |||
| 07ef35b7e6 | |||
| 7c9823a1ff | |||
| a0c3bcf798 | |||
| 1a4a7d1705 | |||
| f141947085 | |||
| a07e947644 | |||
| ae4692a845 | |||
| 7dac269429 | |||
| ec5575dce2 | |||
| 6fee60e110 | |||
| 52f91c2388 | |||
| 348265afc1 | |||
| a7e466142d | |||
| 2fccf3924d | |||
| 4705d07e11 | |||
| 68be3b9a3d | |||
| e2d17d808b | |||
| 95edbd43ba | |||
| b96d553cd8 | |||
| bffdb5fb11 | |||
| 109e782493 | |||
| ff2c70608d | |||
| 5903d1c8f1 | |||
| f0392e7501 | |||
| 4037788e0c | |||
| 59884ab0fb | |||
| 4a6d37f0e8 | |||
| 731e2d5f26 | |||
| df3cbb9b9e | |||
| 5402666b19 | |||
| 4ec6a4e493 | |||
| 2d5ad42128 | |||
| dccda35f65 | |||
| d142b9095e | |||
| c2c079886f | |||
| c3ae1aaecd | |||
| f099bc1236 | |||
| 0b5d1ebefa | |||
| 082c2ed11c | |||
| a764f0a5b2 | |||
| 651d9fff9f | |||
| fddfce303c | |||
| a24fc8291b | |||
| 37e4485415 | |||
| 8d3f9d61da | |||
| 27c55f6514 | |||
| 9883c572cd | |||
| f9619defcc | |||
| 01f0ced1e6 | |||
| 647fb115a0 | |||
| 2114b9e3ad | |||
| 45b96acf6b | |||
| 3305215144 | |||
| 86b03f399a | |||
| 8dc5b4dc56 | |||
| ef5341b664 | |||
| 050534e743 | |||
| 3fe94d3386 | |||
| 3364cf96cf | |||
| a1ed4430ce | |||
| 7f11a79ad9 | |||
| ddcd9cf2c4 | |||
| c2e9064474 | |||
| bc9e1e3b9a | |||
| 613d2c5790 | |||
| 51bc41b2e8 | |||
| 9de3ecc4a8 | |||
| c4a66204f0 | |||
| 3558a6c170 | |||
| 595fc4ccec | |||
| 3ad147d349 | |||
| d285d8cd97 | |||
| 5714895291 | |||
| a33936e8ff | |||
| 9f8161d13e | |||
| a599a0f4bf | |||
| 7498bc63a3 | |||
| 894bf995bb | |||
| 52dbacc506 | |||
| cbcbbc41af | |||
| 6044314811 | |||
| 5fb38ecc2a | |||
| 73db759558 | |||
| 6e9691a419 | |||
| fd53b83190 | |||
| c7b5bfb809 | |||
| cfd1250615 | |||
| c8eeba5880 | |||
| 1812491679 | |||
| 7b6ab22b78 | |||
| c20d112f60 | |||
| 2817be14d5 | |||
| f6217bb990 | |||
| a3ceb7a944 | |||
| 0f8f35bd5b | |||
| 6373ff898b | |||
| d1c4077a75 | |||
| 059f375d85 | |||
| 8cbfb5aef6 | |||
| 5ebabf5bed | |||
| e23c8a5dcd | |||
| 89ea760e67 | |||
| 02b976ffa4 | |||
| 556b5ad686 | |||
| 884aabd130 | |||
| f0dac1d90e | |||
| 4a2978150c | |||
| df0c092b22 | |||
| 7d4258f50e | |||
| e24fabb03c | |||
| ce08ee399b | |||
| badd5aa101 | |||
| 5ff3be22b4 | |||
| df09cbd271 | |||
| 957bc021eb | |||
| 49dbfdbfb0 | |||
| 9a5c5c46f2 | |||
| 8197f9a873 | |||
| bab6a4a219 | |||
| 6c93157b14 | |||
| 033029eaa1 | |||
| a958ddb27a | |||
| f63f007326 | |||
| b47f1afa35 | |||
| 2369be7244 | |||
| 00bb6fbd28 | |||
| 063b06494a | |||
| b824185a3a | |||
| 8e6ddd7c1b | |||
| d1bc7ad2ee | |||
| 321474fb97 | |||
| ea89e4e0c6 | |||
| 9e31631d8f | |||
| 712d537d66 | |||
| bd4eb19393 | |||
| 02efab7c11 | |||
| 8ce129bc51 | |||
| d5a44e913d | |||
| 1444de981c | |||
| bd76b8ff1a | |||
| a95f22fa88 | |||
| 38ac6a7c27 | |||
| e5f3d5ae26 | |||
| 4cbc91f2fa | |||
| 6d3d3a40ab | |||
| 51b12841d6 | |||
| 993bf7c2c8 | |||
| b42b5fcf65 | |||
| 5d391fb1f9 | |||
| 2ddfcc7cf6 | |||
| 5ba51b21c9 | |||
| 3ea84ad9c8 | |||
| 0a5dce50fb | |||
| 6c9afd1ffb | |||
| bfef96d56e | |||
| 74adf3d59c | |||
| ba7e087aef | |||
| f911aa2997 | |||
| 42f9ac997f | |||
| c7cf7aad4e | |||
| 2118bc2556 | |||
| b49eb6826b | |||
| 8dd2394e93 | |||
| 5aea82d9c4 | |||
| 47005ebe10 | |||
| 3ee47e4af7 | |||
| 55c0468ac9 | |||
| eeb36a5ce7 | |||
| aceca266ff | |||
| d82e502a71 | |||
| 0494b92371 | |||
| 8683a5b1b7 | |||
| 4cbe470089 | |||
| 6cd1824a77 | |||
| 2844700dc4 | |||
| f8fd1ea7e1 | |||
| 57edc215d7 | |||
| 7a4044b05f | |||
| e84d5412bc | |||
| 151480dc85 | |||
| 2331b3a270 | |||
| 5cd1a678c8 | |||
| cc9546b761 | |||
| a63dcfed6f | |||
| 4dd8cdc38b | |||
| 1a4822d6be | |||
| ce161f09cc | |||
| 672958a192 | |||
| 3820de916c | |||
| ef44979b5c |
22
.github/workflows/release.yml
vendored
22
.github/workflows/release.yml
vendored
@ -10,6 +10,12 @@ on:
|
|||||||
tags:
|
tags:
|
||||||
- "v*.*.*" # normal release
|
- "v*.*.*" # normal release
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
actions: read
|
||||||
|
checks: read
|
||||||
|
statuses: read
|
||||||
|
|
||||||
# https://docs.github.com/en/actions/using-jobs/using-concurrency
|
# https://docs.github.com/en/actions/using-jobs/using-concurrency
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
@ -76,6 +82,14 @@ jobs:
|
|||||||
# The body field does not support environment variable substitution directly.
|
# The body field does not support environment variable substitution directly.
|
||||||
body_path: release_body.md
|
body_path: release_body.md
|
||||||
|
|
||||||
|
- name: Build and push image
|
||||||
|
run: |
|
||||||
|
sudo docker login --username infiniflow --password-stdin <<< ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
sudo docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -t infiniflow/ragflow:${RELEASE_TAG} -f Dockerfile .
|
||||||
|
sudo docker tag infiniflow/ragflow:${RELEASE_TAG} infiniflow/ragflow:latest
|
||||||
|
sudo docker push infiniflow/ragflow:${RELEASE_TAG}
|
||||||
|
sudo docker push infiniflow/ragflow:latest
|
||||||
|
|
||||||
- name: Build and push ragflow-sdk
|
- name: Build and push ragflow-sdk
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
run: |
|
run: |
|
||||||
@ -85,11 +99,3 @@ jobs:
|
|||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
run: |
|
run: |
|
||||||
cd admin/client && uv build && uv publish --token ${{ secrets.PYPI_API_TOKEN }}
|
cd admin/client && uv build && uv publish --token ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
|
||||||
- name: Build and push image
|
|
||||||
run: |
|
|
||||||
sudo docker login --username infiniflow --password-stdin <<< ${{ secrets.DOCKERHUB_TOKEN }}
|
|
||||||
sudo docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -t infiniflow/ragflow:${RELEASE_TAG} -f Dockerfile .
|
|
||||||
sudo docker tag infiniflow/ragflow:${RELEASE_TAG} infiniflow/ragflow:latest
|
|
||||||
sudo docker push infiniflow/ragflow:${RELEASE_TAG}
|
|
||||||
sudo docker push infiniflow/ragflow:latest
|
|
||||||
|
|||||||
48
.github/workflows/tests.yml
vendored
48
.github/workflows/tests.yml
vendored
@ -86,6 +86,9 @@ jobs:
|
|||||||
mkdir -p ${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}
|
mkdir -p ${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}
|
||||||
echo "${PR_SHA} ${GITHUB_RUN_ID}" > ${PR_SHA_FP}
|
echo "${PR_SHA} ${GITHUB_RUN_ID}" > ${PR_SHA_FP}
|
||||||
fi
|
fi
|
||||||
|
ARTIFACTS_DIR=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/${GITHUB_RUN_ID}
|
||||||
|
echo "ARTIFACTS_DIR=${ARTIFACTS_DIR}" >> ${GITHUB_ENV}
|
||||||
|
rm -rf ${ARTIFACTS_DIR} && mkdir -p ${ARTIFACTS_DIR}
|
||||||
|
|
||||||
# https://github.com/astral-sh/ruff-action
|
# https://github.com/astral-sh/ruff-action
|
||||||
- name: Static check with Ruff
|
- name: Static check with Ruff
|
||||||
@ -161,7 +164,7 @@ jobs:
|
|||||||
INFINITY_THRIFT_PORT=$((23817 + RUNNER_NUM * 10))
|
INFINITY_THRIFT_PORT=$((23817 + RUNNER_NUM * 10))
|
||||||
INFINITY_HTTP_PORT=$((23820 + RUNNER_NUM * 10))
|
INFINITY_HTTP_PORT=$((23820 + RUNNER_NUM * 10))
|
||||||
INFINITY_PSQL_PORT=$((5432 + RUNNER_NUM * 10))
|
INFINITY_PSQL_PORT=$((5432 + RUNNER_NUM * 10))
|
||||||
MYSQL_PORT=$((5455 + RUNNER_NUM * 10))
|
EXPOSE_MYSQL_PORT=$((5455 + RUNNER_NUM * 10))
|
||||||
MINIO_PORT=$((9000 + RUNNER_NUM * 10))
|
MINIO_PORT=$((9000 + RUNNER_NUM * 10))
|
||||||
MINIO_CONSOLE_PORT=$((9001 + RUNNER_NUM * 10))
|
MINIO_CONSOLE_PORT=$((9001 + RUNNER_NUM * 10))
|
||||||
REDIS_PORT=$((6379 + RUNNER_NUM * 10))
|
REDIS_PORT=$((6379 + RUNNER_NUM * 10))
|
||||||
@ -181,7 +184,7 @@ jobs:
|
|||||||
echo -e "INFINITY_THRIFT_PORT=${INFINITY_THRIFT_PORT}" >> docker/.env
|
echo -e "INFINITY_THRIFT_PORT=${INFINITY_THRIFT_PORT}" >> docker/.env
|
||||||
echo -e "INFINITY_HTTP_PORT=${INFINITY_HTTP_PORT}" >> docker/.env
|
echo -e "INFINITY_HTTP_PORT=${INFINITY_HTTP_PORT}" >> docker/.env
|
||||||
echo -e "INFINITY_PSQL_PORT=${INFINITY_PSQL_PORT}" >> docker/.env
|
echo -e "INFINITY_PSQL_PORT=${INFINITY_PSQL_PORT}" >> docker/.env
|
||||||
echo -e "MYSQL_PORT=${MYSQL_PORT}" >> docker/.env
|
echo -e "EXPOSE_MYSQL_PORT=${EXPOSE_MYSQL_PORT}" >> docker/.env
|
||||||
echo -e "MINIO_PORT=${MINIO_PORT}" >> docker/.env
|
echo -e "MINIO_PORT=${MINIO_PORT}" >> docker/.env
|
||||||
echo -e "MINIO_CONSOLE_PORT=${MINIO_CONSOLE_PORT}" >> docker/.env
|
echo -e "MINIO_CONSOLE_PORT=${MINIO_CONSOLE_PORT}" >> docker/.env
|
||||||
echo -e "REDIS_PORT=${REDIS_PORT}" >> docker/.env
|
echo -e "REDIS_PORT=${REDIS_PORT}" >> docker/.env
|
||||||
@ -205,29 +208,36 @@ jobs:
|
|||||||
- name: Run sdk tests against Elasticsearch
|
- name: Run sdk tests against Elasticsearch
|
||||||
run: |
|
run: |
|
||||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
|
||||||
echo "Waiting for service to be available..."
|
echo "Waiting for service to be available..."
|
||||||
sleep 5
|
sleep 5
|
||||||
done
|
done
|
||||||
source .venv/bin/activate && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api
|
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
|
||||||
|
|
||||||
- name: Run frontend api tests against Elasticsearch
|
- name: Run web api tests against Elasticsearch
|
||||||
run: |
|
run: |
|
||||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
|
||||||
echo "Waiting for service to be available..."
|
echo "Waiting for service to be available..."
|
||||||
sleep 5
|
sleep 5
|
||||||
done
|
done
|
||||||
source .venv/bin/activate && pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py
|
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api/ 2>&1 | tee es_web_api_test.log
|
||||||
|
|
||||||
- name: Run http api tests against Elasticsearch
|
- name: Run http api tests against Elasticsearch
|
||||||
run: |
|
run: |
|
||||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
|
||||||
echo "Waiting for service to be available..."
|
echo "Waiting for service to be available..."
|
||||||
sleep 5
|
sleep 5
|
||||||
done
|
done
|
||||||
source .venv/bin/activate && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api
|
source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log
|
||||||
|
|
||||||
|
- name: Collect ragflow log
|
||||||
|
if: ${{ !cancelled() }}
|
||||||
|
run: |
|
||||||
|
cp -r docker/ragflow-logs ${ARTIFACTS_DIR}/ragflow-logs-es
|
||||||
|
echo "ragflow log" && tail -n 200 docker/ragflow-logs/ragflow_server.log
|
||||||
|
sudo rm -rf docker/ragflow-logs
|
||||||
|
|
||||||
- name: Stop ragflow:nightly
|
- name: Stop ragflow:nightly
|
||||||
if: always() # always run this step even if previous steps failed
|
if: always() # always run this step even if previous steps failed
|
||||||
@ -243,30 +253,36 @@ jobs:
|
|||||||
- name: Run sdk tests against Infinity
|
- name: Run sdk tests against Infinity
|
||||||
run: |
|
run: |
|
||||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
|
||||||
echo "Waiting for service to be available..."
|
echo "Waiting for service to be available..."
|
||||||
sleep 5
|
sleep 5
|
||||||
done
|
done
|
||||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api
|
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
|
||||||
|
|
||||||
- name: Run frontend api tests against Infinity
|
- name: Run web api tests against Infinity
|
||||||
run: |
|
run: |
|
||||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
|
||||||
echo "Waiting for service to be available..."
|
echo "Waiting for service to be available..."
|
||||||
sleep 5
|
sleep 5
|
||||||
done
|
done
|
||||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py
|
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api/ 2>&1 | tee infinity_web_api_test.log
|
||||||
|
|
||||||
- name: Run http api tests against Infinity
|
- name: Run http api tests against Infinity
|
||||||
run: |
|
run: |
|
||||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do
|
||||||
echo "Waiting for service to be available..."
|
echo "Waiting for service to be available..."
|
||||||
sleep 5
|
sleep 5
|
||||||
done
|
done
|
||||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api
|
source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log
|
||||||
|
|
||||||
|
- name: Collect ragflow log
|
||||||
|
if: ${{ !cancelled() }}
|
||||||
|
run: |
|
||||||
|
cp -r docker/ragflow-logs ${ARTIFACTS_DIR}/ragflow-logs-infinity
|
||||||
|
echo "ragflow log" && tail -n 200 docker/ragflow-logs/ragflow_server.log
|
||||||
|
sudo rm -rf docker/ragflow-logs
|
||||||
- name: Stop ragflow:nightly
|
- name: Stop ragflow:nightly
|
||||||
if: always() # always run this step even if previous steps failed
|
if: always() # always run this step even if previous steps failed
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
13
.gitignore
vendored
13
.gitignore
vendored
@ -44,6 +44,7 @@ cl100k_base.tiktoken
|
|||||||
chrome*
|
chrome*
|
||||||
huggingface.co/
|
huggingface.co/
|
||||||
nltk_data/
|
nltk_data/
|
||||||
|
uv-x86_64*.tar.gz
|
||||||
|
|
||||||
# Exclude hash-like temporary files like 9b5ad71b2ce5302211f9c61530b329a4922fc6a4
|
# Exclude hash-like temporary files like 9b5ad71b2ce5302211f9c61530b329a4922fc6a4
|
||||||
*[0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f]*
|
*[0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f]*
|
||||||
@ -51,6 +52,13 @@ nltk_data/
|
|||||||
.venv
|
.venv
|
||||||
docker/data
|
docker/data
|
||||||
|
|
||||||
|
# OceanBase data and conf
|
||||||
|
docker/oceanbase/conf
|
||||||
|
docker/oceanbase/data
|
||||||
|
|
||||||
|
# SeekDB data and conf
|
||||||
|
docker/seekdb
|
||||||
|
|
||||||
|
|
||||||
#--------------------------------------------------#
|
#--------------------------------------------------#
|
||||||
# The following was generated with gitignore.nvim: #
|
# The following was generated with gitignore.nvim: #
|
||||||
@ -198,3 +206,8 @@ backup
|
|||||||
|
|
||||||
|
|
||||||
.hypothesis
|
.hypothesis
|
||||||
|
|
||||||
|
|
||||||
|
# Added by cargo
|
||||||
|
|
||||||
|
/target
|
||||||
|
|||||||
26
Dockerfile
26
Dockerfile
@ -19,17 +19,17 @@ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co
|
|||||||
# This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
|
# This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
|
||||||
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
|
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
|
||||||
cp -r /deps/nltk_data /root/ && \
|
cp -r /deps/nltk_data /root/ && \
|
||||||
cp /deps/tika-server-standard-3.0.0.jar /deps/tika-server-standard-3.0.0.jar.md5 /ragflow/ && \
|
cp /deps/tika-server-standard-3.2.3.jar /deps/tika-server-standard-3.2.3.jar.md5 /ragflow/ && \
|
||||||
cp /deps/cl100k_base.tiktoken /ragflow/9b5ad71b2ce5302211f9c61530b329a4922fc6a4
|
cp /deps/cl100k_base.tiktoken /ragflow/9b5ad71b2ce5302211f9c61530b329a4922fc6a4
|
||||||
|
|
||||||
ENV TIKA_SERVER_JAR="file:///ragflow/tika-server-standard-3.0.0.jar"
|
ENV TIKA_SERVER_JAR="file:///ragflow/tika-server-standard-3.2.3.jar"
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# Setup apt
|
# Setup apt
|
||||||
# Python package and implicit dependencies:
|
# Python package and implicit dependencies:
|
||||||
# opencv-python: libglib2.0-0 libglx-mesa0 libgl1
|
# opencv-python: libglib2.0-0 libglx-mesa0 libgl1
|
||||||
# aspose-slides: pkg-config libicu-dev libgdiplus libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
# aspose-slides: pkg-config libicu-dev libgdiplus libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||||
# python-pptx: default-jdk tika-server-standard-3.0.0.jar
|
# python-pptx: default-jdk tika-server-standard-3.2.3.jar
|
||||||
# selenium: libatk-bridge2.0-0 chrome-linux64-121-0-6167-85
|
# selenium: libatk-bridge2.0-0 chrome-linux64-121-0-6167-85
|
||||||
# Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev
|
# Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev
|
||||||
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
||||||
@ -53,7 +53,8 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
|||||||
apt install -y ghostscript && \
|
apt install -y ghostscript && \
|
||||||
apt install -y pandoc && \
|
apt install -y pandoc && \
|
||||||
apt install -y texlive && \
|
apt install -y texlive && \
|
||||||
apt install -y fonts-freefont-ttf fonts-noto-cjk
|
apt install -y fonts-freefont-ttf fonts-noto-cjk && \
|
||||||
|
apt install -y postgresql-client
|
||||||
|
|
||||||
# Install uv
|
# Install uv
|
||||||
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
|
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
|
||||||
@ -64,10 +65,12 @@ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps
|
|||||||
echo 'url = "https://pypi.tuna.tsinghua.edu.cn/simple"' >> /etc/uv/uv.toml && \
|
echo 'url = "https://pypi.tuna.tsinghua.edu.cn/simple"' >> /etc/uv/uv.toml && \
|
||||||
echo 'default = true' >> /etc/uv/uv.toml; \
|
echo 'default = true' >> /etc/uv/uv.toml; \
|
||||||
fi; \
|
fi; \
|
||||||
tar xzf /deps/uv-x86_64-unknown-linux-gnu.tar.gz \
|
arch="$(uname -m)"; \
|
||||||
&& cp uv-x86_64-unknown-linux-gnu/* /usr/local/bin/ \
|
if [ "$arch" = "x86_64" ]; then uv_arch="x86_64"; else uv_arch="aarch64"; fi; \
|
||||||
&& rm -rf uv-x86_64-unknown-linux-gnu \
|
tar xzf "/deps/uv-${uv_arch}-unknown-linux-gnu.tar.gz" \
|
||||||
&& uv python install 3.11
|
&& cp "uv-${uv_arch}-unknown-linux-gnu/"* /usr/local/bin/ \
|
||||||
|
&& rm -rf "uv-${uv_arch}-unknown-linux-gnu" \
|
||||||
|
&& uv python install 3.12
|
||||||
|
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
|
ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
|
||||||
ENV PATH=/root/.local/bin:$PATH
|
ENV PATH=/root/.local/bin:$PATH
|
||||||
@ -152,11 +155,14 @@ RUN --mount=type=cache,id=ragflow_uv,target=/root/.cache/uv,sharing=locked \
|
|||||||
else \
|
else \
|
||||||
sed -i 's|pypi.tuna.tsinghua.edu.cn|pypi.org|g' uv.lock; \
|
sed -i 's|pypi.tuna.tsinghua.edu.cn|pypi.org|g' uv.lock; \
|
||||||
fi; \
|
fi; \
|
||||||
uv sync --python 3.12 --frozen
|
uv sync --python 3.12 --frozen && \
|
||||||
|
# Ensure pip is available in the venv for runtime package installation (fixes #12651)
|
||||||
|
.venv/bin/python3 -m ensurepip --upgrade
|
||||||
|
|
||||||
COPY web web
|
COPY web web
|
||||||
COPY docs docs
|
COPY docs docs
|
||||||
RUN --mount=type=cache,id=ragflow_npm,target=/root/.npm,sharing=locked \
|
RUN --mount=type=cache,id=ragflow_npm,target=/root/.npm,sharing=locked \
|
||||||
|
export NODE_OPTIONS="--max-old-space-size=4096" && \
|
||||||
cd web && npm install && npm run build
|
cd web && npm install && npm run build
|
||||||
|
|
||||||
COPY .git /ragflow/.git
|
COPY .git /ragflow/.git
|
||||||
@ -187,11 +193,11 @@ COPY deepdoc deepdoc
|
|||||||
COPY rag rag
|
COPY rag rag
|
||||||
COPY agent agent
|
COPY agent agent
|
||||||
COPY graphrag graphrag
|
COPY graphrag graphrag
|
||||||
COPY agentic_reasoning agentic_reasoning
|
|
||||||
COPY pyproject.toml uv.lock ./
|
COPY pyproject.toml uv.lock ./
|
||||||
COPY mcp mcp
|
COPY mcp mcp
|
||||||
COPY plugin plugin
|
COPY plugin plugin
|
||||||
COPY common common
|
COPY common common
|
||||||
|
COPY memory memory
|
||||||
|
|
||||||
COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template
|
COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template
|
||||||
COPY docker/entrypoint.sh ./
|
COPY docker/entrypoint.sh ./
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
FROM scratch
|
FROM scratch
|
||||||
|
|
||||||
# Copy resources downloaded via download_deps.py
|
# Copy resources downloaded via download_deps.py
|
||||||
COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.0.0.jar tika-server-standard-3.0.0.jar.md5 libssl*.deb uv-x86_64-unknown-linux-gnu.tar.gz /
|
COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.2.3.jar tika-server-standard-3.2.3.jar.md5 libssl*.deb uv-x86_64-unknown-linux-gnu.tar.gz uv-aarch64-unknown-linux-gnu.tar.gz /
|
||||||
|
|
||||||
COPY nltk_data /nltk_data
|
COPY nltk_data /nltk_data
|
||||||
|
|
||||||
|
|||||||
32
README.md
32
README.md
@ -22,7 +22,7 @@
|
|||||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||||
@ -37,7 +37,7 @@
|
|||||||
|
|
||||||
<h4 align="center">
|
<h4 align="center">
|
||||||
<a href="https://ragflow.io/docs/dev/">Document</a> |
|
<a href="https://ragflow.io/docs/dev/">Document</a> |
|
||||||
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> |
|
<a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
|
||||||
<a href="https://twitter.com/infiniflowai">Twitter</a> |
|
<a href="https://twitter.com/infiniflowai">Twitter</a> |
|
||||||
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
|
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
|
||||||
<a href="https://demo.ragflow.io">Demo</a>
|
<a href="https://demo.ragflow.io">Demo</a>
|
||||||
@ -72,7 +72,7 @@
|
|||||||
|
|
||||||
## 💡 What is RAGFlow?
|
## 💡 What is RAGFlow?
|
||||||
|
|
||||||
[RAGFlow](https://ragflow.io/) is a leading open-source Retrieval-Augmented Generation (RAG) engine that fuses cutting-edge RAG with Agent capabilities to create a superior context layer for LLMs. It offers a streamlined RAG workflow adaptable to enterprises of any scale. Powered by a converged context engine and pre-built agent templates, RAGFlow enables developers to transform complex data into high-fidelity, production-ready AI systems with exceptional efficiency and precision.
|
[RAGFlow](https://ragflow.io/) is a leading open-source Retrieval-Augmented Generation ([RAG](https://ragflow.io/basics/what-is-rag)) engine that fuses cutting-edge RAG with Agent capabilities to create a superior context layer for LLMs. It offers a streamlined RAG workflow adaptable to enterprises of any scale. Powered by a converged [context engine](https://ragflow.io/basics/what-is-agent-context-engine) and pre-built agent templates, RAGFlow enables developers to transform complex data into high-fidelity, production-ready AI systems with exceptional efficiency and precision.
|
||||||
|
|
||||||
## 🎮 Demo
|
## 🎮 Demo
|
||||||
|
|
||||||
@ -85,6 +85,7 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
|
|
||||||
## 🔥 Latest Updates
|
## 🔥 Latest Updates
|
||||||
|
|
||||||
|
- 2025-12-26 Supports 'Memory' for AI agent.
|
||||||
- 2025-11-19 Supports Gemini 3 Pro.
|
- 2025-11-19 Supports Gemini 3 Pro.
|
||||||
- 2025-11-12 Supports data synchronization from Confluence, S3, Notion, Discord, Google Drive.
|
- 2025-11-12 Supports data synchronization from Confluence, S3, Notion, Discord, Google Drive.
|
||||||
- 2025-10-23 Supports MinerU & Docling as document parsing methods.
|
- 2025-10-23 Supports MinerU & Docling as document parsing methods.
|
||||||
@ -187,12 +188,12 @@ releases! 🌟
|
|||||||
> All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64.
|
> All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64.
|
||||||
> If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system.
|
> If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system.
|
||||||
|
|
||||||
> The command below downloads the `v0.22.1` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.22.1`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
|
> The command below downloads the `v0.23.1` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.23.1`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cd ragflow/docker
|
$ cd ragflow/docker
|
||||||
|
|
||||||
# git checkout v0.22.1
|
# git checkout v0.23.1
|
||||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
|
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
|
||||||
# This step ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
# This step ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||||
|
|
||||||
@ -206,10 +207,10 @@ releases! 🌟
|
|||||||
|
|
||||||
> Note: Prior to `v0.22.0`, we provided both images with embedding models and slim images without embedding models. Details as follows:
|
> Note: Prior to `v0.22.0`, we provided both images with embedding models and slim images without embedding models. Details as follows:
|
||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
|-------------------|-----------------|-----------------------|----------------|
|
||||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||||
|
|
||||||
> Starting with `v0.22.0`, we ship only the slim edition and no longer append the **-slim** suffix to the image tag.
|
> Starting with `v0.22.0`, we ship only the slim edition and no longer append the **-slim** suffix to the image tag.
|
||||||
|
|
||||||
@ -232,7 +233,7 @@ releases! 🌟
|
|||||||
* Running on all addresses (0.0.0.0)
|
* Running on all addresses (0.0.0.0)
|
||||||
```
|
```
|
||||||
|
|
||||||
> If you skip this confirmation step and directly log in to RAGFlow, your browser may prompt a `network anormal`
|
> If you skip this confirmation step and directly log in to RAGFlow, your browser may prompt a `network abnormal`
|
||||||
> error because, at that moment, your RAGFlow may not be fully initialized.
|
> error because, at that moment, your RAGFlow may not be fully initialized.
|
||||||
>
|
>
|
||||||
5. In your web browser, enter the IP address of your server and log in to RAGFlow.
|
5. In your web browser, enter the IP address of your server and log in to RAGFlow.
|
||||||
@ -302,6 +303,15 @@ cd ragflow/
|
|||||||
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
|
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Or if you are behind a proxy, you can pass proxy arguments:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build --platform linux/amd64 \
|
||||||
|
--build-arg http_proxy=http://YOUR_PROXY:PORT \
|
||||||
|
--build-arg https_proxy=http://YOUR_PROXY:PORT \
|
||||||
|
-f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
|
```
|
||||||
|
|
||||||
## 🔨 Launch service from source for development
|
## 🔨 Launch service from source for development
|
||||||
|
|
||||||
1. Install `uv` and `pre-commit`, or skip this step if they are already installed:
|
1. Install `uv` and `pre-commit`, or skip this step if they are already installed:
|
||||||
@ -386,7 +396,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
|
|
||||||
## 📜 Roadmap
|
## 📜 Roadmap
|
||||||
|
|
||||||
See the [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214)
|
See the [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241)
|
||||||
|
|
||||||
## 🏄 Community
|
## 🏄 Community
|
||||||
|
|
||||||
|
|||||||
32
README_id.md
32
README_id.md
@ -22,7 +22,7 @@
|
|||||||
<img alt="Lencana Daring" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
<img alt="Lencana Daring" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Rilis%20Terbaru" alt="Rilis Terbaru">
|
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Rilis%20Terbaru" alt="Rilis Terbaru">
|
||||||
@ -37,7 +37,7 @@
|
|||||||
|
|
||||||
<h4 align="center">
|
<h4 align="center">
|
||||||
<a href="https://ragflow.io/docs/dev/">Dokumentasi</a> |
|
<a href="https://ragflow.io/docs/dev/">Dokumentasi</a> |
|
||||||
<a href="https://github.com/infiniflow/ragflow/issues/4214">Peta Jalan</a> |
|
<a href="https://github.com/infiniflow/ragflow/issues/12241">Peta Jalan</a> |
|
||||||
<a href="https://twitter.com/infiniflowai">Twitter</a> |
|
<a href="https://twitter.com/infiniflowai">Twitter</a> |
|
||||||
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
|
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
|
||||||
<a href="https://demo.ragflow.io">Demo</a>
|
<a href="https://demo.ragflow.io">Demo</a>
|
||||||
@ -72,7 +72,7 @@
|
|||||||
|
|
||||||
## 💡 Apa Itu RAGFlow?
|
## 💡 Apa Itu RAGFlow?
|
||||||
|
|
||||||
[RAGFlow](https://ragflow.io/) adalah mesin RAG (Retrieval-Augmented Generation) open-source terkemuka yang mengintegrasikan teknologi RAG mutakhir dengan kemampuan Agent untuk menciptakan lapisan kontekstual superior bagi LLM. Menyediakan alur kerja RAG yang efisien dan dapat diadaptasi untuk perusahaan segala skala. Didukung oleh mesin konteks terkonvergensi dan template Agent yang telah dipra-bangun, RAGFlow memungkinkan pengembang mengubah data kompleks menjadi sistem AI kesetiaan-tinggi dan siap-produksi dengan efisiensi dan presisi yang luar biasa.
|
[RAGFlow](https://ragflow.io/) adalah mesin [RAG](https://ragflow.io/basics/what-is-rag) (Retrieval-Augmented Generation) open-source terkemuka yang mengintegrasikan teknologi RAG mutakhir dengan kemampuan Agent untuk menciptakan lapisan kontekstual superior bagi LLM. Menyediakan alur kerja RAG yang efisien dan dapat diadaptasi untuk perusahaan segala skala. Didukung oleh mesin konteks terkonvergensi dan template Agent yang telah dipra-bangun, RAGFlow memungkinkan pengembang mengubah data kompleks menjadi sistem AI kesetiaan-tinggi dan siap-produksi dengan efisiensi dan presisi yang luar biasa.
|
||||||
|
|
||||||
## 🎮 Demo
|
## 🎮 Demo
|
||||||
|
|
||||||
@ -85,6 +85,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
|
|
||||||
## 🔥 Pembaruan Terbaru
|
## 🔥 Pembaruan Terbaru
|
||||||
|
|
||||||
|
- 2025-12-26 Mendukung 'Memori' untuk agen AI.
|
||||||
- 2025-11-19 Mendukung Gemini 3 Pro.
|
- 2025-11-19 Mendukung Gemini 3 Pro.
|
||||||
- 2025-11-12 Mendukung sinkronisasi data dari Confluence, S3, Notion, Discord, Google Drive.
|
- 2025-11-12 Mendukung sinkronisasi data dari Confluence, S3, Notion, Discord, Google Drive.
|
||||||
- 2025-10-23 Mendukung MinerU & Docling sebagai metode penguraian dokumen.
|
- 2025-10-23 Mendukung MinerU & Docling sebagai metode penguraian dokumen.
|
||||||
@ -187,12 +188,12 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
> Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64.
|
> Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64.
|
||||||
> Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image).
|
> Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image).
|
||||||
|
|
||||||
> Perintah di bawah ini mengunduh edisi v0.22.1 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.22.1, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
|
> Perintah di bawah ini mengunduh edisi v0.23.1 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.23.1, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cd ragflow/docker
|
$ cd ragflow/docker
|
||||||
|
|
||||||
# git checkout v0.22.1
|
# git checkout v0.23.1
|
||||||
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases)
|
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases)
|
||||||
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||||
|
|
||||||
@ -206,10 +207,10 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
|
|
||||||
> Catatan: Sebelum `v0.22.0`, kami menyediakan image dengan model embedding dan image slim tanpa model embedding. Detailnya sebagai berikut:
|
> Catatan: Sebelum `v0.22.0`, kami menyediakan image dengan model embedding dan image slim tanpa model embedding. Detailnya sebagai berikut:
|
||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
|-------------------|-----------------|-----------------------|----------------|
|
||||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||||
|
|
||||||
> Mulai dari `v0.22.0`, kami hanya menyediakan edisi slim dan tidak lagi menambahkan akhiran **-slim** pada tag image.
|
> Mulai dari `v0.22.0`, kami hanya menyediakan edisi slim dan tidak lagi menambahkan akhiran **-slim** pada tag image.
|
||||||
|
|
||||||
@ -232,7 +233,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
* Running on all addresses (0.0.0.0)
|
* Running on all addresses (0.0.0.0)
|
||||||
```
|
```
|
||||||
|
|
||||||
> Jika Anda melewatkan langkah ini dan langsung login ke RAGFlow, browser Anda mungkin menampilkan error `network anormal`
|
> Jika Anda melewatkan langkah ini dan langsung login ke RAGFlow, browser Anda mungkin menampilkan error `network abnormal`
|
||||||
> karena RAGFlow mungkin belum sepenuhnya siap.
|
> karena RAGFlow mungkin belum sepenuhnya siap.
|
||||||
>
|
>
|
||||||
2. Buka browser web Anda, masukkan alamat IP server Anda, dan login ke RAGFlow.
|
2. Buka browser web Anda, masukkan alamat IP server Anda, dan login ke RAGFlow.
|
||||||
@ -276,6 +277,15 @@ cd ragflow/
|
|||||||
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
|
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Jika berada di belakang proxy, Anda dapat melewatkan argumen proxy:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build --platform linux/amd64 \
|
||||||
|
--build-arg http_proxy=http://YOUR_PROXY:PORT \
|
||||||
|
--build-arg https_proxy=http://YOUR_PROXY:PORT \
|
||||||
|
-f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
|
```
|
||||||
|
|
||||||
## 🔨 Menjalankan Aplikasi dari untuk Pengembangan
|
## 🔨 Menjalankan Aplikasi dari untuk Pengembangan
|
||||||
|
|
||||||
1. Instal `uv` dan `pre-commit`, atau lewati langkah ini jika sudah terinstal:
|
1. Instal `uv` dan `pre-commit`, atau lewati langkah ini jika sudah terinstal:
|
||||||
@ -358,7 +368,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
|
|
||||||
## 📜 Roadmap
|
## 📜 Roadmap
|
||||||
|
|
||||||
Lihat [Roadmap RAGFlow 2025](https://github.com/infiniflow/ragflow/issues/4214)
|
Lihat [Roadmap RAGFlow 2026](https://github.com/infiniflow/ragflow/issues/12241)
|
||||||
|
|
||||||
## 🏄 Komunitas
|
## 🏄 Komunitas
|
||||||
|
|
||||||
|
|||||||
32
README_ja.md
32
README_ja.md
@ -22,7 +22,7 @@
|
|||||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||||
@ -37,7 +37,7 @@
|
|||||||
|
|
||||||
<h4 align="center">
|
<h4 align="center">
|
||||||
<a href="https://ragflow.io/docs/dev/">Document</a> |
|
<a href="https://ragflow.io/docs/dev/">Document</a> |
|
||||||
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> |
|
<a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
|
||||||
<a href="https://twitter.com/infiniflowai">Twitter</a> |
|
<a href="https://twitter.com/infiniflowai">Twitter</a> |
|
||||||
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
|
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
|
||||||
<a href="https://demo.ragflow.io">Demo</a>
|
<a href="https://demo.ragflow.io">Demo</a>
|
||||||
@ -53,7 +53,7 @@
|
|||||||
|
|
||||||
## 💡 RAGFlow とは?
|
## 💡 RAGFlow とは?
|
||||||
|
|
||||||
[RAGFlow](https://ragflow.io/) は、先進的なRAG(Retrieval-Augmented Generation)技術と Agent 機能を融合し、大規模言語モデル(LLM)に優れたコンテキスト層を構築する最先端のオープンソース RAG エンジンです。あらゆる規模の企業に対応可能な合理化された RAG ワークフローを提供し、統合型コンテキストエンジンと事前構築されたAgentテンプレートにより、開発者が複雑なデータを驚異的な効率性と精度で高精細なプロダクションレディAIシステムへ変換することを可能にします。
|
[RAGFlow](https://ragflow.io/) は、先進的な[RAG](https://ragflow.io/basics/what-is-rag)(Retrieval-Augmented Generation)技術と Agent 機能を融合し、大規模言語モデル(LLM)に優れたコンテキスト層を構築する最先端のオープンソース RAG エンジンです。あらゆる規模の企業に対応可能な合理化された RAG ワークフローを提供し、統合型[コンテキストエンジン](https://ragflow.io/basics/what-is-agent-context-engine)と事前構築されたAgentテンプレートにより、開発者が複雑なデータを驚異的な効率性と精度で高精細なプロダクションレディAIシステムへ変換することを可能にします。
|
||||||
|
|
||||||
## 🎮 Demo
|
## 🎮 Demo
|
||||||
|
|
||||||
@ -66,7 +66,8 @@
|
|||||||
|
|
||||||
## 🔥 最新情報
|
## 🔥 最新情報
|
||||||
|
|
||||||
- 2025-11-19 Gemini 3 Proをサポートしています
|
- 2025-12-26 AIエージェントの「メモリ」機能をサポート。
|
||||||
|
- 2025-11-19 Gemini 3 Proをサポートしています。
|
||||||
- 2025-11-12 Confluence、S3、Notion、Discord、Google Drive からのデータ同期をサポートします。
|
- 2025-11-12 Confluence、S3、Notion、Discord、Google Drive からのデータ同期をサポートします。
|
||||||
- 2025-10-23 ドキュメント解析方法として MinerU と Docling をサポートします。
|
- 2025-10-23 ドキュメント解析方法として MinerU と Docling をサポートします。
|
||||||
- 2025-10-15 オーケストレーションされたデータパイプラインのサポート。
|
- 2025-10-15 オーケストレーションされたデータパイプラインのサポート。
|
||||||
@ -167,12 +168,12 @@
|
|||||||
> 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。
|
> 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。
|
||||||
> ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。
|
> ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。
|
||||||
|
|
||||||
> 以下のコマンドは、RAGFlow Docker イメージの v0.22.1 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.22.1 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
|
> 以下のコマンドは、RAGFlow Docker イメージの v0.23.1 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.23.1 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cd ragflow/docker
|
$ cd ragflow/docker
|
||||||
|
|
||||||
# git checkout v0.22.1
|
# git checkout v0.23.1
|
||||||
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases)
|
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases)
|
||||||
# この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。
|
# この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。
|
||||||
|
|
||||||
@ -186,10 +187,10 @@
|
|||||||
|
|
||||||
> 注意:`v0.22.0` より前のバージョンでは、embedding モデルを含むイメージと、embedding モデルを含まない slim イメージの両方を提供していました。詳細は以下の通りです:
|
> 注意:`v0.22.0` より前のバージョンでは、embedding モデルを含むイメージと、embedding モデルを含まない slim イメージの両方を提供していました。詳細は以下の通りです:
|
||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
|-------------------|-----------------|-----------------------|----------------|
|
||||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||||
|
|
||||||
> `v0.22.0` 以降、当プロジェクトでは slim エディションのみを提供し、イメージタグに **-slim** サフィックスを付けなくなりました。
|
> `v0.22.0` 以降、当プロジェクトでは slim エディションのみを提供し、イメージタグに **-slim** サフィックスを付けなくなりました。
|
||||||
|
|
||||||
@ -276,6 +277,15 @@ cd ragflow/
|
|||||||
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
|
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
プロキシ環境下にいる場合は、プロキシ引数を指定できます:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build --platform linux/amd64 \
|
||||||
|
--build-arg http_proxy=http://YOUR_PROXY:PORT \
|
||||||
|
--build-arg https_proxy=http://YOUR_PROXY:PORT \
|
||||||
|
-f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
|
```
|
||||||
|
|
||||||
## 🔨 ソースコードからサービスを起動する方法
|
## 🔨 ソースコードからサービスを起動する方法
|
||||||
|
|
||||||
1. `uv` と `pre-commit` をインストールする。すでにインストールされている場合は、このステップをスキップしてください:
|
1. `uv` と `pre-commit` をインストールする。すでにインストールされている場合は、このステップをスキップしてください:
|
||||||
@ -358,7 +368,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
|
|
||||||
## 📜 ロードマップ
|
## 📜 ロードマップ
|
||||||
|
|
||||||
[RAGFlow ロードマップ 2025](https://github.com/infiniflow/ragflow/issues/4214) を参照
|
[RAGFlow ロードマップ 2026](https://github.com/infiniflow/ragflow/issues/12241) を参照
|
||||||
|
|
||||||
## 🏄 コミュニティ
|
## 🏄 コミュニティ
|
||||||
|
|
||||||
|
|||||||
32
README_ko.md
32
README_ko.md
@ -22,7 +22,7 @@
|
|||||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||||
@ -37,7 +37,7 @@
|
|||||||
|
|
||||||
<h4 align="center">
|
<h4 align="center">
|
||||||
<a href="https://ragflow.io/docs/dev/">Document</a> |
|
<a href="https://ragflow.io/docs/dev/">Document</a> |
|
||||||
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> |
|
<a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
|
||||||
<a href="https://twitter.com/infiniflowai">Twitter</a> |
|
<a href="https://twitter.com/infiniflowai">Twitter</a> |
|
||||||
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
|
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
|
||||||
<a href="https://demo.ragflow.io">Demo</a>
|
<a href="https://demo.ragflow.io">Demo</a>
|
||||||
@ -54,7 +54,7 @@
|
|||||||
|
|
||||||
## 💡 RAGFlow란?
|
## 💡 RAGFlow란?
|
||||||
|
|
||||||
[RAGFlow](https://ragflow.io/) 는 최첨단 RAG(Retrieval-Augmented Generation)와 Agent 기능을 융합하여 대규모 언어 모델(LLM)을 위한 우수한 컨텍스트 계층을 생성하는 선도적인 오픈소스 RAG 엔진입니다. 모든 규모의 기업에 적용 가능한 효율적인 RAG 워크플로를 제공하며, 통합 컨텍스트 엔진과 사전 구축된 Agent 템플릿을 통해 개발자들이 복잡한 데이터를 예외적인 효율성과 정밀도로 고급 구현도의 프로덕션 준비 완료 AI 시스템으로 변환할 수 있도록 지원합니다.
|
[RAGFlow](https://ragflow.io/) 는 최첨단 [RAG](https://ragflow.io/basics/what-is-rag)(Retrieval-Augmented Generation)와 Agent 기능을 융합하여 대규모 언어 모델(LLM)을 위한 우수한 컨텍스트 계층을 생성하는 선도적인 오픈소스 RAG 엔진입니다. 모든 규모의 기업에 적용 가능한 효율적인 RAG 워크플로를 제공하며, 통합 [컨텍스트 엔진](https://ragflow.io/basics/what-is-agent-context-engine)과 사전 구축된 Agent 템플릿을 통해 개발자들이 복잡한 데이터를 예외적인 효율성과 정밀도로 고급 구현도의 프로덕션 준비 완료 AI 시스템으로 변환할 수 있도록 지원합니다.
|
||||||
|
|
||||||
## 🎮 데모
|
## 🎮 데모
|
||||||
|
|
||||||
@ -67,6 +67,7 @@
|
|||||||
|
|
||||||
## 🔥 업데이트
|
## 🔥 업데이트
|
||||||
|
|
||||||
|
- 2025-12-26 AI 에이전트의 '메모리' 기능 지원.
|
||||||
- 2025-11-19 Gemini 3 Pro를 지원합니다.
|
- 2025-11-19 Gemini 3 Pro를 지원합니다.
|
||||||
- 2025-11-12 Confluence, S3, Notion, Discord, Google Drive에서 데이터 동기화를 지원합니다.
|
- 2025-11-12 Confluence, S3, Notion, Discord, Google Drive에서 데이터 동기화를 지원합니다.
|
||||||
- 2025-10-23 문서 파싱 방법으로 MinerU 및 Docling을 지원합니다.
|
- 2025-10-23 문서 파싱 방법으로 MinerU 및 Docling을 지원합니다.
|
||||||
@ -169,12 +170,12 @@
|
|||||||
> 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다.
|
> 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다.
|
||||||
> ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image).
|
> ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image).
|
||||||
|
|
||||||
> 아래 명령어는 RAGFlow Docker 이미지의 v0.22.1 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.22.1과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
|
> 아래 명령어는 RAGFlow Docker 이미지의 v0.23.1 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.23.1과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cd ragflow/docker
|
$ cd ragflow/docker
|
||||||
|
|
||||||
# git checkout v0.22.1
|
# git checkout v0.23.1
|
||||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
|
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
|
||||||
# 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다.
|
# 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다.
|
||||||
|
|
||||||
@ -188,10 +189,10 @@
|
|||||||
|
|
||||||
> 참고: `v0.22.0` 이전 버전에서는 embedding 모델이 포함된 이미지와 embedding 모델이 포함되지 않은 slim 이미지를 모두 제공했습니다. 자세한 내용은 다음과 같습니다:
|
> 참고: `v0.22.0` 이전 버전에서는 embedding 모델이 포함된 이미지와 embedding 모델이 포함되지 않은 slim 이미지를 모두 제공했습니다. 자세한 내용은 다음과 같습니다:
|
||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
|-------------------|-----------------|-----------------------|----------------|
|
||||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||||
|
|
||||||
> `v0.22.0`부터는 slim 에디션만 배포하며 이미지 태그에 **-slim** 접미사를 더 이상 붙이지 않습니다.
|
> `v0.22.0`부터는 slim 에디션만 배포하며 이미지 태그에 **-slim** 접미사를 더 이상 붙이지 않습니다.
|
||||||
|
|
||||||
@ -213,7 +214,7 @@
|
|||||||
* Running on all addresses (0.0.0.0)
|
* Running on all addresses (0.0.0.0)
|
||||||
```
|
```
|
||||||
|
|
||||||
> 만약 확인 단계를 건너뛰고 바로 RAGFlow에 로그인하면, RAGFlow가 완전히 초기화되지 않았기 때문에 브라우저에서 `network anormal` 오류가 발생할 수 있습니다.
|
> 만약 확인 단계를 건너뛰고 바로 RAGFlow에 로그인하면, RAGFlow가 완전히 초기화되지 않았기 때문에 브라우저에서 `network abnormal` 오류가 발생할 수 있습니다.
|
||||||
|
|
||||||
2. 웹 브라우저에 서버의 IP 주소를 입력하고 RAGFlow에 로그인하세요.
|
2. 웹 브라우저에 서버의 IP 주소를 입력하고 RAGFlow에 로그인하세요.
|
||||||
> 기본 설정을 사용할 경우, `http://IP_OF_YOUR_MACHINE`만 입력하면 됩니다 (포트 번호는 제외). 기본 HTTP 서비스 포트 `80`은 기본 구성으로 사용할 때 생략할 수 있습니다.
|
> 기본 설정을 사용할 경우, `http://IP_OF_YOUR_MACHINE`만 입력하면 됩니다 (포트 번호는 제외). 기본 HTTP 서비스 포트 `80`은 기본 구성으로 사용할 때 생략할 수 있습니다.
|
||||||
@ -270,6 +271,15 @@ cd ragflow/
|
|||||||
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
|
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
프록시 환경인 경우, 프록시 인수를 전달할 수 있습니다:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build --platform linux/amd64 \
|
||||||
|
--build-arg http_proxy=http://YOUR_PROXY:PORT \
|
||||||
|
--build-arg https_proxy=http://YOUR_PROXY:PORT \
|
||||||
|
-f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
|
```
|
||||||
|
|
||||||
## 🔨 소스 코드로 서비스를 시작합니다.
|
## 🔨 소스 코드로 서비스를 시작합니다.
|
||||||
|
|
||||||
1. `uv` 와 `pre-commit` 을 설치하거나, 이미 설치된 경우 이 단계를 건너뜁니다:
|
1. `uv` 와 `pre-commit` 을 설치하거나, 이미 설치된 경우 이 단계를 건너뜁니다:
|
||||||
@ -362,7 +372,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
|
|
||||||
## 📜 로드맵
|
## 📜 로드맵
|
||||||
|
|
||||||
[RAGFlow 로드맵 2025](https://github.com/infiniflow/ragflow/issues/4214)을 확인하세요.
|
[RAGFlow 로드맵 2026](https://github.com/infiniflow/ragflow/issues/12241)을 확인하세요.
|
||||||
|
|
||||||
## 🏄 커뮤니티
|
## 🏄 커뮤니티
|
||||||
|
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
<img alt="Badge Estático" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
<img alt="Badge Estático" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Última%20Relese" alt="Última Versão">
|
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Última%20Relese" alt="Última Versão">
|
||||||
@ -37,7 +37,7 @@
|
|||||||
|
|
||||||
<h4 align="center">
|
<h4 align="center">
|
||||||
<a href="https://ragflow.io/docs/dev/">Documentação</a> |
|
<a href="https://ragflow.io/docs/dev/">Documentação</a> |
|
||||||
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> |
|
<a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
|
||||||
<a href="https://twitter.com/infiniflowai">Twitter</a> |
|
<a href="https://twitter.com/infiniflowai">Twitter</a> |
|
||||||
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
|
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
|
||||||
<a href="https://demo.ragflow.io">Demo</a>
|
<a href="https://demo.ragflow.io">Demo</a>
|
||||||
@ -73,7 +73,7 @@
|
|||||||
|
|
||||||
## 💡 O que é o RAGFlow?
|
## 💡 O que é o RAGFlow?
|
||||||
|
|
||||||
[RAGFlow](https://ragflow.io/) é um mecanismo de RAG (Retrieval-Augmented Generation) open-source líder que fusiona tecnologias RAG de ponta com funcionalidades Agent para criar uma camada contextual superior para LLMs. Oferece um fluxo de trabalho RAG otimizado adaptável a empresas de qualquer escala. Alimentado por um motor de contexto convergente e modelos Agent pré-construídos, o RAGFlow permite que desenvolvedores transformem dados complexos em sistemas de IA de alta fidelidade e pronto para produção com excepcional eficiência e precisão.
|
[RAGFlow](https://ragflow.io/) é um mecanismo de [RAG](https://ragflow.io/basics/what-is-rag) (Retrieval-Augmented Generation) open-source líder que fusiona tecnologias RAG de ponta com funcionalidades Agent para criar uma camada contextual superior para LLMs. Oferece um fluxo de trabalho RAG otimizado adaptável a empresas de qualquer escala. Alimentado por [um motor de contexto](https://ragflow.io/basics/what-is-agent-context-engine) convergente e modelos Agent pré-construídos, o RAGFlow permite que desenvolvedores transformem dados complexos em sistemas de IA de alta fidelidade e pronto para produção com excepcional eficiência e precisão.
|
||||||
|
|
||||||
## 🎮 Demo
|
## 🎮 Demo
|
||||||
|
|
||||||
@ -86,6 +86,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
|
|
||||||
## 🔥 Últimas Atualizações
|
## 🔥 Últimas Atualizações
|
||||||
|
|
||||||
|
- 26-12-2025 Suporte à função 'Memória' para agentes de IA.
|
||||||
- 19-11-2025 Suporta Gemini 3 Pro.
|
- 19-11-2025 Suporta Gemini 3 Pro.
|
||||||
- 12-11-2025 Suporta a sincronização de dados do Confluence, S3, Notion, Discord e Google Drive.
|
- 12-11-2025 Suporta a sincronização de dados do Confluence, S3, Notion, Discord e Google Drive.
|
||||||
- 23-10-2025 Suporta MinerU e Docling como métodos de análise de documentos.
|
- 23-10-2025 Suporta MinerU e Docling como métodos de análise de documentos.
|
||||||
@ -187,12 +188,12 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
> Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64.
|
> Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64.
|
||||||
> Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema.
|
> Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema.
|
||||||
|
|
||||||
> O comando abaixo baixa a edição`v0.22.1` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.22.1`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
|
> O comando abaixo baixa a edição`v0.23.1` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.23.1`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cd ragflow/docker
|
$ cd ragflow/docker
|
||||||
|
|
||||||
# git checkout v0.22.1
|
# git checkout v0.23.1
|
||||||
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases)
|
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases)
|
||||||
# Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker.
|
# Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker.
|
||||||
|
|
||||||
@ -206,10 +207,10 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
|
|
||||||
> Nota: Antes da `v0.22.0`, fornecíamos imagens com modelos de embedding e imagens slim sem modelos de embedding. Detalhes a seguir:
|
> Nota: Antes da `v0.22.0`, fornecíamos imagens com modelos de embedding e imagens slim sem modelos de embedding. Detalhes a seguir:
|
||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
|-------------------|-----------------|-----------------------|----------------|
|
||||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||||
|
|
||||||
> A partir da `v0.22.0`, distribuímos apenas a edição slim e não adicionamos mais o sufixo **-slim** às tags das imagens.
|
> A partir da `v0.22.0`, distribuímos apenas a edição slim e não adicionamos mais o sufixo **-slim** às tags das imagens.
|
||||||
|
|
||||||
@ -231,7 +232,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
* Rodando em todos os endereços (0.0.0.0)
|
* Rodando em todos os endereços (0.0.0.0)
|
||||||
```
|
```
|
||||||
|
|
||||||
> Se você pular essa etapa de confirmação e acessar diretamente o RAGFlow, seu navegador pode exibir um erro `network anormal`, pois, nesse momento, seu RAGFlow pode não estar totalmente inicializado.
|
> Se você pular essa etapa de confirmação e acessar diretamente o RAGFlow, seu navegador pode exibir um erro `network abnormal`, pois, nesse momento, seu RAGFlow pode não estar totalmente inicializado.
|
||||||
>
|
>
|
||||||
5. No seu navegador, insira o endereço IP do seu servidor e faça login no RAGFlow.
|
5. No seu navegador, insira o endereço IP do seu servidor e faça login no RAGFlow.
|
||||||
|
|
||||||
@ -293,6 +294,15 @@ cd ragflow/
|
|||||||
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
|
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Se você estiver atrás de um proxy, pode passar argumentos de proxy:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build --platform linux/amd64 \
|
||||||
|
--build-arg http_proxy=http://YOUR_PROXY:PORT \
|
||||||
|
--build-arg https_proxy=http://YOUR_PROXY:PORT \
|
||||||
|
-f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
|
```
|
||||||
|
|
||||||
## 🔨 Lançar o serviço a partir do código-fonte para desenvolvimento
|
## 🔨 Lançar o serviço a partir do código-fonte para desenvolvimento
|
||||||
|
|
||||||
1. Instale o `uv` e o `pre-commit`, ou pule esta etapa se eles já estiverem instalados:
|
1. Instale o `uv` e o `pre-commit`, ou pule esta etapa se eles já estiverem instalados:
|
||||||
@ -375,7 +385,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
|
|
||||||
## 📜 Roadmap
|
## 📜 Roadmap
|
||||||
|
|
||||||
Veja o [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214)
|
Veja o [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241)
|
||||||
|
|
||||||
## 🏄 Comunidade
|
## 🏄 Comunidade
|
||||||
|
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||||
@ -37,7 +37,7 @@
|
|||||||
|
|
||||||
<h4 align="center">
|
<h4 align="center">
|
||||||
<a href="https://ragflow.io/docs/dev/">Document</a> |
|
<a href="https://ragflow.io/docs/dev/">Document</a> |
|
||||||
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> |
|
<a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
|
||||||
<a href="https://twitter.com/infiniflowai">Twitter</a> |
|
<a href="https://twitter.com/infiniflowai">Twitter</a> |
|
||||||
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
|
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
|
||||||
<a href="https://demo.ragflow.io">Demo</a>
|
<a href="https://demo.ragflow.io">Demo</a>
|
||||||
@ -72,7 +72,7 @@
|
|||||||
|
|
||||||
## 💡 RAGFlow 是什麼?
|
## 💡 RAGFlow 是什麼?
|
||||||
|
|
||||||
[RAGFlow](https://ragflow.io/) 是一款領先的開源 RAG(Retrieval-Augmented Generation)引擎,通過融合前沿的 RAG 技術與 Agent 能力,為大型語言模型提供卓越的上下文層。它提供可適配任意規模企業的端到端 RAG 工作流,憑藉融合式上下文引擎與預置的 Agent 模板,助力開發者以極致效率與精度將複雜數據轉化為高可信、生產級的人工智能系統。
|
[RAGFlow](https://ragflow.io/) 是一款領先的開源 [RAG](https://ragflow.io/basics/what-is-rag)(Retrieval-Augmented Generation)引擎,通過融合前沿的 RAG 技術與 Agent 能力,為大型語言模型提供卓越的上下文層。它提供可適配任意規模企業的端到端 RAG 工作流,憑藉融合式[上下文引擎](https://ragflow.io/basics/what-is-agent-context-engine)與預置的 Agent 模板,助力開發者以極致效率與精度將複雜數據轉化為高可信、生產級的人工智能系統。
|
||||||
|
|
||||||
## 🎮 Demo 試用
|
## 🎮 Demo 試用
|
||||||
|
|
||||||
@ -85,15 +85,16 @@
|
|||||||
|
|
||||||
## 🔥 近期更新
|
## 🔥 近期更新
|
||||||
|
|
||||||
- 2025-11-19 支援 Gemini 3 Pro.
|
- 2025-12-26 支援AI代理的「記憶」功能。
|
||||||
|
- 2025-11-19 支援 Gemini 3 Pro。
|
||||||
- 2025-11-12 支援從 Confluence、S3、Notion、Discord、Google Drive 進行資料同步。
|
- 2025-11-12 支援從 Confluence、S3、Notion、Discord、Google Drive 進行資料同步。
|
||||||
- 2025-10-23 支援 MinerU 和 Docling 作為文件解析方法。
|
- 2025-10-23 支援 MinerU 和 Docling 作為文件解析方法。
|
||||||
- 2025-10-15 支援可編排的資料管道。
|
- 2025-10-15 支援可編排的資料管道。
|
||||||
- 2025-08-08 支援 OpenAI 最新的 GPT-5 系列模型。
|
- 2025-08-08 支援 OpenAI 最新的 GPT-5 系列模型。
|
||||||
- 2025-08-01 支援 agentic workflow 和 MCP
|
- 2025-08-01 支援 agentic workflow 和 MCP。
|
||||||
- 2025-05-23 為 Agent 新增 Python/JS 程式碼執行器元件。
|
- 2025-05-23 為 Agent 新增 Python/JS 程式碼執行器元件。
|
||||||
- 2025-05-05 支援跨語言查詢。
|
- 2025-05-05 支援跨語言查詢。
|
||||||
- 2025-03-19 PDF和DOCX中的圖支持用多模態大模型去解析得到描述.
|
- 2025-03-19 PDF和DOCX中的圖支持用多模態大模型去解析得到描述。
|
||||||
- 2024-12-18 升級了 DeepDoc 的文檔佈局分析模型。
|
- 2024-12-18 升級了 DeepDoc 的文檔佈局分析模型。
|
||||||
- 2024-08-22 支援用 RAG 技術實現從自然語言到 SQL 語句的轉換。
|
- 2024-08-22 支援用 RAG 技術實現從自然語言到 SQL 語句的轉換。
|
||||||
|
|
||||||
@ -124,7 +125,7 @@
|
|||||||
|
|
||||||
### 🍔 **相容各類異質資料來源**
|
### 🍔 **相容各類異質資料來源**
|
||||||
|
|
||||||
- 支援豐富的文件類型,包括 Word 文件、PPT、excel 表格、txt 檔案、圖片、PDF、影印件、影印件、結構化資料、網頁等。
|
- 支援豐富的文件類型,包括 Word 文件、PPT、excel 表格、txt 檔案、圖片、PDF、影印件、複印件、結構化資料、網頁等。
|
||||||
|
|
||||||
### 🛀 **全程無憂、自動化的 RAG 工作流程**
|
### 🛀 **全程無憂、自動化的 RAG 工作流程**
|
||||||
|
|
||||||
@ -186,12 +187,12 @@
|
|||||||
> 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。
|
> 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。
|
||||||
> 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。
|
> 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。
|
||||||
|
|
||||||
> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.22.1`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.22.1` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
|
> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.23.1`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.23.1` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cd ragflow/docker
|
$ cd ragflow/docker
|
||||||
|
|
||||||
# git checkout v0.22.1
|
# git checkout v0.23.1
|
||||||
# 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases)
|
# 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases)
|
||||||
# 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。
|
# 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。
|
||||||
|
|
||||||
@ -205,10 +206,10 @@
|
|||||||
|
|
||||||
> 注意:在 `v0.22.0` 之前的版本,我們會同時提供包含 embedding 模型的映像和不含 embedding 模型的 slim 映像。具體如下:
|
> 注意:在 `v0.22.0` 之前的版本,我們會同時提供包含 embedding 模型的映像和不含 embedding 模型的 slim 映像。具體如下:
|
||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
|-------------------|-----------------|-----------------------|----------------|
|
||||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||||
|
|
||||||
> 從 `v0.22.0` 開始,我們只發佈 slim 版本,並且不再在映像標籤後附加 **-slim** 後綴。
|
> 從 `v0.22.0` 開始,我們只發佈 slim 版本,並且不再在映像標籤後附加 **-slim** 後綴。
|
||||||
|
|
||||||
@ -236,7 +237,7 @@
|
|||||||
* Running on all addresses (0.0.0.0)
|
* Running on all addresses (0.0.0.0)
|
||||||
```
|
```
|
||||||
|
|
||||||
> 如果您跳過這一步驟系統確認步驟就登入 RAGFlow,你的瀏覽器有可能會提示 `network anormal` 或 `網路異常`,因為 RAGFlow 可能並未完全啟動成功。
|
> 如果您跳過這一步驟系統確認步驟就登入 RAGFlow,你的瀏覽器有可能會提示 `network abnormal` 或 `網路異常`,因為 RAGFlow 可能並未完全啟動成功。
|
||||||
>
|
>
|
||||||
5. 在你的瀏覽器中輸入你的伺服器對應的 IP 位址並登入 RAGFlow。
|
5. 在你的瀏覽器中輸入你的伺服器對應的 IP 位址並登入 RAGFlow。
|
||||||
|
|
||||||
@ -302,6 +303,15 @@ cd ragflow/
|
|||||||
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
|
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
若您位於代理環境,可傳遞代理參數:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build --platform linux/amd64 \
|
||||||
|
--build-arg http_proxy=http://YOUR_PROXY:PORT \
|
||||||
|
--build-arg https_proxy=http://YOUR_PROXY:PORT \
|
||||||
|
-f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
|
```
|
||||||
|
|
||||||
## 🔨 以原始碼啟動服務
|
## 🔨 以原始碼啟動服務
|
||||||
|
|
||||||
1. 安裝 `uv` 和 `pre-commit`。如已安裝,可跳過此步驟:
|
1. 安裝 `uv` 和 `pre-commit`。如已安裝,可跳過此步驟:
|
||||||
@ -389,7 +399,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
|
|
||||||
## 📜 路線圖
|
## 📜 路線圖
|
||||||
|
|
||||||
詳見 [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214) 。
|
詳見 [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241) 。
|
||||||
|
|
||||||
## 🏄 開源社群
|
## 🏄 開源社群
|
||||||
|
|
||||||
|
|||||||
36
README_zh.md
36
README_zh.md
@ -22,7 +22,7 @@
|
|||||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.23.1">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||||
@ -37,7 +37,7 @@
|
|||||||
|
|
||||||
<h4 align="center">
|
<h4 align="center">
|
||||||
<a href="https://ragflow.io/docs/dev/">Document</a> |
|
<a href="https://ragflow.io/docs/dev/">Document</a> |
|
||||||
<a href="https://github.com/infiniflow/ragflow/issues/4214">Roadmap</a> |
|
<a href="https://github.com/infiniflow/ragflow/issues/12241">Roadmap</a> |
|
||||||
<a href="https://twitter.com/infiniflowai">Twitter</a> |
|
<a href="https://twitter.com/infiniflowai">Twitter</a> |
|
||||||
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
|
<a href="https://discord.gg/NjYzJD3GM3">Discord</a> |
|
||||||
<a href="https://demo.ragflow.io">Demo</a>
|
<a href="https://demo.ragflow.io">Demo</a>
|
||||||
@ -72,7 +72,7 @@
|
|||||||
|
|
||||||
## 💡 RAGFlow 是什么?
|
## 💡 RAGFlow 是什么?
|
||||||
|
|
||||||
[RAGFlow](https://ragflow.io/) 是一款领先的开源检索增强生成(RAG)引擎,通过融合前沿的 RAG 技术与 Agent 能力,为大型语言模型提供卓越的上下文层。它提供可适配任意规模企业的端到端 RAG 工作流,凭借融合式上下文引擎与预置的 Agent 模板,助力开发者以极致效率与精度将复杂数据转化为高可信、生产级的人工智能系统。
|
[RAGFlow](https://ragflow.io/) 是一款领先的开源检索增强生成([RAG](https://ragflow.io/basics/what-is-rag))引擎,通过融合前沿的 RAG 技术与 Agent 能力,为大型语言模型提供卓越的上下文层。它提供可适配任意规模企业的端到端 RAG 工作流,凭借融合式[上下文引擎](https://ragflow.io/basics/what-is-agent-context-engine)与预置的 Agent 模板,助力开发者以极致效率与精度将复杂数据转化为高可信、生产级的人工智能系统。
|
||||||
|
|
||||||
## 🎮 Demo 试用
|
## 🎮 Demo 试用
|
||||||
|
|
||||||
@ -85,7 +85,8 @@
|
|||||||
|
|
||||||
## 🔥 近期更新
|
## 🔥 近期更新
|
||||||
|
|
||||||
- 2025-11-19 支持 Gemini 3 Pro.
|
- 2025-12-26 支持AI代理的“记忆”功能。
|
||||||
|
- 2025-11-19 支持 Gemini 3 Pro。
|
||||||
- 2025-11-12 支持从 Confluence、S3、Notion、Discord、Google Drive 进行数据同步。
|
- 2025-11-12 支持从 Confluence、S3、Notion、Discord、Google Drive 进行数据同步。
|
||||||
- 2025-10-23 支持 MinerU 和 Docling 作为文档解析方法。
|
- 2025-10-23 支持 MinerU 和 Docling 作为文档解析方法。
|
||||||
- 2025-10-15 支持可编排的数据管道。
|
- 2025-10-15 支持可编排的数据管道。
|
||||||
@ -93,7 +94,7 @@
|
|||||||
- 2025-08-01 支持 agentic workflow 和 MCP。
|
- 2025-08-01 支持 agentic workflow 和 MCP。
|
||||||
- 2025-05-23 Agent 新增 Python/JS 代码执行器组件。
|
- 2025-05-23 Agent 新增 Python/JS 代码执行器组件。
|
||||||
- 2025-05-05 支持跨语言查询。
|
- 2025-05-05 支持跨语言查询。
|
||||||
- 2025-03-19 PDF 和 DOCX 中的图支持用多模态大模型去解析得到描述.
|
- 2025-03-19 PDF 和 DOCX 中的图支持用多模态大模型去解析得到描述。
|
||||||
- 2024-12-18 升级了 DeepDoc 的文档布局分析模型。
|
- 2024-12-18 升级了 DeepDoc 的文档布局分析模型。
|
||||||
- 2024-08-22 支持用 RAG 技术实现从自然语言到 SQL 语句的转换。
|
- 2024-08-22 支持用 RAG 技术实现从自然语言到 SQL 语句的转换。
|
||||||
|
|
||||||
@ -187,12 +188,12 @@
|
|||||||
> 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。
|
> 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。
|
||||||
> 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。
|
> 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。
|
||||||
|
|
||||||
> 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.22.1`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.22.1` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
|
> 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.23.1`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.23.1` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cd ragflow/docker
|
$ cd ragflow/docker
|
||||||
|
|
||||||
# git checkout v0.22.1
|
# git checkout v0.23.1
|
||||||
# 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases)
|
# 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases)
|
||||||
# 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。
|
# 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。
|
||||||
|
|
||||||
@ -206,10 +207,10 @@
|
|||||||
|
|
||||||
> 注意:在 `v0.22.0` 之前的版本,我们会同时提供包含 embedding 模型的镜像和不含 embedding 模型的 slim 镜像。具体如下:
|
> 注意:在 `v0.22.0` 之前的版本,我们会同时提供包含 embedding 模型的镜像和不含 embedding 模型的 slim 镜像。具体如下:
|
||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
|-------------------|-----------------|-----------------------|----------------|
|
||||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||||
|
|
||||||
> 从 `v0.22.0` 开始,我们只发布 slim 版本,并且不再在镜像标签后附加 **-slim** 后缀。
|
> 从 `v0.22.0` 开始,我们只发布 slim 版本,并且不再在镜像标签后附加 **-slim** 后缀。
|
||||||
|
|
||||||
@ -237,7 +238,7 @@
|
|||||||
* Running on all addresses (0.0.0.0)
|
* Running on all addresses (0.0.0.0)
|
||||||
```
|
```
|
||||||
|
|
||||||
> 如果您在没有看到上面的提示信息出来之前,就尝试登录 RAGFlow,你的浏览器有可能会提示 `network anormal` 或 `网络异常`。
|
> 如果您在没有看到上面的提示信息出来之前,就尝试登录 RAGFlow,你的浏览器有可能会提示 `network abnormal` 或 `网络异常`。
|
||||||
|
|
||||||
5. 在你的浏览器中输入你的服务器对应的 IP 地址并登录 RAGFlow。
|
5. 在你的浏览器中输入你的服务器对应的 IP 地址并登录 RAGFlow。
|
||||||
> 上面这个例子中,您只需输入 http://IP_OF_YOUR_MACHINE 即可:未改动过配置则无需输入端口(默认的 HTTP 服务端口 80)。
|
> 上面这个例子中,您只需输入 http://IP_OF_YOUR_MACHINE 即可:未改动过配置则无需输入端口(默认的 HTTP 服务端口 80)。
|
||||||
@ -301,6 +302,15 @@ cd ragflow/
|
|||||||
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
|
docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
如果您处在代理环境下,可以传递代理参数:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build --platform linux/amd64 \
|
||||||
|
--build-arg http_proxy=http://YOUR_PROXY:PORT \
|
||||||
|
--build-arg https_proxy=http://YOUR_PROXY:PORT \
|
||||||
|
-f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
|
```
|
||||||
|
|
||||||
## 🔨 以源代码启动服务
|
## 🔨 以源代码启动服务
|
||||||
|
|
||||||
1. 安装 `uv` 和 `pre-commit`。如已经安装,可跳过本步骤:
|
1. 安装 `uv` 和 `pre-commit`。如已经安装,可跳过本步骤:
|
||||||
@ -392,7 +402,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
|
|
||||||
## 📜 路线图
|
## 📜 路线图
|
||||||
|
|
||||||
详见 [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214) 。
|
详见 [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241) 。
|
||||||
|
|
||||||
## 🏄 开源社区
|
## 🏄 开源社区
|
||||||
|
|
||||||
|
|||||||
@ -6,8 +6,8 @@ Use this section to tell people about which versions of your project are
|
|||||||
currently being supported with security updates.
|
currently being supported with security updates.
|
||||||
|
|
||||||
| Version | Supported |
|
| Version | Supported |
|
||||||
| ------- | ------------------ |
|
|---------|--------------------|
|
||||||
| <=0.7.0 | :white_check_mark: |
|
| <=0.7.0 | :white_check_mark: |
|
||||||
|
|
||||||
## Reporting a Vulnerability
|
## Reporting a Vulnerability
|
||||||
|
|
||||||
|
|||||||
@ -21,7 +21,7 @@ cp pyproject.toml release/$PROJECT_NAME/pyproject.toml
|
|||||||
cp README.md release/$PROJECT_NAME/README.md
|
cp README.md release/$PROJECT_NAME/README.md
|
||||||
|
|
||||||
mkdir release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR -p
|
mkdir release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR -p
|
||||||
cp admin_client.py release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR/admin_client.py
|
cp ragflow_cli.py release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR/ragflow_cli.py
|
||||||
|
|
||||||
if [ -d "release/$PROJECT_NAME/$SOURCE_DIR" ]; then
|
if [ -d "release/$PROJECT_NAME/$SOURCE_DIR" ]; then
|
||||||
echo "✅ source dir: release/$PROJECT_NAME/$SOURCE_DIR"
|
echo "✅ source dir: release/$PROJECT_NAME/$SOURCE_DIR"
|
||||||
|
|||||||
@ -48,7 +48,7 @@ It consists of a server-side Service and a command-line client (CLI), both imple
|
|||||||
1. Ensure the Admin Service is running.
|
1. Ensure the Admin Service is running.
|
||||||
2. Install ragflow-cli.
|
2. Install ragflow-cli.
|
||||||
```bash
|
```bash
|
||||||
pip install ragflow-cli==0.22.1
|
pip install ragflow-cli==0.23.1
|
||||||
```
|
```
|
||||||
3. Launch the CLI client:
|
3. Launch the CLI client:
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@ -1,978 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import base64
|
|
||||||
from cmd import Cmd
|
|
||||||
|
|
||||||
from Cryptodome.PublicKey import RSA
|
|
||||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
|
||||||
from typing import Dict, List, Any
|
|
||||||
from lark import Lark, Transformer, Tree
|
|
||||||
import requests
|
|
||||||
import getpass
|
|
||||||
|
|
||||||
GRAMMAR = r"""
|
|
||||||
start: command
|
|
||||||
|
|
||||||
command: sql_command | meta_command
|
|
||||||
|
|
||||||
sql_command: list_services
|
|
||||||
| show_service
|
|
||||||
| startup_service
|
|
||||||
| shutdown_service
|
|
||||||
| restart_service
|
|
||||||
| list_users
|
|
||||||
| show_user
|
|
||||||
| drop_user
|
|
||||||
| alter_user
|
|
||||||
| create_user
|
|
||||||
| activate_user
|
|
||||||
| list_datasets
|
|
||||||
| list_agents
|
|
||||||
| create_role
|
|
||||||
| drop_role
|
|
||||||
| alter_role
|
|
||||||
| list_roles
|
|
||||||
| show_role
|
|
||||||
| grant_permission
|
|
||||||
| revoke_permission
|
|
||||||
| alter_user_role
|
|
||||||
| show_user_permission
|
|
||||||
| show_version
|
|
||||||
|
|
||||||
// meta command definition
|
|
||||||
meta_command: "\\" meta_command_name [meta_args]
|
|
||||||
|
|
||||||
meta_command_name: /[a-zA-Z?]+/
|
|
||||||
meta_args: (meta_arg)+
|
|
||||||
|
|
||||||
meta_arg: /[^\\s"']+/ | quoted_string
|
|
||||||
|
|
||||||
// command definition
|
|
||||||
|
|
||||||
LIST: "LIST"i
|
|
||||||
SERVICES: "SERVICES"i
|
|
||||||
SHOW: "SHOW"i
|
|
||||||
CREATE: "CREATE"i
|
|
||||||
SERVICE: "SERVICE"i
|
|
||||||
SHUTDOWN: "SHUTDOWN"i
|
|
||||||
STARTUP: "STARTUP"i
|
|
||||||
RESTART: "RESTART"i
|
|
||||||
USERS: "USERS"i
|
|
||||||
DROP: "DROP"i
|
|
||||||
USER: "USER"i
|
|
||||||
ALTER: "ALTER"i
|
|
||||||
ACTIVE: "ACTIVE"i
|
|
||||||
PASSWORD: "PASSWORD"i
|
|
||||||
DATASETS: "DATASETS"i
|
|
||||||
OF: "OF"i
|
|
||||||
AGENTS: "AGENTS"i
|
|
||||||
ROLE: "ROLE"i
|
|
||||||
ROLES: "ROLES"i
|
|
||||||
DESCRIPTION: "DESCRIPTION"i
|
|
||||||
GRANT: "GRANT"i
|
|
||||||
REVOKE: "REVOKE"i
|
|
||||||
ALL: "ALL"i
|
|
||||||
PERMISSION: "PERMISSION"i
|
|
||||||
TO: "TO"i
|
|
||||||
FROM: "FROM"i
|
|
||||||
FOR: "FOR"i
|
|
||||||
RESOURCES: "RESOURCES"i
|
|
||||||
ON: "ON"i
|
|
||||||
SET: "SET"i
|
|
||||||
VERSION: "VERSION"i
|
|
||||||
|
|
||||||
list_services: LIST SERVICES ";"
|
|
||||||
show_service: SHOW SERVICE NUMBER ";"
|
|
||||||
startup_service: STARTUP SERVICE NUMBER ";"
|
|
||||||
shutdown_service: SHUTDOWN SERVICE NUMBER ";"
|
|
||||||
restart_service: RESTART SERVICE NUMBER ";"
|
|
||||||
|
|
||||||
list_users: LIST USERS ";"
|
|
||||||
drop_user: DROP USER quoted_string ";"
|
|
||||||
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
|
|
||||||
show_user: SHOW USER quoted_string ";"
|
|
||||||
create_user: CREATE USER quoted_string quoted_string ";"
|
|
||||||
activate_user: ALTER USER ACTIVE quoted_string status ";"
|
|
||||||
|
|
||||||
list_datasets: LIST DATASETS OF quoted_string ";"
|
|
||||||
list_agents: LIST AGENTS OF quoted_string ";"
|
|
||||||
|
|
||||||
create_role: CREATE ROLE identifier [DESCRIPTION quoted_string] ";"
|
|
||||||
drop_role: DROP ROLE identifier ";"
|
|
||||||
alter_role: ALTER ROLE identifier SET DESCRIPTION quoted_string ";"
|
|
||||||
list_roles: LIST ROLES ";"
|
|
||||||
show_role: SHOW ROLE identifier ";"
|
|
||||||
|
|
||||||
grant_permission: GRANT action_list ON identifier TO ROLE identifier ";"
|
|
||||||
revoke_permission: REVOKE action_list ON identifier FROM ROLE identifier ";"
|
|
||||||
alter_user_role: ALTER USER quoted_string SET ROLE identifier ";"
|
|
||||||
show_user_permission: SHOW USER PERMISSION quoted_string ";"
|
|
||||||
|
|
||||||
show_version: SHOW VERSION ";"
|
|
||||||
|
|
||||||
action_list: identifier ("," identifier)*
|
|
||||||
|
|
||||||
identifier: WORD
|
|
||||||
quoted_string: QUOTED_STRING
|
|
||||||
status: WORD
|
|
||||||
|
|
||||||
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
|
|
||||||
WORD: /[a-zA-Z0-9_\-\.]+/
|
|
||||||
NUMBER: /[0-9]+/
|
|
||||||
|
|
||||||
%import common.WS
|
|
||||||
%ignore WS
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class AdminTransformer(Transformer):
|
|
||||||
|
|
||||||
def start(self, items):
|
|
||||||
return items[0]
|
|
||||||
|
|
||||||
def command(self, items):
|
|
||||||
return items[0]
|
|
||||||
|
|
||||||
def list_services(self, items):
|
|
||||||
result = {'type': 'list_services'}
|
|
||||||
return result
|
|
||||||
|
|
||||||
def show_service(self, items):
|
|
||||||
service_id = int(items[2])
|
|
||||||
return {"type": "show_service", "number": service_id}
|
|
||||||
|
|
||||||
def startup_service(self, items):
|
|
||||||
service_id = int(items[2])
|
|
||||||
return {"type": "startup_service", "number": service_id}
|
|
||||||
|
|
||||||
def shutdown_service(self, items):
|
|
||||||
service_id = int(items[2])
|
|
||||||
return {"type": "shutdown_service", "number": service_id}
|
|
||||||
|
|
||||||
def restart_service(self, items):
|
|
||||||
service_id = int(items[2])
|
|
||||||
return {"type": "restart_service", "number": service_id}
|
|
||||||
|
|
||||||
def list_users(self, items):
|
|
||||||
return {"type": "list_users"}
|
|
||||||
|
|
||||||
def show_user(self, items):
|
|
||||||
user_name = items[2]
|
|
||||||
return {"type": "show_user", "user_name": user_name}
|
|
||||||
|
|
||||||
def drop_user(self, items):
|
|
||||||
user_name = items[2]
|
|
||||||
return {"type": "drop_user", "user_name": user_name}
|
|
||||||
|
|
||||||
def alter_user(self, items):
|
|
||||||
user_name = items[3]
|
|
||||||
new_password = items[4]
|
|
||||||
return {"type": "alter_user", "user_name": user_name, "password": new_password}
|
|
||||||
|
|
||||||
def create_user(self, items):
|
|
||||||
user_name = items[2]
|
|
||||||
password = items[3]
|
|
||||||
return {"type": "create_user", "user_name": user_name, "password": password, "role": "user"}
|
|
||||||
|
|
||||||
def activate_user(self, items):
|
|
||||||
user_name = items[3]
|
|
||||||
activate_status = items[4]
|
|
||||||
return {"type": "activate_user", "activate_status": activate_status, "user_name": user_name}
|
|
||||||
|
|
||||||
def list_datasets(self, items):
|
|
||||||
user_name = items[3]
|
|
||||||
return {"type": "list_datasets", "user_name": user_name}
|
|
||||||
|
|
||||||
def list_agents(self, items):
|
|
||||||
user_name = items[3]
|
|
||||||
return {"type": "list_agents", "user_name": user_name}
|
|
||||||
|
|
||||||
def create_role(self, items):
|
|
||||||
role_name = items[2]
|
|
||||||
if len(items) > 4:
|
|
||||||
description = items[4]
|
|
||||||
return {"type": "create_role", "role_name": role_name, "description": description}
|
|
||||||
else:
|
|
||||||
return {"type": "create_role", "role_name": role_name}
|
|
||||||
|
|
||||||
def drop_role(self, items):
|
|
||||||
role_name = items[2]
|
|
||||||
return {"type": "drop_role", "role_name": role_name}
|
|
||||||
|
|
||||||
def alter_role(self, items):
|
|
||||||
role_name = items[2]
|
|
||||||
description = items[5]
|
|
||||||
return {"type": "alter_role", "role_name": role_name, "description": description}
|
|
||||||
|
|
||||||
def list_roles(self, items):
|
|
||||||
return {"type": "list_roles"}
|
|
||||||
|
|
||||||
def show_role(self, items):
|
|
||||||
role_name = items[2]
|
|
||||||
return {"type": "show_role", "role_name": role_name}
|
|
||||||
|
|
||||||
def grant_permission(self, items):
|
|
||||||
action_list = items[1]
|
|
||||||
resource = items[3]
|
|
||||||
role_name = items[6]
|
|
||||||
return {"type": "grant_permission", "role_name": role_name, "resource": resource, "actions": action_list}
|
|
||||||
|
|
||||||
def revoke_permission(self, items):
|
|
||||||
action_list = items[1]
|
|
||||||
resource = items[3]
|
|
||||||
role_name = items[6]
|
|
||||||
return {
|
|
||||||
"type": "revoke_permission",
|
|
||||||
"role_name": role_name,
|
|
||||||
"resource": resource, "actions": action_list
|
|
||||||
}
|
|
||||||
|
|
||||||
def alter_user_role(self, items):
|
|
||||||
user_name = items[2]
|
|
||||||
role_name = items[5]
|
|
||||||
return {"type": "alter_user_role", "user_name": user_name, "role_name": role_name}
|
|
||||||
|
|
||||||
def show_user_permission(self, items):
|
|
||||||
user_name = items[3]
|
|
||||||
return {"type": "show_user_permission", "user_name": user_name}
|
|
||||||
|
|
||||||
def show_version(self, items):
|
|
||||||
return {"type": "show_version"}
|
|
||||||
|
|
||||||
def action_list(self, items):
|
|
||||||
return items
|
|
||||||
|
|
||||||
def meta_command(self, items):
|
|
||||||
command_name = str(items[0]).lower()
|
|
||||||
args = items[1:] if len(items) > 1 else []
|
|
||||||
|
|
||||||
# handle quoted parameter
|
|
||||||
parsed_args = []
|
|
||||||
for arg in args:
|
|
||||||
if hasattr(arg, 'value'):
|
|
||||||
parsed_args.append(arg.value)
|
|
||||||
else:
|
|
||||||
parsed_args.append(str(arg))
|
|
||||||
|
|
||||||
return {'type': 'meta', 'command': command_name, 'args': parsed_args}
|
|
||||||
|
|
||||||
def meta_command_name(self, items):
|
|
||||||
return items[0]
|
|
||||||
|
|
||||||
def meta_args(self, items):
|
|
||||||
return items
|
|
||||||
|
|
||||||
|
|
||||||
def encrypt(input_string):
|
|
||||||
pub = '-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----'
|
|
||||||
pub_key = RSA.importKey(pub)
|
|
||||||
cipher = Cipher_pkcs1_v1_5.new(pub_key)
|
|
||||||
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode('utf-8')))
|
|
||||||
return base64.b64encode(cipher_text).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def encode_to_base64(input_string):
|
|
||||||
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
|
||||||
return base64_encoded.decode('utf-8')
|
|
||||||
|
|
||||||
|
|
||||||
class AdminCLI(Cmd):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.parser = Lark(GRAMMAR, start='start', parser='lalr', transformer=AdminTransformer())
|
|
||||||
self.command_history = []
|
|
||||||
self.is_interactive = False
|
|
||||||
self.admin_account = "admin@ragflow.io"
|
|
||||||
self.admin_password: str = "admin"
|
|
||||||
self.session = requests.Session()
|
|
||||||
self.access_token: str = ""
|
|
||||||
self.host: str = ""
|
|
||||||
self.port: int = 0
|
|
||||||
|
|
||||||
intro = r"""Type "\h" for help."""
|
|
||||||
prompt = "admin> "
|
|
||||||
|
|
||||||
def onecmd(self, command: str) -> bool:
|
|
||||||
try:
|
|
||||||
result = self.parse_command(command)
|
|
||||||
|
|
||||||
if isinstance(result, dict):
|
|
||||||
if 'type' in result and result.get('type') == 'empty':
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.execute_command(result)
|
|
||||||
|
|
||||||
if isinstance(result, Tree):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if result.get('type') == 'meta' and result.get('command') in ['q', 'quit', 'exit']:
|
|
||||||
return True
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\nUse '\\q' to quit")
|
|
||||||
except EOFError:
|
|
||||||
print("\nGoodbye!")
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def emptyline(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def default(self, line: str) -> bool:
|
|
||||||
return self.onecmd(line)
|
|
||||||
|
|
||||||
def parse_command(self, command_str: str) -> dict[str, str]:
|
|
||||||
if not command_str.strip():
|
|
||||||
return {'type': 'empty'}
|
|
||||||
|
|
||||||
self.command_history.append(command_str)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = self.parser.parse(command_str)
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
return {'type': 'error', 'message': f'Parse error: {str(e)}'}
|
|
||||||
|
|
||||||
def verify_admin(self, arguments: dict, single_command: bool):
|
|
||||||
self.host = arguments['host']
|
|
||||||
self.port = arguments['port']
|
|
||||||
print("Attempt to access server for admin login")
|
|
||||||
url = f"http://{self.host}:{self.port}/api/v1/admin/login"
|
|
||||||
|
|
||||||
attempt_count = 3
|
|
||||||
if single_command:
|
|
||||||
attempt_count = 1
|
|
||||||
|
|
||||||
try_count = 0
|
|
||||||
while True:
|
|
||||||
try_count += 1
|
|
||||||
if try_count > attempt_count:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if single_command:
|
|
||||||
admin_passwd = arguments['password']
|
|
||||||
else:
|
|
||||||
admin_passwd = getpass.getpass(f"password for {self.admin_account}: ").strip()
|
|
||||||
try:
|
|
||||||
self.admin_password = encrypt(admin_passwd)
|
|
||||||
response = self.session.post(url, json={'email': self.admin_account, 'password': self.admin_password})
|
|
||||||
if response.status_code == 200:
|
|
||||||
res_json = response.json()
|
|
||||||
error_code = res_json.get('code', -1)
|
|
||||||
if error_code == 0:
|
|
||||||
self.session.headers.update({
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': response.headers['Authorization'],
|
|
||||||
'User-Agent': 'RAGFlow-CLI/0.22.1'
|
|
||||||
})
|
|
||||||
print("Authentication successful.")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
error_message = res_json.get('message', 'Unknown error')
|
|
||||||
print(f"Authentication failed: {error_message}, try again")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
print(f"Bad response,status: {response.status_code}, password is wrong")
|
|
||||||
except Exception as e:
|
|
||||||
print(str(e))
|
|
||||||
print("Can't access server for admin login (connection failed)")
|
|
||||||
|
|
||||||
def _format_service_detail_table(self, data):
|
|
||||||
if isinstance(data, list):
|
|
||||||
return data
|
|
||||||
if not all([isinstance(v, list) for v in data.values()]):
|
|
||||||
# normal table
|
|
||||||
return data
|
|
||||||
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
|
|
||||||
task_executor_list = []
|
|
||||||
for k, v in data.items():
|
|
||||||
# display latest status
|
|
||||||
heartbeats = sorted(v, key=lambda x: x["now"], reverse=True)
|
|
||||||
task_executor_list.append({
|
|
||||||
"task_executor_name": k,
|
|
||||||
**heartbeats[0],
|
|
||||||
} if heartbeats else {"task_executor_name": k})
|
|
||||||
return task_executor_list
|
|
||||||
|
|
||||||
def _print_table_simple(self, data):
|
|
||||||
if not data:
|
|
||||||
print("No data to print")
|
|
||||||
return
|
|
||||||
if isinstance(data, dict):
|
|
||||||
# handle single row data
|
|
||||||
data = [data]
|
|
||||||
|
|
||||||
columns = list(set().union(*(d.keys() for d in data)))
|
|
||||||
columns.sort()
|
|
||||||
col_widths = {}
|
|
||||||
|
|
||||||
def get_string_width(text):
|
|
||||||
half_width_chars = (
|
|
||||||
" !\"#$%&'()*+,-./0123456789:;<=>?@"
|
|
||||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`"
|
|
||||||
"abcdefghijklmnopqrstuvwxyz{|}~"
|
|
||||||
"\t\n\r"
|
|
||||||
)
|
|
||||||
width = 0
|
|
||||||
for char in text:
|
|
||||||
if char in half_width_chars:
|
|
||||||
width += 1
|
|
||||||
else:
|
|
||||||
width += 2
|
|
||||||
return width
|
|
||||||
|
|
||||||
for col in columns:
|
|
||||||
max_width = get_string_width(str(col))
|
|
||||||
for item in data:
|
|
||||||
value_len = get_string_width(str(item.get(col, '')))
|
|
||||||
if value_len > max_width:
|
|
||||||
max_width = value_len
|
|
||||||
col_widths[col] = max(2, max_width)
|
|
||||||
|
|
||||||
# Generate delimiter
|
|
||||||
separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
|
|
||||||
|
|
||||||
# Print header
|
|
||||||
print(separator)
|
|
||||||
header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
|
|
||||||
print(header)
|
|
||||||
print(separator)
|
|
||||||
|
|
||||||
# Print data
|
|
||||||
for item in data:
|
|
||||||
row = "|"
|
|
||||||
for col in columns:
|
|
||||||
value = str(item.get(col, ''))
|
|
||||||
if get_string_width(value) > col_widths[col]:
|
|
||||||
value = value[:col_widths[col] - 3] + "..."
|
|
||||||
row += f" {value:<{col_widths[col] - (get_string_width(value) - len(value))}} |"
|
|
||||||
print(row)
|
|
||||||
|
|
||||||
print(separator)
|
|
||||||
|
|
||||||
def run_interactive(self):
|
|
||||||
|
|
||||||
self.is_interactive = True
|
|
||||||
print("RAGFlow Admin command line interface - Type '\\?' for help, '\\q' to quit")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
command = input("admin> ").strip()
|
|
||||||
if not command:
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"command: {command}")
|
|
||||||
result = self.parse_command(command)
|
|
||||||
self.execute_command(result)
|
|
||||||
|
|
||||||
if isinstance(result, Tree):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if result.get('type') == 'meta' and result.get('command') in ['q', 'quit', 'exit']:
|
|
||||||
break
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\nUse '\\q' to quit")
|
|
||||||
except EOFError:
|
|
||||||
print("\nGoodbye!")
|
|
||||||
break
|
|
||||||
|
|
||||||
def run_single_command(self, command: str):
|
|
||||||
result = self.parse_command(command)
|
|
||||||
self.execute_command(result)
|
|
||||||
|
|
||||||
def parse_connection_args(self, args: List[str]) -> Dict[str, Any]:
|
|
||||||
parser = argparse.ArgumentParser(description='Admin CLI Client', add_help=False)
|
|
||||||
parser.add_argument('-h', '--host', default='localhost', help='Admin service host')
|
|
||||||
parser.add_argument('-p', '--port', type=int, default=9381, help='Admin service port')
|
|
||||||
parser.add_argument('-w', '--password', default='admin', type=str, help='Superuser password')
|
|
||||||
parser.add_argument('command', nargs='?', help='Single command')
|
|
||||||
try:
|
|
||||||
parsed_args, remaining_args = parser.parse_known_args(args)
|
|
||||||
if remaining_args:
|
|
||||||
command = remaining_args[0]
|
|
||||||
return {
|
|
||||||
'host': parsed_args.host,
|
|
||||||
'port': parsed_args.port,
|
|
||||||
'password': parsed_args.password,
|
|
||||||
'command': command
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
'host': parsed_args.host,
|
|
||||||
'port': parsed_args.port,
|
|
||||||
}
|
|
||||||
except SystemExit:
|
|
||||||
return {'error': 'Invalid connection arguments'}
|
|
||||||
|
|
||||||
def execute_command(self, parsed_command: Dict[str, Any]):
|
|
||||||
|
|
||||||
command_dict: dict
|
|
||||||
if isinstance(parsed_command, Tree):
|
|
||||||
command_dict = parsed_command.children[0]
|
|
||||||
else:
|
|
||||||
if parsed_command['type'] == 'error':
|
|
||||||
print(f"Error: {parsed_command['message']}")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
command_dict = parsed_command
|
|
||||||
|
|
||||||
# print(f"Parsed command: {command_dict}")
|
|
||||||
|
|
||||||
command_type = command_dict['type']
|
|
||||||
|
|
||||||
match command_type:
|
|
||||||
case 'list_services':
|
|
||||||
self._handle_list_services(command_dict)
|
|
||||||
case 'show_service':
|
|
||||||
self._handle_show_service(command_dict)
|
|
||||||
case 'restart_service':
|
|
||||||
self._handle_restart_service(command_dict)
|
|
||||||
case 'shutdown_service':
|
|
||||||
self._handle_shutdown_service(command_dict)
|
|
||||||
case 'startup_service':
|
|
||||||
self._handle_startup_service(command_dict)
|
|
||||||
case 'list_users':
|
|
||||||
self._handle_list_users(command_dict)
|
|
||||||
case 'show_user':
|
|
||||||
self._handle_show_user(command_dict)
|
|
||||||
case 'drop_user':
|
|
||||||
self._handle_drop_user(command_dict)
|
|
||||||
case 'alter_user':
|
|
||||||
self._handle_alter_user(command_dict)
|
|
||||||
case 'create_user':
|
|
||||||
self._handle_create_user(command_dict)
|
|
||||||
case 'activate_user':
|
|
||||||
self._handle_activate_user(command_dict)
|
|
||||||
case 'list_datasets':
|
|
||||||
self._handle_list_datasets(command_dict)
|
|
||||||
case 'list_agents':
|
|
||||||
self._handle_list_agents(command_dict)
|
|
||||||
case 'create_role':
|
|
||||||
self._create_role(command_dict)
|
|
||||||
case 'drop_role':
|
|
||||||
self._drop_role(command_dict)
|
|
||||||
case 'alter_role':
|
|
||||||
self._alter_role(command_dict)
|
|
||||||
case 'list_roles':
|
|
||||||
self._list_roles(command_dict)
|
|
||||||
case 'show_role':
|
|
||||||
self._show_role(command_dict)
|
|
||||||
case 'grant_permission':
|
|
||||||
self._grant_permission(command_dict)
|
|
||||||
case 'revoke_permission':
|
|
||||||
self._revoke_permission(command_dict)
|
|
||||||
case 'alter_user_role':
|
|
||||||
self._alter_user_role(command_dict)
|
|
||||||
case 'show_user_permission':
|
|
||||||
self._show_user_permission(command_dict)
|
|
||||||
case 'show_version':
|
|
||||||
self._show_version(command_dict)
|
|
||||||
case 'meta':
|
|
||||||
self._handle_meta_command(command_dict)
|
|
||||||
case _:
|
|
||||||
print(f"Command '{command_type}' would be executed with API")
|
|
||||||
|
|
||||||
def _handle_list_services(self, command):
|
|
||||||
print("Listing all services")
|
|
||||||
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/services'
|
|
||||||
response = self.session.get(url)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
self._print_table_simple(res_json['data'])
|
|
||||||
else:
|
|
||||||
print(f"Fail to get all services, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _handle_show_service(self, command):
|
|
||||||
service_id: int = command['number']
|
|
||||||
print(f"Showing service: {service_id}")
|
|
||||||
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/services/{service_id}'
|
|
||||||
response = self.session.get(url)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
res_data = res_json['data']
|
|
||||||
if 'status' in res_data and res_data['status'] == 'alive':
|
|
||||||
print(f"Service {res_data['service_name']} is alive, ")
|
|
||||||
if isinstance(res_data['message'], str):
|
|
||||||
print(res_data['message'])
|
|
||||||
else:
|
|
||||||
data = self._format_service_detail_table(res_data['message'])
|
|
||||||
self._print_table_simple(data)
|
|
||||||
else:
|
|
||||||
print(f"Service {res_data['service_name']} is down, {res_data['message']}")
|
|
||||||
else:
|
|
||||||
print(f"Fail to show service, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _handle_restart_service(self, command):
|
|
||||||
service_id: int = command['number']
|
|
||||||
print(f"Restart service {service_id}")
|
|
||||||
|
|
||||||
def _handle_shutdown_service(self, command):
|
|
||||||
service_id: int = command['number']
|
|
||||||
print(f"Shutdown service {service_id}")
|
|
||||||
|
|
||||||
def _handle_startup_service(self, command):
|
|
||||||
service_id: int = command['number']
|
|
||||||
print(f"Startup service {service_id}")
|
|
||||||
|
|
||||||
def _handle_list_users(self, command):
|
|
||||||
print("Listing all users")
|
|
||||||
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
|
||||||
response = self.session.get(url)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
self._print_table_simple(res_json['data'])
|
|
||||||
else:
|
|
||||||
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _handle_show_user(self, command):
|
|
||||||
username_tree: Tree = command['user_name']
|
|
||||||
user_name: str = username_tree.children[0].strip("'\"")
|
|
||||||
print(f"Showing user: {user_name}")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}'
|
|
||||||
response = self.session.get(url)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
table_data = res_json['data']
|
|
||||||
table_data.pop('avatar')
|
|
||||||
self._print_table_simple(table_data)
|
|
||||||
else:
|
|
||||||
print(f"Fail to get user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _handle_drop_user(self, command):
|
|
||||||
username_tree: Tree = command['user_name']
|
|
||||||
user_name: str = username_tree.children[0].strip("'\"")
|
|
||||||
print(f"Drop user: {user_name}")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}'
|
|
||||||
response = self.session.delete(url)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
print(res_json["message"])
|
|
||||||
else:
|
|
||||||
print(f"Fail to drop user, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _handle_alter_user(self, command):
|
|
||||||
user_name_tree: Tree = command['user_name']
|
|
||||||
user_name: str = user_name_tree.children[0].strip("'\"")
|
|
||||||
password_tree: Tree = command['password']
|
|
||||||
password: str = password_tree.children[0].strip("'\"")
|
|
||||||
print(f"Alter user: {user_name}, password: ******")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/password'
|
|
||||||
response = self.session.put(url, json={'new_password': encrypt(password)})
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
print(res_json["message"])
|
|
||||||
else:
|
|
||||||
print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _handle_create_user(self, command):
|
|
||||||
user_name_tree: Tree = command['user_name']
|
|
||||||
user_name: str = user_name_tree.children[0].strip("'\"")
|
|
||||||
password_tree: Tree = command['password']
|
|
||||||
password: str = password_tree.children[0].strip("'\"")
|
|
||||||
role: str = command['role']
|
|
||||||
print(f"Create user: {user_name}, password: ******, role: {role}")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
|
||||||
response = self.session.post(
|
|
||||||
url,
|
|
||||||
json={'user_name': user_name, 'password': encrypt(password), 'role': role}
|
|
||||||
)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
self._print_table_simple(res_json['data'])
|
|
||||||
else:
|
|
||||||
print(f"Fail to create user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _handle_activate_user(self, command):
|
|
||||||
user_name_tree: Tree = command['user_name']
|
|
||||||
user_name: str = user_name_tree.children[0].strip("'\"")
|
|
||||||
activate_tree: Tree = command['activate_status']
|
|
||||||
activate_status: str = activate_tree.children[0].strip("'\"")
|
|
||||||
if activate_status.lower() in ['on', 'off']:
|
|
||||||
print(f"Alter user {user_name} activate status, turn {activate_status.lower()}.")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/activate'
|
|
||||||
response = self.session.put(url, json={'activate_status': activate_status})
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
print(res_json["message"])
|
|
||||||
else:
|
|
||||||
print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
else:
|
|
||||||
print(f"Unknown activate status: {activate_status}.")
|
|
||||||
|
|
||||||
def _handle_list_datasets(self, command):
|
|
||||||
username_tree: Tree = command['user_name']
|
|
||||||
user_name: str = username_tree.children[0].strip("'\"")
|
|
||||||
print(f"Listing all datasets of user: {user_name}")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/datasets'
|
|
||||||
response = self.session.get(url)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
table_data = res_json['data']
|
|
||||||
for t in table_data:
|
|
||||||
t.pop('avatar')
|
|
||||||
self._print_table_simple(table_data)
|
|
||||||
else:
|
|
||||||
print(f"Fail to get all datasets of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _handle_list_agents(self, command):
|
|
||||||
username_tree: Tree = command['user_name']
|
|
||||||
user_name: str = username_tree.children[0].strip("'\"")
|
|
||||||
print(f"Listing all agents of user: {user_name}")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/agents'
|
|
||||||
response = self.session.get(url)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
table_data = res_json['data']
|
|
||||||
for t in table_data:
|
|
||||||
t.pop('avatar')
|
|
||||||
self._print_table_simple(table_data)
|
|
||||||
else:
|
|
||||||
print(f"Fail to get all agents of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _create_role(self, command):
|
|
||||||
role_name_tree: Tree = command['role_name']
|
|
||||||
role_name: str = role_name_tree.children[0].strip("'\"")
|
|
||||||
desc_str: str = ''
|
|
||||||
if 'description' in command:
|
|
||||||
desc_tree: Tree = command['description']
|
|
||||||
desc_str = desc_tree.children[0].strip("'\"")
|
|
||||||
|
|
||||||
print(f"create role name: {role_name}, description: {desc_str}")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/roles'
|
|
||||||
response = self.session.post(
|
|
||||||
url,
|
|
||||||
json={'role_name': role_name, 'description': desc_str}
|
|
||||||
)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
self._print_table_simple(res_json['data'])
|
|
||||||
else:
|
|
||||||
print(f"Fail to create role {role_name}, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _drop_role(self, command):
|
|
||||||
role_name_tree: Tree = command['role_name']
|
|
||||||
role_name: str = role_name_tree.children[0].strip("'\"")
|
|
||||||
print(f"drop role name: {role_name}")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}'
|
|
||||||
response = self.session.delete(url)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
self._print_table_simple(res_json['data'])
|
|
||||||
else:
|
|
||||||
print(f"Fail to drop role {role_name}, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _alter_role(self, command):
|
|
||||||
role_name_tree: Tree = command['role_name']
|
|
||||||
role_name: str = role_name_tree.children[0].strip("'\"")
|
|
||||||
desc_tree: Tree = command['description']
|
|
||||||
desc_str: str = desc_tree.children[0].strip("'\"")
|
|
||||||
|
|
||||||
print(f"alter role name: {role_name}, description: {desc_str}")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}'
|
|
||||||
response = self.session.put(
|
|
||||||
url,
|
|
||||||
json={'description': desc_str}
|
|
||||||
)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
self._print_table_simple(res_json['data'])
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"Fail to update role {role_name} with description: {desc_str}, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _list_roles(self, command):
|
|
||||||
print("Listing all roles")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/roles'
|
|
||||||
response = self.session.get(url)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
self._print_table_simple(res_json['data'])
|
|
||||||
else:
|
|
||||||
print(f"Fail to list roles, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _show_role(self, command):
|
|
||||||
role_name_tree: Tree = command['role_name']
|
|
||||||
role_name: str = role_name_tree.children[0].strip("'\"")
|
|
||||||
print(f"show role: {role_name}")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}/permission'
|
|
||||||
response = self.session.get(url)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
self._print_table_simple(res_json['data'])
|
|
||||||
else:
|
|
||||||
print(f"Fail to list roles, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _grant_permission(self, command):
|
|
||||||
role_name_tree: Tree = command['role_name']
|
|
||||||
role_name_str: str = role_name_tree.children[0].strip("'\"")
|
|
||||||
resource_tree: Tree = command['resource']
|
|
||||||
resource_str: str = resource_tree.children[0].strip("'\"")
|
|
||||||
action_tree_list: list = command['actions']
|
|
||||||
actions: list = []
|
|
||||||
for action_tree in action_tree_list:
|
|
||||||
action_str: str = action_tree.children[0].strip("'\"")
|
|
||||||
actions.append(action_str)
|
|
||||||
print(f"grant role_name: {role_name_str}, resource: {resource_str}, actions: {actions}")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name_str}/permission'
|
|
||||||
response = self.session.post(
|
|
||||||
url,
|
|
||||||
json={'actions': actions, 'resource': resource_str}
|
|
||||||
)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
self._print_table_simple(res_json['data'])
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"Fail to grant role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _revoke_permission(self, command):
|
|
||||||
role_name_tree: Tree = command['role_name']
|
|
||||||
role_name_str: str = role_name_tree.children[0].strip("'\"")
|
|
||||||
resource_tree: Tree = command['resource']
|
|
||||||
resource_str: str = resource_tree.children[0].strip("'\"")
|
|
||||||
action_tree_list: list = command['actions']
|
|
||||||
actions: list = []
|
|
||||||
for action_tree in action_tree_list:
|
|
||||||
action_str: str = action_tree.children[0].strip("'\"")
|
|
||||||
actions.append(action_str)
|
|
||||||
print(f"revoke role_name: {role_name_str}, resource: {resource_str}, actions: {actions}")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name_str}/permission'
|
|
||||||
response = self.session.delete(
|
|
||||||
url,
|
|
||||||
json={'actions': actions, 'resource': resource_str}
|
|
||||||
)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
self._print_table_simple(res_json['data'])
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"Fail to revoke role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _alter_user_role(self, command):
|
|
||||||
role_name_tree: Tree = command['role_name']
|
|
||||||
role_name_str: str = role_name_tree.children[0].strip("'\"")
|
|
||||||
user_name_tree: Tree = command['user_name']
|
|
||||||
user_name_str: str = user_name_tree.children[0].strip("'\"")
|
|
||||||
print(f"alter_user_role user_name: {user_name_str}, role_name: {role_name_str}")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name_str}/role'
|
|
||||||
response = self.session.put(
|
|
||||||
url,
|
|
||||||
json={'role_name': role_name_str}
|
|
||||||
)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
self._print_table_simple(res_json['data'])
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"Fail to alter user: {user_name_str} to role {role_name_str}, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _show_user_permission(self, command):
|
|
||||||
user_name_tree: Tree = command['user_name']
|
|
||||||
user_name_str: str = user_name_tree.children[0].strip("'\"")
|
|
||||||
print(f"show_user_permission user_name: {user_name_str}")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name_str}/permission'
|
|
||||||
response = self.session.get(url)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
self._print_table_simple(res_json['data'])
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"Fail to show user: {user_name_str} permission, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _show_version(self, command):
|
|
||||||
print("show_version")
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/version'
|
|
||||||
response = self.session.get(url)
|
|
||||||
res_json = response.json()
|
|
||||||
if response.status_code == 200:
|
|
||||||
self._print_table_simple(res_json['data'])
|
|
||||||
else:
|
|
||||||
print(f"Fail to show version, code: {res_json['code']}, message: {res_json['message']}")
|
|
||||||
|
|
||||||
def _handle_meta_command(self, command):
|
|
||||||
meta_command = command['command']
|
|
||||||
args = command.get('args', [])
|
|
||||||
|
|
||||||
if meta_command in ['?', 'h', 'help']:
|
|
||||||
self.show_help()
|
|
||||||
elif meta_command in ['q', 'quit', 'exit']:
|
|
||||||
print("Goodbye!")
|
|
||||||
else:
|
|
||||||
print(f"Meta command '{meta_command}' with args {args}")
|
|
||||||
|
|
||||||
def show_help(self):
|
|
||||||
"""Help info"""
|
|
||||||
help_text = """
|
|
||||||
Commands:
|
|
||||||
LIST SERVICES
|
|
||||||
SHOW SERVICE <service>
|
|
||||||
STARTUP SERVICE <service>
|
|
||||||
SHUTDOWN SERVICE <service>
|
|
||||||
RESTART SERVICE <service>
|
|
||||||
LIST USERS
|
|
||||||
SHOW USER <user>
|
|
||||||
DROP USER <user>
|
|
||||||
CREATE USER <user> <password>
|
|
||||||
ALTER USER PASSWORD <user> <new_password>
|
|
||||||
ALTER USER ACTIVE <user> <on/off>
|
|
||||||
LIST DATASETS OF <user>
|
|
||||||
LIST AGENTS OF <user>
|
|
||||||
|
|
||||||
Meta Commands:
|
|
||||||
\\?, \\h, \\help Show this help
|
|
||||||
\\q, \\quit, \\exit Quit the CLI
|
|
||||||
"""
|
|
||||||
print(help_text)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
import sys
|
|
||||||
|
|
||||||
cli = AdminCLI()
|
|
||||||
|
|
||||||
args = cli.parse_connection_args(sys.argv)
|
|
||||||
if 'error' in args:
|
|
||||||
print("Error: Invalid connection arguments")
|
|
||||||
return
|
|
||||||
|
|
||||||
if 'command' in args:
|
|
||||||
if 'password' not in args:
|
|
||||||
print("Error: password is missing")
|
|
||||||
return
|
|
||||||
if cli.verify_admin(args, single_command=True):
|
|
||||||
command: str = args['command']
|
|
||||||
# print(f"Run single command: {command}")
|
|
||||||
cli.run_single_command(command)
|
|
||||||
else:
|
|
||||||
if cli.verify_admin(args, single_command=False):
|
|
||||||
print(r"""
|
|
||||||
____ ___ ______________ ___ __ _
|
|
||||||
/ __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___
|
|
||||||
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \
|
|
||||||
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /
|
|
||||||
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
|
|
||||||
""")
|
|
||||||
cli.cmdloop()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
166
admin/client/http_client.py
Normal file
166
admin/client/http_client.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
# from requests.sessions import HTTPAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class HttpClient:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str = "127.0.0.1",
|
||||||
|
port: int = 9381,
|
||||||
|
api_version: str = "v1",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
connect_timeout: float = 5.0,
|
||||||
|
read_timeout: float = 60.0,
|
||||||
|
verify_ssl: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.api_version = api_version
|
||||||
|
self.api_key = api_key
|
||||||
|
self.login_token: str | None = None
|
||||||
|
self.connect_timeout = connect_timeout
|
||||||
|
self.read_timeout = read_timeout
|
||||||
|
self.verify_ssl = verify_ssl
|
||||||
|
|
||||||
|
def api_base(self) -> str:
|
||||||
|
return f"{self.host}:{self.port}/api/{self.api_version}"
|
||||||
|
|
||||||
|
def non_api_base(self) -> str:
|
||||||
|
return f"{self.host}:{self.port}/{self.api_version}"
|
||||||
|
|
||||||
|
def build_url(self, path: str, use_api_base: bool = True) -> str:
|
||||||
|
base = self.api_base() if use_api_base else self.non_api_base()
|
||||||
|
if self.verify_ssl:
|
||||||
|
return f"https://{base}/{path.lstrip('/')}"
|
||||||
|
else:
|
||||||
|
return f"http://{base}/{path.lstrip('/')}"
|
||||||
|
|
||||||
|
def _headers(self, auth_kind: Optional[str], extra: Optional[Dict[str, str]]) -> Dict[str, str]:
|
||||||
|
headers = {}
|
||||||
|
if auth_kind == "api" and self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
elif auth_kind == "web" and self.login_token:
|
||||||
|
headers["Authorization"] = self.login_token
|
||||||
|
elif auth_kind == "admin" and self.login_token:
|
||||||
|
headers["Authorization"] = self.login_token
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
if extra:
|
||||||
|
headers.update(extra)
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
use_api_base: bool = True,
|
||||||
|
auth_kind: Optional[str] = "api",
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
json_body: Optional[Dict[str, Any]] = None,
|
||||||
|
data: Any = None,
|
||||||
|
files: Any = None,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
iterations: int = 1,
|
||||||
|
) -> requests.Response | dict:
|
||||||
|
url = self.build_url(path, use_api_base=use_api_base)
|
||||||
|
merged_headers = self._headers(auth_kind, headers)
|
||||||
|
# timeout: Tuple[float, float] = (self.connect_timeout, self.read_timeout)
|
||||||
|
session = requests.Session()
|
||||||
|
# adapter = HTTPAdapter(pool_connections=100, pool_maxsize=100)
|
||||||
|
# session.mount("http://", adapter)
|
||||||
|
if iterations > 1:
|
||||||
|
response_list = []
|
||||||
|
total_duration = 0.0
|
||||||
|
for _ in range(iterations):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
response = session.get(url, headers=merged_headers, json=json_body, data=data, stream=stream)
|
||||||
|
# response = requests.request(
|
||||||
|
# method=method,
|
||||||
|
# url=url,
|
||||||
|
# headers=merged_headers,
|
||||||
|
# json=json_body,
|
||||||
|
# data=data,
|
||||||
|
# files=files,
|
||||||
|
# params=params,
|
||||||
|
# timeout=timeout,
|
||||||
|
# stream=stream,
|
||||||
|
# verify=self.verify_ssl,
|
||||||
|
# )
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
total_duration += end_time - start_time
|
||||||
|
response_list.append(response)
|
||||||
|
return {"duration": total_duration, "response_list": response_list}
|
||||||
|
else:
|
||||||
|
return session.get(url, headers=merged_headers, json=json_body, data=data, stream=stream)
|
||||||
|
# return requests.request(
|
||||||
|
# method=method,
|
||||||
|
# url=url,
|
||||||
|
# headers=merged_headers,
|
||||||
|
# json=json_body,
|
||||||
|
# data=data,
|
||||||
|
# files=files,
|
||||||
|
# params=params,
|
||||||
|
# timeout=timeout,
|
||||||
|
# stream=stream,
|
||||||
|
# verify=self.verify_ssl,
|
||||||
|
# )
|
||||||
|
|
||||||
|
def request_json(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
use_api_base: bool = True,
|
||||||
|
auth_kind: Optional[str] = "api",
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
json_body: Optional[Dict[str, Any]] = None,
|
||||||
|
data: Any = None,
|
||||||
|
files: Any = None,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
response = self.request(
|
||||||
|
method,
|
||||||
|
path,
|
||||||
|
use_api_base=use_api_base,
|
||||||
|
auth_kind=auth_kind,
|
||||||
|
headers=headers,
|
||||||
|
json_body=json_body,
|
||||||
|
data=data,
|
||||||
|
files=files,
|
||||||
|
params=params,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return response.json()
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(f"Non-JSON response from {path}: {exc}") from exc
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_json_bytes(raw: bytes) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
return json.loads(raw.decode("utf-8"))
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(f"Invalid JSON payload: {exc}") from exc
|
||||||
623
admin/client/parser.py
Normal file
623
admin/client/parser.py
Normal file
@ -0,0 +1,623 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from lark import Transformer
|
||||||
|
|
||||||
|
GRAMMAR = r"""
|
||||||
|
start: command
|
||||||
|
|
||||||
|
command: sql_command | meta_command
|
||||||
|
|
||||||
|
sql_command: login_user
|
||||||
|
| ping_server
|
||||||
|
| list_services
|
||||||
|
| show_service
|
||||||
|
| startup_service
|
||||||
|
| shutdown_service
|
||||||
|
| restart_service
|
||||||
|
| register_user
|
||||||
|
| list_users
|
||||||
|
| show_user
|
||||||
|
| drop_user
|
||||||
|
| alter_user
|
||||||
|
| create_user
|
||||||
|
| activate_user
|
||||||
|
| list_datasets
|
||||||
|
| list_agents
|
||||||
|
| create_role
|
||||||
|
| drop_role
|
||||||
|
| alter_role
|
||||||
|
| list_roles
|
||||||
|
| show_role
|
||||||
|
| grant_permission
|
||||||
|
| revoke_permission
|
||||||
|
| alter_user_role
|
||||||
|
| show_user_permission
|
||||||
|
| show_version
|
||||||
|
| grant_admin
|
||||||
|
| revoke_admin
|
||||||
|
| set_variable
|
||||||
|
| show_variable
|
||||||
|
| list_variables
|
||||||
|
| list_configs
|
||||||
|
| list_environments
|
||||||
|
| generate_key
|
||||||
|
| list_keys
|
||||||
|
| drop_key
|
||||||
|
| show_current_user
|
||||||
|
| set_default_llm
|
||||||
|
| set_default_vlm
|
||||||
|
| set_default_embedding
|
||||||
|
| set_default_reranker
|
||||||
|
| set_default_asr
|
||||||
|
| set_default_tts
|
||||||
|
| reset_default_llm
|
||||||
|
| reset_default_vlm
|
||||||
|
| reset_default_embedding
|
||||||
|
| reset_default_reranker
|
||||||
|
| reset_default_asr
|
||||||
|
| reset_default_tts
|
||||||
|
| create_model_provider
|
||||||
|
| drop_model_provider
|
||||||
|
| create_user_dataset_with_parser
|
||||||
|
| create_user_dataset_with_pipeline
|
||||||
|
| drop_user_dataset
|
||||||
|
| list_user_datasets
|
||||||
|
| list_user_dataset_files
|
||||||
|
| list_user_agents
|
||||||
|
| list_user_chats
|
||||||
|
| create_user_chat
|
||||||
|
| drop_user_chat
|
||||||
|
| list_user_model_providers
|
||||||
|
| list_user_default_models
|
||||||
|
| parse_dataset_docs
|
||||||
|
| parse_dataset_sync
|
||||||
|
| parse_dataset_async
|
||||||
|
| import_docs_into_dataset
|
||||||
|
| search_on_datasets
|
||||||
|
| benchmark
|
||||||
|
|
||||||
|
// meta command definition
|
||||||
|
meta_command: "\\" meta_command_name [meta_args]
|
||||||
|
|
||||||
|
meta_command_name: /[a-zA-Z?]+/
|
||||||
|
meta_args: (meta_arg)+
|
||||||
|
|
||||||
|
meta_arg: /[^\\s"']+/ | quoted_string
|
||||||
|
|
||||||
|
// command definition
|
||||||
|
|
||||||
|
LOGIN: "LOGIN"i
|
||||||
|
REGISTER: "REGISTER"i
|
||||||
|
LIST: "LIST"i
|
||||||
|
SERVICES: "SERVICES"i
|
||||||
|
SHOW: "SHOW"i
|
||||||
|
CREATE: "CREATE"i
|
||||||
|
SERVICE: "SERVICE"i
|
||||||
|
SHUTDOWN: "SHUTDOWN"i
|
||||||
|
STARTUP: "STARTUP"i
|
||||||
|
RESTART: "RESTART"i
|
||||||
|
USERS: "USERS"i
|
||||||
|
DROP: "DROP"i
|
||||||
|
USER: "USER"i
|
||||||
|
ALTER: "ALTER"i
|
||||||
|
ACTIVE: "ACTIVE"i
|
||||||
|
ADMIN: "ADMIN"i
|
||||||
|
PASSWORD: "PASSWORD"i
|
||||||
|
DATASET: "DATASET"i
|
||||||
|
DATASETS: "DATASETS"i
|
||||||
|
OF: "OF"i
|
||||||
|
AGENTS: "AGENTS"i
|
||||||
|
ROLE: "ROLE"i
|
||||||
|
ROLES: "ROLES"i
|
||||||
|
DESCRIPTION: "DESCRIPTION"i
|
||||||
|
GRANT: "GRANT"i
|
||||||
|
REVOKE: "REVOKE"i
|
||||||
|
ALL: "ALL"i
|
||||||
|
PERMISSION: "PERMISSION"i
|
||||||
|
TO: "TO"i
|
||||||
|
FROM: "FROM"i
|
||||||
|
FOR: "FOR"i
|
||||||
|
RESOURCES: "RESOURCES"i
|
||||||
|
ON: "ON"i
|
||||||
|
SET: "SET"i
|
||||||
|
RESET: "RESET"i
|
||||||
|
VERSION: "VERSION"i
|
||||||
|
VAR: "VAR"i
|
||||||
|
VARS: "VARS"i
|
||||||
|
CONFIGS: "CONFIGS"i
|
||||||
|
ENVS: "ENVS"i
|
||||||
|
KEY: "KEY"i
|
||||||
|
KEYS: "KEYS"i
|
||||||
|
GENERATE: "GENERATE"i
|
||||||
|
MODEL: "MODEL"i
|
||||||
|
MODELS: "MODELS"i
|
||||||
|
PROVIDER: "PROVIDER"i
|
||||||
|
PROVIDERS: "PROVIDERS"i
|
||||||
|
DEFAULT: "DEFAULT"i
|
||||||
|
CHATS: "CHATS"i
|
||||||
|
CHAT: "CHAT"i
|
||||||
|
FILES: "FILES"i
|
||||||
|
AS: "AS"i
|
||||||
|
PARSE: "PARSE"i
|
||||||
|
IMPORT: "IMPORT"i
|
||||||
|
INTO: "INTO"i
|
||||||
|
WITH: "WITH"i
|
||||||
|
PARSER: "PARSER"i
|
||||||
|
PIPELINE: "PIPELINE"i
|
||||||
|
SEARCH: "SEARCH"i
|
||||||
|
CURRENT: "CURRENT"i
|
||||||
|
LLM: "LLM"i
|
||||||
|
VLM: "VLM"i
|
||||||
|
EMBEDDING: "EMBEDDING"i
|
||||||
|
RERANKER: "RERANKER"i
|
||||||
|
ASR: "ASR"i
|
||||||
|
TTS: "TTS"i
|
||||||
|
ASYNC: "ASYNC"i
|
||||||
|
SYNC: "SYNC"i
|
||||||
|
BENCHMARK: "BENCHMARK"i
|
||||||
|
PING: "PING"i
|
||||||
|
|
||||||
|
login_user: LOGIN USER quoted_string ";"
|
||||||
|
list_services: LIST SERVICES ";"
|
||||||
|
show_service: SHOW SERVICE NUMBER ";"
|
||||||
|
startup_service: STARTUP SERVICE NUMBER ";"
|
||||||
|
shutdown_service: SHUTDOWN SERVICE NUMBER ";"
|
||||||
|
restart_service: RESTART SERVICE NUMBER ";"
|
||||||
|
|
||||||
|
register_user: REGISTER USER quoted_string AS quoted_string PASSWORD quoted_string ";"
|
||||||
|
list_users: LIST USERS ";"
|
||||||
|
drop_user: DROP USER quoted_string ";"
|
||||||
|
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
|
||||||
|
show_user: SHOW USER quoted_string ";"
|
||||||
|
create_user: CREATE USER quoted_string quoted_string ";"
|
||||||
|
activate_user: ALTER USER ACTIVE quoted_string status ";"
|
||||||
|
|
||||||
|
list_datasets: LIST DATASETS OF quoted_string ";"
|
||||||
|
list_agents: LIST AGENTS OF quoted_string ";"
|
||||||
|
|
||||||
|
create_role: CREATE ROLE identifier [DESCRIPTION quoted_string] ";"
|
||||||
|
drop_role: DROP ROLE identifier ";"
|
||||||
|
alter_role: ALTER ROLE identifier SET DESCRIPTION quoted_string ";"
|
||||||
|
list_roles: LIST ROLES ";"
|
||||||
|
show_role: SHOW ROLE identifier ";"
|
||||||
|
|
||||||
|
grant_permission: GRANT identifier_list ON identifier TO ROLE identifier ";"
|
||||||
|
revoke_permission: REVOKE identifier_list ON identifier FROM ROLE identifier ";"
|
||||||
|
alter_user_role: ALTER USER quoted_string SET ROLE identifier ";"
|
||||||
|
show_user_permission: SHOW USER PERMISSION quoted_string ";"
|
||||||
|
|
||||||
|
show_version: SHOW VERSION ";"
|
||||||
|
|
||||||
|
grant_admin: GRANT ADMIN quoted_string ";"
|
||||||
|
revoke_admin: REVOKE ADMIN quoted_string ";"
|
||||||
|
|
||||||
|
generate_key: GENERATE KEY FOR USER quoted_string ";"
|
||||||
|
list_keys: LIST KEYS OF quoted_string ";"
|
||||||
|
drop_key: DROP KEY quoted_string OF quoted_string ";"
|
||||||
|
|
||||||
|
set_variable: SET VAR identifier identifier ";"
|
||||||
|
show_variable: SHOW VAR identifier ";"
|
||||||
|
list_variables: LIST VARS ";"
|
||||||
|
list_configs: LIST CONFIGS ";"
|
||||||
|
list_environments: LIST ENVS ";"
|
||||||
|
|
||||||
|
benchmark: BENCHMARK NUMBER NUMBER user_statement
|
||||||
|
|
||||||
|
user_statement: ping_server
|
||||||
|
| show_current_user
|
||||||
|
| create_model_provider
|
||||||
|
| drop_model_provider
|
||||||
|
| set_default_llm
|
||||||
|
| set_default_vlm
|
||||||
|
| set_default_embedding
|
||||||
|
| set_default_reranker
|
||||||
|
| set_default_asr
|
||||||
|
| set_default_tts
|
||||||
|
| reset_default_llm
|
||||||
|
| reset_default_vlm
|
||||||
|
| reset_default_embedding
|
||||||
|
| reset_default_reranker
|
||||||
|
| reset_default_asr
|
||||||
|
| reset_default_tts
|
||||||
|
| create_user_dataset_with_parser
|
||||||
|
| create_user_dataset_with_pipeline
|
||||||
|
| drop_user_dataset
|
||||||
|
| list_user_datasets
|
||||||
|
| list_user_dataset_files
|
||||||
|
| list_user_agents
|
||||||
|
| list_user_chats
|
||||||
|
| create_user_chat
|
||||||
|
| drop_user_chat
|
||||||
|
| list_user_model_providers
|
||||||
|
| list_user_default_models
|
||||||
|
| import_docs_into_dataset
|
||||||
|
| search_on_datasets
|
||||||
|
|
||||||
|
ping_server: PING ";"
|
||||||
|
show_current_user: SHOW CURRENT USER ";"
|
||||||
|
create_model_provider: CREATE MODEL PROVIDER quoted_string quoted_string ";"
|
||||||
|
drop_model_provider: DROP MODEL PROVIDER quoted_string ";"
|
||||||
|
set_default_llm: SET DEFAULT LLM quoted_string ";"
|
||||||
|
set_default_vlm: SET DEFAULT VLM quoted_string ";"
|
||||||
|
set_default_embedding: SET DEFAULT EMBEDDING quoted_string ";"
|
||||||
|
set_default_reranker: SET DEFAULT RERANKER quoted_string ";"
|
||||||
|
set_default_asr: SET DEFAULT ASR quoted_string ";"
|
||||||
|
set_default_tts: SET DEFAULT TTS quoted_string ";"
|
||||||
|
|
||||||
|
reset_default_llm: RESET DEFAULT LLM ";"
|
||||||
|
reset_default_vlm: RESET DEFAULT VLM ";"
|
||||||
|
reset_default_embedding: RESET DEFAULT EMBEDDING ";"
|
||||||
|
reset_default_reranker: RESET DEFAULT RERANKER ";"
|
||||||
|
reset_default_asr: RESET DEFAULT ASR ";"
|
||||||
|
reset_default_tts: RESET DEFAULT TTS ";"
|
||||||
|
|
||||||
|
list_user_datasets: LIST DATASETS ";"
|
||||||
|
create_user_dataset_with_parser: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PARSER quoted_string ";"
|
||||||
|
create_user_dataset_with_pipeline: CREATE DATASET quoted_string WITH EMBEDDING quoted_string PIPELINE quoted_string ";"
|
||||||
|
drop_user_dataset: DROP DATASET quoted_string ";"
|
||||||
|
list_user_dataset_files: LIST FILES OF DATASET quoted_string ";"
|
||||||
|
list_user_agents: LIST AGENTS ";"
|
||||||
|
list_user_chats: LIST CHATS ";"
|
||||||
|
create_user_chat: CREATE CHAT quoted_string ";"
|
||||||
|
drop_user_chat: DROP CHAT quoted_string ";"
|
||||||
|
list_user_model_providers: LIST MODEL PROVIDERS ";"
|
||||||
|
list_user_default_models: LIST DEFAULT MODELS ";"
|
||||||
|
import_docs_into_dataset: IMPORT quoted_string INTO DATASET quoted_string ";"
|
||||||
|
search_on_datasets: SEARCH quoted_string ON DATASETS quoted_string ";"
|
||||||
|
|
||||||
|
parse_dataset_docs: PARSE quoted_string OF DATASET quoted_string ";"
|
||||||
|
parse_dataset_sync: PARSE DATASET quoted_string SYNC ";"
|
||||||
|
parse_dataset_async: PARSE DATASET quoted_string ASYNC ";"
|
||||||
|
|
||||||
|
identifier_list: identifier ("," identifier)*
|
||||||
|
|
||||||
|
identifier: WORD
|
||||||
|
quoted_string: QUOTED_STRING
|
||||||
|
status: WORD
|
||||||
|
|
||||||
|
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
|
||||||
|
WORD: /[a-zA-Z0-9_\-\.]+/
|
||||||
|
NUMBER: /[0-9]+/
|
||||||
|
|
||||||
|
%import common.WS
|
||||||
|
%ignore WS
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class RAGFlowCLITransformer(Transformer):
|
||||||
|
def start(self, items):
|
||||||
|
return items[0]
|
||||||
|
|
||||||
|
def command(self, items):
|
||||||
|
return items[0]
|
||||||
|
|
||||||
|
def login_user(self, items):
|
||||||
|
email = items[2].children[0].strip("'\"")
|
||||||
|
return {"type": "login_user", "email": email}
|
||||||
|
|
||||||
|
def ping_server(self, items):
|
||||||
|
return {"type": "ping_server"}
|
||||||
|
|
||||||
|
def list_services(self, items):
|
||||||
|
result = {"type": "list_services"}
|
||||||
|
return result
|
||||||
|
|
||||||
|
def show_service(self, items):
|
||||||
|
service_id = int(items[2])
|
||||||
|
return {"type": "show_service", "number": service_id}
|
||||||
|
|
||||||
|
def startup_service(self, items):
|
||||||
|
service_id = int(items[2])
|
||||||
|
return {"type": "startup_service", "number": service_id}
|
||||||
|
|
||||||
|
def shutdown_service(self, items):
|
||||||
|
service_id = int(items[2])
|
||||||
|
return {"type": "shutdown_service", "number": service_id}
|
||||||
|
|
||||||
|
def restart_service(self, items):
|
||||||
|
service_id = int(items[2])
|
||||||
|
return {"type": "restart_service", "number": service_id}
|
||||||
|
|
||||||
|
def register_user(self, items):
|
||||||
|
user_name: str = items[2].children[0].strip("'\"")
|
||||||
|
nickname: str = items[4].children[0].strip("'\"")
|
||||||
|
password: str = items[6].children[0].strip("'\"")
|
||||||
|
return {"type": "register_user", "user_name": user_name, "nickname": nickname, "password": password}
|
||||||
|
|
||||||
|
def list_users(self, items):
|
||||||
|
return {"type": "list_users"}
|
||||||
|
|
||||||
|
def show_user(self, items):
|
||||||
|
user_name = items[2]
|
||||||
|
return {"type": "show_user", "user_name": user_name}
|
||||||
|
|
||||||
|
def drop_user(self, items):
|
||||||
|
user_name = items[2]
|
||||||
|
return {"type": "drop_user", "user_name": user_name}
|
||||||
|
|
||||||
|
def alter_user(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
new_password = items[4]
|
||||||
|
return {"type": "alter_user", "user_name": user_name, "password": new_password}
|
||||||
|
|
||||||
|
def create_user(self, items):
|
||||||
|
user_name = items[2]
|
||||||
|
password = items[3]
|
||||||
|
return {"type": "create_user", "user_name": user_name, "password": password, "role": "user"}
|
||||||
|
|
||||||
|
def activate_user(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
activate_status = items[4]
|
||||||
|
return {"type": "activate_user", "activate_status": activate_status, "user_name": user_name}
|
||||||
|
|
||||||
|
def list_datasets(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
return {"type": "list_datasets", "user_name": user_name}
|
||||||
|
|
||||||
|
def list_agents(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
return {"type": "list_agents", "user_name": user_name}
|
||||||
|
|
||||||
|
def create_role(self, items):
|
||||||
|
role_name = items[2]
|
||||||
|
if len(items) > 4:
|
||||||
|
description = items[4]
|
||||||
|
return {"type": "create_role", "role_name": role_name, "description": description}
|
||||||
|
else:
|
||||||
|
return {"type": "create_role", "role_name": role_name}
|
||||||
|
|
||||||
|
def drop_role(self, items):
|
||||||
|
role_name = items[2]
|
||||||
|
return {"type": "drop_role", "role_name": role_name}
|
||||||
|
|
||||||
|
def alter_role(self, items):
|
||||||
|
role_name = items[2]
|
||||||
|
description = items[5]
|
||||||
|
return {"type": "alter_role", "role_name": role_name, "description": description}
|
||||||
|
|
||||||
|
def list_roles(self, items):
|
||||||
|
return {"type": "list_roles"}
|
||||||
|
|
||||||
|
def show_role(self, items):
|
||||||
|
role_name = items[2]
|
||||||
|
return {"type": "show_role", "role_name": role_name}
|
||||||
|
|
||||||
|
def grant_permission(self, items):
|
||||||
|
action_list = items[1]
|
||||||
|
resource = items[3]
|
||||||
|
role_name = items[6]
|
||||||
|
return {"type": "grant_permission", "role_name": role_name, "resource": resource, "actions": action_list}
|
||||||
|
|
||||||
|
def revoke_permission(self, items):
|
||||||
|
action_list = items[1]
|
||||||
|
resource = items[3]
|
||||||
|
role_name = items[6]
|
||||||
|
return {"type": "revoke_permission", "role_name": role_name, "resource": resource, "actions": action_list}
|
||||||
|
|
||||||
|
def alter_user_role(self, items):
|
||||||
|
user_name = items[2]
|
||||||
|
role_name = items[5]
|
||||||
|
return {"type": "alter_user_role", "user_name": user_name, "role_name": role_name}
|
||||||
|
|
||||||
|
def show_user_permission(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
return {"type": "show_user_permission", "user_name": user_name}
|
||||||
|
|
||||||
|
def show_version(self, items):
|
||||||
|
return {"type": "show_version"}
|
||||||
|
|
||||||
|
def grant_admin(self, items):
|
||||||
|
user_name = items[2]
|
||||||
|
return {"type": "grant_admin", "user_name": user_name}
|
||||||
|
|
||||||
|
def revoke_admin(self, items):
|
||||||
|
user_name = items[2]
|
||||||
|
return {"type": "revoke_admin", "user_name": user_name}
|
||||||
|
|
||||||
|
def generate_key(self, items):
|
||||||
|
user_name = items[4]
|
||||||
|
return {"type": "generate_key", "user_name": user_name}
|
||||||
|
|
||||||
|
def list_keys(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
return {"type": "list_keys", "user_name": user_name}
|
||||||
|
|
||||||
|
def drop_key(self, items):
|
||||||
|
key = items[2]
|
||||||
|
user_name = items[4]
|
||||||
|
return {"type": "drop_key", "key": key, "user_name": user_name}
|
||||||
|
|
||||||
|
def set_variable(self, items):
|
||||||
|
var_name = items[2]
|
||||||
|
var_value = items[3]
|
||||||
|
return {"type": "set_variable", "var_name": var_name, "var_value": var_value}
|
||||||
|
|
||||||
|
def show_variable(self, items):
|
||||||
|
var_name = items[2]
|
||||||
|
return {"type": "show_variable", "var_name": var_name}
|
||||||
|
|
||||||
|
def list_variables(self, items):
|
||||||
|
return {"type": "list_variables"}
|
||||||
|
|
||||||
|
def list_configs(self, items):
|
||||||
|
return {"type": "list_configs"}
|
||||||
|
|
||||||
|
def list_environments(self, items):
|
||||||
|
return {"type": "list_environments"}
|
||||||
|
|
||||||
|
def create_model_provider(self, items):
|
||||||
|
provider_name = items[3].children[0].strip("'\"")
|
||||||
|
provider_key = items[4].children[0].strip("'\"")
|
||||||
|
return {"type": "create_model_provider", "provider_name": provider_name, "provider_key": provider_key}
|
||||||
|
|
||||||
|
def drop_model_provider(self, items):
|
||||||
|
provider_name = items[3].children[0].strip("'\"")
|
||||||
|
return {"type": "drop_model_provider", "provider_name": provider_name}
|
||||||
|
|
||||||
|
def show_current_user(self, items):
|
||||||
|
return {"type": "show_current_user"}
|
||||||
|
|
||||||
|
def set_default_llm(self, items):
|
||||||
|
llm_id = items[3].children[0].strip("'\"")
|
||||||
|
return {"type": "set_default_model", "model_type": "llm_id", "model_id": llm_id}
|
||||||
|
|
||||||
|
def set_default_vlm(self, items):
|
||||||
|
vlm_id = items[3].children[0].strip("'\"")
|
||||||
|
return {"type": "set_default_model", "model_type": "img2txt_id", "model_id": vlm_id}
|
||||||
|
|
||||||
|
def set_default_embedding(self, items):
|
||||||
|
embedding_id = items[3].children[0].strip("'\"")
|
||||||
|
return {"type": "set_default_model", "model_type": "embd_id", "model_id": embedding_id}
|
||||||
|
|
||||||
|
def set_default_reranker(self, items):
|
||||||
|
reranker_id = items[3].children[0].strip("'\"")
|
||||||
|
return {"type": "set_default_model", "model_type": "reranker_id", "model_id": reranker_id}
|
||||||
|
|
||||||
|
def set_default_asr(self, items):
|
||||||
|
asr_id = items[3].children[0].strip("'\"")
|
||||||
|
return {"type": "set_default_model", "model_type": "asr_id", "model_id": asr_id}
|
||||||
|
|
||||||
|
def set_default_tts(self, items):
|
||||||
|
tts_id = items[3].children[0].strip("'\"")
|
||||||
|
return {"type": "set_default_model", "model_type": "tts_id", "model_id": tts_id}
|
||||||
|
|
||||||
|
def reset_default_llm(self, items):
|
||||||
|
return {"type": "reset_default_model", "model_type": "llm_id"}
|
||||||
|
|
||||||
|
def reset_default_vlm(self, items):
|
||||||
|
return {"type": "reset_default_model", "model_type": "img2txt_id"}
|
||||||
|
|
||||||
|
def reset_default_embedding(self, items):
|
||||||
|
return {"type": "reset_default_model", "model_type": "embd_id"}
|
||||||
|
|
||||||
|
def reset_default_reranker(self, items):
|
||||||
|
return {"type": "reset_default_model", "model_type": "reranker_id"}
|
||||||
|
|
||||||
|
def reset_default_asr(self, items):
|
||||||
|
return {"type": "reset_default_model", "model_type": "asr_id"}
|
||||||
|
|
||||||
|
def reset_default_tts(self, items):
|
||||||
|
return {"type": "reset_default_model", "model_type": "tts_id"}
|
||||||
|
|
||||||
|
def list_user_datasets(self, items):
|
||||||
|
return {"type": "list_user_datasets"}
|
||||||
|
|
||||||
|
def create_user_dataset_with_parser(self, items):
|
||||||
|
dataset_name = items[2].children[0].strip("'\"")
|
||||||
|
embedding = items[5].children[0].strip("'\"")
|
||||||
|
parser_type = items[7].children[0].strip("'\"")
|
||||||
|
return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding,
|
||||||
|
"parser_type": parser_type}
|
||||||
|
|
||||||
|
def create_user_dataset_with_pipeline(self, items):
|
||||||
|
dataset_name = items[2].children[0].strip("'\"")
|
||||||
|
embedding = items[5].children[0].strip("'\"")
|
||||||
|
pipeline = items[7].children[0].strip("'\"")
|
||||||
|
return {"type": "create_user_dataset", "dataset_name": dataset_name, "embedding": embedding,
|
||||||
|
"pipeline": pipeline}
|
||||||
|
|
||||||
|
def drop_user_dataset(self, items):
|
||||||
|
dataset_name = items[2].children[0].strip("'\"")
|
||||||
|
return {"type": "drop_user_dataset", "dataset_name": dataset_name}
|
||||||
|
|
||||||
|
def list_user_dataset_files(self, items):
|
||||||
|
dataset_name = items[4].children[0].strip("'\"")
|
||||||
|
return {"type": "list_user_dataset_files", "dataset_name": dataset_name}
|
||||||
|
|
||||||
|
def list_user_agents(self, items):
|
||||||
|
return {"type": "list_user_agents"}
|
||||||
|
|
||||||
|
def list_user_chats(self, items):
|
||||||
|
return {"type": "list_user_chats"}
|
||||||
|
|
||||||
|
def create_user_chat(self, items):
|
||||||
|
chat_name = items[2].children[0].strip("'\"")
|
||||||
|
return {"type": "create_user_chat", "chat_name": chat_name}
|
||||||
|
|
||||||
|
def drop_user_chat(self, items):
|
||||||
|
chat_name = items[2].children[0].strip("'\"")
|
||||||
|
return {"type": "drop_user_chat", "chat_name": chat_name}
|
||||||
|
|
||||||
|
def list_user_model_providers(self, items):
|
||||||
|
return {"type": "list_user_model_providers"}
|
||||||
|
|
||||||
|
def list_user_default_models(self, items):
|
||||||
|
return {"type": "list_user_default_models"}
|
||||||
|
|
||||||
|
def parse_dataset_docs(self, items):
|
||||||
|
document_list_str = items[1].children[0].strip("'\"")
|
||||||
|
document_names = document_list_str.split(",")
|
||||||
|
if len(document_names) == 1:
|
||||||
|
document_names = document_names[0]
|
||||||
|
document_names = document_names.split(" ")
|
||||||
|
dataset_name = items[4].children[0].strip("'\"")
|
||||||
|
return {"type": "parse_dataset_docs", "dataset_name": dataset_name, "document_names": document_names}
|
||||||
|
|
||||||
|
def parse_dataset_sync(self, items):
|
||||||
|
dataset_name = items[2].children[0].strip("'\"")
|
||||||
|
return {"type": "parse_dataset", "dataset_name": dataset_name, "method": "sync"}
|
||||||
|
|
||||||
|
def parse_dataset_async(self, items):
|
||||||
|
dataset_name = items[2].children[0].strip("'\"")
|
||||||
|
return {"type": "parse_dataset", "dataset_name": dataset_name, "method": "async"}
|
||||||
|
|
||||||
|
def import_docs_into_dataset(self, items):
|
||||||
|
document_list_str = items[1].children[0].strip("'\"")
|
||||||
|
document_paths = document_list_str.split(",")
|
||||||
|
if len(document_paths) == 1:
|
||||||
|
document_paths = document_paths[0]
|
||||||
|
document_paths = document_paths.split(" ")
|
||||||
|
dataset_name = items[4].children[0].strip("'\"")
|
||||||
|
return {"type": "import_docs_into_dataset", "dataset_name": dataset_name, "document_paths": document_paths}
|
||||||
|
|
||||||
|
def search_on_datasets(self, items):
|
||||||
|
question = items[1].children[0].strip("'\"")
|
||||||
|
datasets_str = items[4].children[0].strip("'\"")
|
||||||
|
datasets = datasets_str.split(",")
|
||||||
|
if len(datasets) == 1:
|
||||||
|
datasets = datasets[0]
|
||||||
|
datasets = datasets.split(" ")
|
||||||
|
return {"type": "search_on_datasets", "datasets": datasets, "question": question}
|
||||||
|
|
||||||
|
def benchmark(self, items):
|
||||||
|
concurrency: int = int(items[1])
|
||||||
|
iterations: int = int(items[2])
|
||||||
|
command = items[3].children[0]
|
||||||
|
return {"type": "benchmark", "concurrency": concurrency, "iterations": iterations, "command": command}
|
||||||
|
|
||||||
|
def action_list(self, items):
|
||||||
|
return items
|
||||||
|
|
||||||
|
def meta_command(self, items):
|
||||||
|
command_name = str(items[0]).lower()
|
||||||
|
args = items[1:] if len(items) > 1 else []
|
||||||
|
|
||||||
|
# handle quoted parameter
|
||||||
|
parsed_args = []
|
||||||
|
for arg in args:
|
||||||
|
if hasattr(arg, "value"):
|
||||||
|
parsed_args.append(arg.value)
|
||||||
|
else:
|
||||||
|
parsed_args.append(str(arg))
|
||||||
|
|
||||||
|
return {"type": "meta", "command": command_name, "args": parsed_args}
|
||||||
|
|
||||||
|
def meta_command_name(self, items):
|
||||||
|
return items[0]
|
||||||
|
|
||||||
|
def meta_args(self, items):
|
||||||
|
return items
|
||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ragflow-cli"
|
name = "ragflow-cli"
|
||||||
version = "0.22.1"
|
version = "0.23.1"
|
||||||
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
|
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
|
||||||
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
|
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
|
||||||
license = { text = "Apache License, Version 2.0" }
|
license = { text = "Apache License, Version 2.0" }
|
||||||
@ -20,5 +20,8 @@ test = [
|
|||||||
"requests-toolbelt>=1.0.0",
|
"requests-toolbelt>=1.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
py-modules = ["ragflow_cli", "parser"]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
ragflow-cli = "admin_client:main"
|
ragflow-cli = "ragflow_cli:main"
|
||||||
|
|||||||
322
admin/client/ragflow_cli.py
Normal file
322
admin/client/ragflow_cli.py
Normal file
@ -0,0 +1,322 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import getpass
|
||||||
|
from cmd import Cmd
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import warnings
|
||||||
|
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||||
|
from Cryptodome.PublicKey import RSA
|
||||||
|
from lark import Lark, Tree
|
||||||
|
from parser import GRAMMAR, RAGFlowCLITransformer
|
||||||
|
from http_client import HttpClient
|
||||||
|
from ragflow_client import RAGFlowClient, run_command
|
||||||
|
from user import login_user
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore", category=getpass.GetPassWarning)
|
||||||
|
|
||||||
|
def encrypt(input_string):
|
||||||
|
pub = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----"
|
||||||
|
pub_key = RSA.importKey(pub)
|
||||||
|
cipher = Cipher_pkcs1_v1_5.new(pub_key)
|
||||||
|
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode("utf-8")))
|
||||||
|
return base64.b64encode(cipher_text).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def encode_to_base64(input_string):
|
||||||
|
base64_encoded = base64.b64encode(input_string.encode("utf-8"))
|
||||||
|
return base64_encoded.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class RAGFlowCLI(Cmd):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = Lark(GRAMMAR, start="start", parser="lalr", transformer=RAGFlowCLITransformer())
|
||||||
|
self.command_history = []
|
||||||
|
self.account = "admin@ragflow.io"
|
||||||
|
self.account_password: str = "admin"
|
||||||
|
self.session = requests.Session()
|
||||||
|
self.host: str = ""
|
||||||
|
self.port: int = 0
|
||||||
|
self.mode: str = "admin"
|
||||||
|
self.ragflow_client = None
|
||||||
|
|
||||||
|
intro = r"""Type "\h" for help."""
|
||||||
|
prompt = "ragflow> "
|
||||||
|
|
||||||
|
def onecmd(self, command: str) -> bool:
|
||||||
|
try:
|
||||||
|
result = self.parse_command(command)
|
||||||
|
|
||||||
|
if isinstance(result, dict):
|
||||||
|
if "type" in result and result.get("type") == "empty":
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.execute_command(result)
|
||||||
|
|
||||||
|
if isinstance(result, Tree):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if result.get("type") == "meta" and result.get("command") in ["q", "quit", "exit"]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nUse '\\q' to quit")
|
||||||
|
except EOFError:
|
||||||
|
print("\nGoodbye!")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def emptyline(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def default(self, line: str) -> bool:
|
||||||
|
return self.onecmd(line)
|
||||||
|
|
||||||
|
def parse_command(self, command_str: str) -> dict[str, str]:
|
||||||
|
if not command_str.strip():
|
||||||
|
return {"type": "empty"}
|
||||||
|
|
||||||
|
self.command_history.append(command_str)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.parser.parse(command_str)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
return {"type": "error", "message": f"Parse error: {str(e)}"}
|
||||||
|
|
||||||
|
def verify_auth(self, arguments: dict, single_command: bool, auth: bool):
|
||||||
|
server_type = arguments.get("type", "admin")
|
||||||
|
http_client = HttpClient(arguments["host"], arguments["port"])
|
||||||
|
if not auth:
|
||||||
|
self.ragflow_client = RAGFlowClient(http_client, server_type)
|
||||||
|
return True
|
||||||
|
|
||||||
|
user_name = arguments["username"]
|
||||||
|
attempt_count = 3
|
||||||
|
if single_command:
|
||||||
|
attempt_count = 1
|
||||||
|
|
||||||
|
try_count = 0
|
||||||
|
while True:
|
||||||
|
try_count += 1
|
||||||
|
if try_count > attempt_count:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if single_command:
|
||||||
|
user_password = arguments["password"]
|
||||||
|
else:
|
||||||
|
user_password = getpass.getpass(f"password for {user_name}: ").strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
token = login_user(http_client, server_type, user_name, user_password)
|
||||||
|
http_client.login_token = token
|
||||||
|
self.ragflow_client = RAGFlowClient(http_client, server_type)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(str(e))
|
||||||
|
print("Can't access server for login (connection failed)")
|
||||||
|
|
||||||
|
def _format_service_detail_table(self, data):
|
||||||
|
if isinstance(data, list):
|
||||||
|
return data
|
||||||
|
if not all([isinstance(v, list) for v in data.values()]):
|
||||||
|
# normal table
|
||||||
|
return data
|
||||||
|
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
|
||||||
|
task_executor_list = []
|
||||||
|
for k, v in data.items():
|
||||||
|
# display latest status
|
||||||
|
heartbeats = sorted(v, key=lambda x: x["now"], reverse=True)
|
||||||
|
task_executor_list.append(
|
||||||
|
{
|
||||||
|
"task_executor_name": k,
|
||||||
|
**heartbeats[0],
|
||||||
|
}
|
||||||
|
if heartbeats
|
||||||
|
else {"task_executor_name": k}
|
||||||
|
)
|
||||||
|
return task_executor_list
|
||||||
|
|
||||||
|
def _print_table_simple(self, data):
|
||||||
|
if not data:
|
||||||
|
print("No data to print")
|
||||||
|
return
|
||||||
|
if isinstance(data, dict):
|
||||||
|
# handle single row data
|
||||||
|
data = [data]
|
||||||
|
|
||||||
|
columns = list(set().union(*(d.keys() for d in data)))
|
||||||
|
columns.sort()
|
||||||
|
col_widths = {}
|
||||||
|
|
||||||
|
def get_string_width(text):
|
||||||
|
half_width_chars = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\t\n\r"
|
||||||
|
width = 0
|
||||||
|
for char in text:
|
||||||
|
if char in half_width_chars:
|
||||||
|
width += 1
|
||||||
|
else:
|
||||||
|
width += 2
|
||||||
|
return width
|
||||||
|
|
||||||
|
for col in columns:
|
||||||
|
max_width = get_string_width(str(col))
|
||||||
|
for item in data:
|
||||||
|
value_len = get_string_width(str(item.get(col, "")))
|
||||||
|
if value_len > max_width:
|
||||||
|
max_width = value_len
|
||||||
|
col_widths[col] = max(2, max_width)
|
||||||
|
|
||||||
|
# Generate delimiter
|
||||||
|
separator = "+" + "+".join(["-" * (col_widths[col] + 2) for col in columns]) + "+"
|
||||||
|
|
||||||
|
# Print header
|
||||||
|
print(separator)
|
||||||
|
header = "|" + "|".join([f" {col:<{col_widths[col]}} " for col in columns]) + "|"
|
||||||
|
print(header)
|
||||||
|
print(separator)
|
||||||
|
|
||||||
|
# Print data
|
||||||
|
for item in data:
|
||||||
|
row = "|"
|
||||||
|
for col in columns:
|
||||||
|
value = str(item.get(col, ""))
|
||||||
|
if get_string_width(value) > col_widths[col]:
|
||||||
|
value = value[: col_widths[col] - 3] + "..."
|
||||||
|
row += f" {value:<{col_widths[col] - (get_string_width(value) - len(value))}} |"
|
||||||
|
print(row)
|
||||||
|
|
||||||
|
print(separator)
|
||||||
|
|
||||||
|
def run_interactive(self, args):
|
||||||
|
if self.verify_auth(args, single_command=False, auth=args["auth"]):
|
||||||
|
print(r"""
|
||||||
|
____ ___ ______________ ________ ____
|
||||||
|
/ __ \/ | / ____/ ____/ /___ _ __ / ____/ / / _/
|
||||||
|
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / / / / / /
|
||||||
|
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / /___/ /____/ /
|
||||||
|
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ \____/_____/___/
|
||||||
|
""")
|
||||||
|
self.cmdloop()
|
||||||
|
|
||||||
|
print("RAGFlow command line interface - Type '\\?' for help, '\\q' to quit")
|
||||||
|
|
||||||
|
def run_single_command(self, args):
|
||||||
|
if self.verify_auth(args, single_command=True, auth=args["auth"]):
|
||||||
|
command = args["command"]
|
||||||
|
result = self.parse_command(command)
|
||||||
|
self.execute_command(result)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_connection_args(self, args: List[str]) -> Dict[str, Any]:
|
||||||
|
parser = argparse.ArgumentParser(description="RAGFlow CLI Client", add_help=False)
|
||||||
|
parser.add_argument("-h", "--host", default="127.0.0.1", help="Admin or RAGFlow service host")
|
||||||
|
parser.add_argument("-p", "--port", type=int, default=9381, help="Admin or RAGFlow service port")
|
||||||
|
parser.add_argument("-w", "--password", default="admin", type=str, help="Superuser password")
|
||||||
|
parser.add_argument("-t", "--type", default="admin", type=str, help="CLI mode, admin or user")
|
||||||
|
parser.add_argument("-u", "--username", default=None,
|
||||||
|
help="Username (email). In admin mode defaults to admin@ragflow.io, in user mode required.")
|
||||||
|
parser.add_argument("command", nargs="?", help="Single command")
|
||||||
|
try:
|
||||||
|
parsed_args, remaining_args = parser.parse_known_args(args)
|
||||||
|
# Determine username based on mode
|
||||||
|
username = parsed_args.username
|
||||||
|
if parsed_args.type == "admin":
|
||||||
|
if username is None:
|
||||||
|
username = "admin@ragflow.io"
|
||||||
|
|
||||||
|
if remaining_args:
|
||||||
|
if remaining_args[0] == "command":
|
||||||
|
command_str = ' '.join(remaining_args[1:]) + ';'
|
||||||
|
auth = True
|
||||||
|
if remaining_args[1] == "register":
|
||||||
|
auth = False
|
||||||
|
else:
|
||||||
|
if username is None:
|
||||||
|
print("Error: username (-u) is required in user mode")
|
||||||
|
return {"error": "Username required"}
|
||||||
|
return {
|
||||||
|
"host": parsed_args.host,
|
||||||
|
"port": parsed_args.port,
|
||||||
|
"password": parsed_args.password,
|
||||||
|
"type": parsed_args.type,
|
||||||
|
"username": username,
|
||||||
|
"command": command_str,
|
||||||
|
"auth": auth
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {"error": "Invalid command"}
|
||||||
|
else:
|
||||||
|
auth = True
|
||||||
|
if username is None:
|
||||||
|
auth = False
|
||||||
|
return {
|
||||||
|
"host": parsed_args.host,
|
||||||
|
"port": parsed_args.port,
|
||||||
|
"type": parsed_args.type,
|
||||||
|
"username": username,
|
||||||
|
"auth": auth
|
||||||
|
}
|
||||||
|
except SystemExit:
|
||||||
|
return {"error": "Invalid connection arguments"}
|
||||||
|
|
||||||
|
def execute_command(self, parsed_command: Dict[str, Any]):
|
||||||
|
command_dict: dict
|
||||||
|
if isinstance(parsed_command, Tree):
|
||||||
|
command_dict = parsed_command.children[0]
|
||||||
|
else:
|
||||||
|
if parsed_command["type"] == "error":
|
||||||
|
print(f"Error: {parsed_command['message']}")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
command_dict = parsed_command
|
||||||
|
|
||||||
|
# print(f"Parsed command: {command_dict}")
|
||||||
|
run_command(self.ragflow_client, command_dict)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
cli = RAGFlowCLI()
|
||||||
|
|
||||||
|
args = cli.parse_connection_args(sys.argv)
|
||||||
|
if "error" in args:
|
||||||
|
print("Error: Invalid connection arguments")
|
||||||
|
return
|
||||||
|
|
||||||
|
if "command" in args:
|
||||||
|
# single command mode
|
||||||
|
# for user mode, api key or password is ok
|
||||||
|
# for admin mode, only password
|
||||||
|
if "password" not in args:
|
||||||
|
print("Error: password is missing")
|
||||||
|
return
|
||||||
|
|
||||||
|
cli.run_single_command(args)
|
||||||
|
else:
|
||||||
|
cli.run_interactive(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1492
admin/client/ragflow_client.py
Normal file
1492
admin/client/ragflow_client.py
Normal file
File diff suppressed because it is too large
Load Diff
65
admin/client/user.py
Normal file
65
admin/client/user.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from http_client import HttpClient
|
||||||
|
|
||||||
|
|
||||||
|
class AuthException(Exception):
|
||||||
|
def __init__(self, message, code=401):
|
||||||
|
super().__init__(message)
|
||||||
|
self.code = code
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_password(password_plain: str) -> str:
|
||||||
|
try:
|
||||||
|
from api.utils.crypt import crypt
|
||||||
|
except Exception as exc:
|
||||||
|
raise AuthException(
|
||||||
|
"Password encryption unavailable; install pycryptodomex (uv sync --python 3.12 --group test)."
|
||||||
|
) from exc
|
||||||
|
return crypt(password_plain)
|
||||||
|
|
||||||
|
|
||||||
|
def register_user(client: HttpClient, email: str, nickname: str, password: str) -> None:
|
||||||
|
password_enc = encrypt_password(password)
|
||||||
|
payload = {"email": email, "nickname": nickname, "password": password_enc}
|
||||||
|
res = client.request_json("POST", "/user/register", use_api_base=False, auth_kind=None, json_body=payload)
|
||||||
|
if res.get("code") == 0:
|
||||||
|
return
|
||||||
|
msg = res.get("message", "")
|
||||||
|
if "has already registered" in msg:
|
||||||
|
return
|
||||||
|
raise AuthException(f"Register failed: {msg}")
|
||||||
|
|
||||||
|
|
||||||
|
def login_user(client: HttpClient, server_type: str, email: str, password: str) -> str:
|
||||||
|
password_enc = encrypt_password(password)
|
||||||
|
payload = {"email": email, "password": password_enc}
|
||||||
|
if server_type == "admin":
|
||||||
|
response = client.request("POST", "/admin/login", use_api_base=True, auth_kind=None, json_body=payload)
|
||||||
|
else:
|
||||||
|
response = client.request("POST", "/user/login", use_api_base=False, auth_kind=None, json_body=payload)
|
||||||
|
try:
|
||||||
|
res = response.json()
|
||||||
|
except Exception as exc:
|
||||||
|
raise AuthException(f"Login failed: invalid JSON response ({exc})") from exc
|
||||||
|
if res.get("code") != 0:
|
||||||
|
raise AuthException(f"Login failed: {res.get('message')}")
|
||||||
|
token = response.headers.get("Authorization")
|
||||||
|
if not token:
|
||||||
|
raise AuthException("Login failed: missing Authorization header")
|
||||||
|
return token
|
||||||
2
admin/client/uv.lock
generated
2
admin/client/uv.lock
generated
@ -196,7 +196,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ragflow-cli"
|
name = "ragflow-cli"
|
||||||
version = "0.22.1"
|
version = "0.23.1"
|
||||||
source = { virtual = "." }
|
source = { virtual = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "beartype" },
|
{ name = "beartype" },
|
||||||
|
|||||||
@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import time
|
||||||
|
start_ts = time.time()
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
import faulthandler
|
import faulthandler
|
||||||
@ -66,7 +68,7 @@ if __name__ == '__main__':
|
|||||||
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
|
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("RAGFlow Admin service start...")
|
logging.info(f"RAGFlow admin is ready after {time.time() - start_ts}s initialization.")
|
||||||
run_simple(
|
run_simple(
|
||||||
hostname="0.0.0.0",
|
hostname="0.0.0.0",
|
||||||
port=9381,
|
port=9381,
|
||||||
|
|||||||
@ -15,24 +15,33 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import secrets
|
import secrets
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from flask import Blueprint, request
|
from common.time_utils import current_timestamp, datetime_format
|
||||||
|
from datetime import datetime
|
||||||
|
from flask import Blueprint, Response, request
|
||||||
from flask_login import current_user, login_required, logout_user
|
from flask_login import current_user, login_required, logout_user
|
||||||
|
|
||||||
from auth import login_verify, login_admin, check_admin_auth
|
from auth import login_verify, login_admin, check_admin_auth
|
||||||
from responses import success_response, error_response
|
from responses import success_response, error_response
|
||||||
from services import UserMgr, ServiceMgr, UserServiceMgr
|
from services import UserMgr, ServiceMgr, UserServiceMgr, SettingsMgr, ConfigMgr, EnvironmentsMgr
|
||||||
from roles import RoleMgr
|
from roles import RoleMgr
|
||||||
from api.common.exceptions import AdminException
|
from api.common.exceptions import AdminException
|
||||||
from common.versions import get_ragflow_version
|
from common.versions import get_ragflow_version
|
||||||
|
from api.utils.api_utils import generate_confirmation_token
|
||||||
|
|
||||||
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
|
admin_bp = Blueprint("admin", __name__, url_prefix="/api/v1/admin")
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/login', methods=['POST'])
|
@admin_bp.route("/ping", methods=["GET"])
|
||||||
|
def ping():
|
||||||
|
return success_response("PONG")
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/login", methods=["POST"])
|
||||||
def login():
|
def login():
|
||||||
if not request.json:
|
if not request.json:
|
||||||
return error_response('Authorize admin failed.' ,400)
|
return error_response("Authorize admin failed.", 400)
|
||||||
try:
|
try:
|
||||||
email = request.json.get("email", "")
|
email = request.json.get("email", "")
|
||||||
password = request.json.get("password", "")
|
password = request.json.get("password", "")
|
||||||
@ -41,7 +50,7 @@ def login():
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/logout', methods=['GET'])
|
@admin_bp.route("/logout", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
def logout():
|
def logout():
|
||||||
try:
|
try:
|
||||||
@ -53,7 +62,7 @@ def logout():
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/auth', methods=['GET'])
|
@admin_bp.route("/auth", methods=["GET"])
|
||||||
@login_verify
|
@login_verify
|
||||||
def auth_admin():
|
def auth_admin():
|
||||||
try:
|
try:
|
||||||
@ -62,7 +71,7 @@ def auth_admin():
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/users', methods=['GET'])
|
@admin_bp.route("/users", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def list_users():
|
def list_users():
|
||||||
@ -73,18 +82,18 @@ def list_users():
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/users', methods=['POST'])
|
@admin_bp.route("/users", methods=["POST"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def create_user():
|
def create_user():
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data or 'username' not in data or 'password' not in data:
|
if not data or "username" not in data or "password" not in data:
|
||||||
return error_response("Username and password are required", 400)
|
return error_response("Username and password are required", 400)
|
||||||
|
|
||||||
username = data['username']
|
username = data["username"]
|
||||||
password = data['password']
|
password = data["password"]
|
||||||
role = data.get('role', 'user')
|
role = data.get("role", "user")
|
||||||
|
|
||||||
res = UserMgr.create_user(username, password, role)
|
res = UserMgr.create_user(username, password, role)
|
||||||
if res["success"]:
|
if res["success"]:
|
||||||
@ -100,7 +109,7 @@ def create_user():
|
|||||||
return error_response(str(e))
|
return error_response(str(e))
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/users/<username>', methods=['DELETE'])
|
@admin_bp.route("/users/<username>", methods=["DELETE"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def delete_user(username):
|
def delete_user(username):
|
||||||
@ -117,16 +126,16 @@ def delete_user(username):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/users/<username>/password', methods=['PUT'])
|
@admin_bp.route("/users/<username>/password", methods=["PUT"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def change_password(username):
|
def change_password(username):
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data or 'new_password' not in data:
|
if not data or "new_password" not in data:
|
||||||
return error_response("New password is required", 400)
|
return error_response("New password is required", 400)
|
||||||
|
|
||||||
new_password = data['new_password']
|
new_password = data["new_password"]
|
||||||
msg = UserMgr.update_user_password(username, new_password)
|
msg = UserMgr.update_user_password(username, new_password)
|
||||||
return success_response(None, msg)
|
return success_response(None, msg)
|
||||||
|
|
||||||
@ -136,15 +145,15 @@ def change_password(username):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/users/<username>/activate', methods=['PUT'])
|
@admin_bp.route("/users/<username>/activate", methods=["PUT"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def alter_user_activate_status(username):
|
def alter_user_activate_status(username):
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data or 'activate_status' not in data:
|
if not data or "activate_status" not in data:
|
||||||
return error_response("Activation status is required", 400)
|
return error_response("Activation status is required", 400)
|
||||||
activate_status = data['activate_status']
|
activate_status = data["activate_status"]
|
||||||
msg = UserMgr.update_user_activate_status(username, activate_status)
|
msg = UserMgr.update_user_activate_status(username, activate_status)
|
||||||
return success_response(None, msg)
|
return success_response(None, msg)
|
||||||
except AdminException as e:
|
except AdminException as e:
|
||||||
@ -153,7 +162,39 @@ def alter_user_activate_status(username):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/users/<username>', methods=['GET'])
|
@admin_bp.route("/users/<username>/admin", methods=["PUT"])
|
||||||
|
@login_required
|
||||||
|
@check_admin_auth
|
||||||
|
def grant_admin(username):
|
||||||
|
try:
|
||||||
|
if current_user.email == username:
|
||||||
|
return error_response(f"can't grant current user: {username}", 409)
|
||||||
|
msg = UserMgr.grant_admin(username)
|
||||||
|
return success_response(None, msg)
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/users/<username>/admin", methods=["DELETE"])
|
||||||
|
@login_required
|
||||||
|
@check_admin_auth
|
||||||
|
def revoke_admin(username):
|
||||||
|
try:
|
||||||
|
if current_user.email == username:
|
||||||
|
return error_response(f"can't grant current user: {username}", 409)
|
||||||
|
msg = UserMgr.revoke_admin(username)
|
||||||
|
return success_response(None, msg)
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/users/<username>", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def get_user_details(username):
|
def get_user_details(username):
|
||||||
@ -167,7 +208,7 @@ def get_user_details(username):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/users/<username>/datasets', methods=['GET'])
|
@admin_bp.route("/users/<username>/datasets", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def get_user_datasets(username):
|
def get_user_datasets(username):
|
||||||
@ -181,7 +222,7 @@ def get_user_datasets(username):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/users/<username>/agents', methods=['GET'])
|
@admin_bp.route("/users/<username>/agents", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def get_user_agents(username):
|
def get_user_agents(username):
|
||||||
@ -195,7 +236,7 @@ def get_user_agents(username):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/services', methods=['GET'])
|
@admin_bp.route("/services", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def get_services():
|
def get_services():
|
||||||
@ -206,7 +247,7 @@ def get_services():
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/service_types/<service_type>', methods=['GET'])
|
@admin_bp.route("/service_types/<service_type>", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def get_services_by_type(service_type_str):
|
def get_services_by_type(service_type_str):
|
||||||
@ -217,7 +258,7 @@ def get_services_by_type(service_type_str):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/services/<service_id>', methods=['GET'])
|
@admin_bp.route("/services/<service_id>", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def get_service(service_id):
|
def get_service(service_id):
|
||||||
@ -228,7 +269,7 @@ def get_service(service_id):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/services/<service_id>', methods=['DELETE'])
|
@admin_bp.route("/services/<service_id>", methods=["DELETE"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def shutdown_service(service_id):
|
def shutdown_service(service_id):
|
||||||
@ -239,7 +280,7 @@ def shutdown_service(service_id):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/services/<service_id>', methods=['PUT'])
|
@admin_bp.route("/services/<service_id>", methods=["PUT"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def restart_service(service_id):
|
def restart_service(service_id):
|
||||||
@ -250,38 +291,38 @@ def restart_service(service_id):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/roles', methods=['POST'])
|
@admin_bp.route("/roles", methods=["POST"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def create_role():
|
def create_role():
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data or 'role_name' not in data:
|
if not data or "role_name" not in data:
|
||||||
return error_response("Role name is required", 400)
|
return error_response("Role name is required", 400)
|
||||||
role_name: str = data['role_name']
|
role_name: str = data["role_name"]
|
||||||
description: str = data['description']
|
description: str = data["description"]
|
||||||
res = RoleMgr.create_role(role_name, description)
|
res = RoleMgr.create_role(role_name, description)
|
||||||
return success_response(res)
|
return success_response(res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/roles/<role_name>', methods=['PUT'])
|
@admin_bp.route("/roles/<role_name>", methods=["PUT"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def update_role(role_name: str):
|
def update_role(role_name: str):
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data or 'description' not in data:
|
if not data or "description" not in data:
|
||||||
return error_response("Role description is required", 400)
|
return error_response("Role description is required", 400)
|
||||||
description: str = data['description']
|
description: str = data["description"]
|
||||||
res = RoleMgr.update_role_description(role_name, description)
|
res = RoleMgr.update_role_description(role_name, description)
|
||||||
return success_response(res)
|
return success_response(res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/roles/<role_name>', methods=['DELETE'])
|
@admin_bp.route("/roles/<role_name>", methods=["DELETE"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def delete_role(role_name: str):
|
def delete_role(role_name: str):
|
||||||
@ -292,7 +333,7 @@ def delete_role(role_name: str):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/roles', methods=['GET'])
|
@admin_bp.route("/roles", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def list_roles():
|
def list_roles():
|
||||||
@ -303,7 +344,7 @@ def list_roles():
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/roles/<role_name>/permission', methods=['GET'])
|
@admin_bp.route("/roles/<role_name>/permission", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def get_role_permission(role_name: str):
|
def get_role_permission(role_name: str):
|
||||||
@ -314,54 +355,54 @@ def get_role_permission(role_name: str):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/roles/<role_name>/permission', methods=['POST'])
|
@admin_bp.route("/roles/<role_name>/permission", methods=["POST"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def grant_role_permission(role_name: str):
|
def grant_role_permission(role_name: str):
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data or 'actions' not in data or 'resource' not in data:
|
if not data or "actions" not in data or "resource" not in data:
|
||||||
return error_response("Permission is required", 400)
|
return error_response("Permission is required", 400)
|
||||||
actions: list = data['actions']
|
actions: list = data["actions"]
|
||||||
resource: str = data['resource']
|
resource: str = data["resource"]
|
||||||
res = RoleMgr.grant_role_permission(role_name, actions, resource)
|
res = RoleMgr.grant_role_permission(role_name, actions, resource)
|
||||||
return success_response(res)
|
return success_response(res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/roles/<role_name>/permission', methods=['DELETE'])
|
@admin_bp.route("/roles/<role_name>/permission", methods=["DELETE"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def revoke_role_permission(role_name: str):
|
def revoke_role_permission(role_name: str):
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data or 'actions' not in data or 'resource' not in data:
|
if not data or "actions" not in data or "resource" not in data:
|
||||||
return error_response("Permission is required", 400)
|
return error_response("Permission is required", 400)
|
||||||
actions: list = data['actions']
|
actions: list = data["actions"]
|
||||||
resource: str = data['resource']
|
resource: str = data["resource"]
|
||||||
res = RoleMgr.revoke_role_permission(role_name, actions, resource)
|
res = RoleMgr.revoke_role_permission(role_name, actions, resource)
|
||||||
return success_response(res)
|
return success_response(res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/users/<user_name>/role', methods=['PUT'])
|
@admin_bp.route("/users/<user_name>/role", methods=["PUT"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def update_user_role(user_name: str):
|
def update_user_role(user_name: str):
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data or 'role_name' not in data:
|
if not data or "role_name" not in data:
|
||||||
return error_response("Role name is required", 400)
|
return error_response("Role name is required", 400)
|
||||||
role_name: str = data['role_name']
|
role_name: str = data["role_name"]
|
||||||
res = RoleMgr.update_user_role(user_name, role_name)
|
res = RoleMgr.update_user_role(user_name, role_name)
|
||||||
return success_response(res)
|
return success_response(res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/users/<user_name>/permission', methods=['GET'])
|
@admin_bp.route("/users/<user_name>/permission", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def get_user_permission(user_name: str):
|
def get_user_permission(user_name: str):
|
||||||
@ -371,7 +412,140 @@ def get_user_permission(user_name: str):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
@admin_bp.route('/version', methods=['GET'])
|
|
||||||
|
@admin_bp.route("/variables", methods=["PUT"])
|
||||||
|
@login_required
|
||||||
|
@check_admin_auth
|
||||||
|
def set_variable():
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
if not data and "var_name" not in data:
|
||||||
|
return error_response("Var name is required", 400)
|
||||||
|
|
||||||
|
if "var_value" not in data:
|
||||||
|
return error_response("Var value is required", 400)
|
||||||
|
var_name: str = data["var_name"]
|
||||||
|
var_value: str = data["var_value"]
|
||||||
|
|
||||||
|
SettingsMgr.update_by_name(var_name, var_value)
|
||||||
|
return success_response(None, "Set variable successfully")
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(str(e), 400)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/variables", methods=["GET"])
|
||||||
|
@login_required
|
||||||
|
@check_admin_auth
|
||||||
|
def get_variable():
|
||||||
|
try:
|
||||||
|
if request.content_length is None or request.content_length == 0:
|
||||||
|
# list variables
|
||||||
|
res = list(SettingsMgr.get_all())
|
||||||
|
return success_response(res)
|
||||||
|
|
||||||
|
# get var
|
||||||
|
data = request.get_json()
|
||||||
|
if not data and "var_name" not in data:
|
||||||
|
return error_response("Var name is required", 400)
|
||||||
|
var_name: str = data["var_name"]
|
||||||
|
res = SettingsMgr.get_by_name(var_name)
|
||||||
|
return success_response(res)
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(str(e), 400)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/configs", methods=["GET"])
|
||||||
|
@login_required
|
||||||
|
@check_admin_auth
|
||||||
|
def get_config():
|
||||||
|
try:
|
||||||
|
res = list(ConfigMgr.get_all())
|
||||||
|
return success_response(res)
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(str(e), 400)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/environments", methods=["GET"])
|
||||||
|
@login_required
|
||||||
|
@check_admin_auth
|
||||||
|
def get_environments():
|
||||||
|
try:
|
||||||
|
res = list(EnvironmentsMgr.get_all())
|
||||||
|
return success_response(res)
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(str(e), 400)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/users/<username>/keys", methods=["POST"])
|
||||||
|
@login_required
|
||||||
|
@check_admin_auth
|
||||||
|
def generate_user_api_key(username: str) -> tuple[Response, int]:
|
||||||
|
try:
|
||||||
|
user_details: list[dict[str, Any]] = UserMgr.get_user_details(username)
|
||||||
|
if not user_details:
|
||||||
|
return error_response("User not found!", 404)
|
||||||
|
tenants: list[dict[str, Any]] = UserServiceMgr.get_user_tenants(username)
|
||||||
|
if not tenants:
|
||||||
|
return error_response("Tenant not found!", 404)
|
||||||
|
tenant_id: str = tenants[0]["tenant_id"]
|
||||||
|
key: str = generate_confirmation_token()
|
||||||
|
obj: dict[str, Any] = {
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"token": key,
|
||||||
|
"beta": generate_confirmation_token().replace("ragflow-", "")[:32],
|
||||||
|
"create_time": current_timestamp(),
|
||||||
|
"create_date": datetime_format(datetime.now()),
|
||||||
|
"update_time": None,
|
||||||
|
"update_date": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not UserMgr.save_api_key(obj):
|
||||||
|
return error_response("Failed to generate API key!", 500)
|
||||||
|
return success_response(obj, "API key generated successfully")
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/users/<username>/keys", methods=["GET"])
|
||||||
|
@login_required
|
||||||
|
@check_admin_auth
|
||||||
|
def get_user_api_keys(username: str) -> tuple[Response, int]:
|
||||||
|
try:
|
||||||
|
api_keys: list[dict[str, Any]] = UserMgr.get_user_api_key(username)
|
||||||
|
return success_response(api_keys, "Get user API keys")
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/users/<username>/keys/<key>", methods=["DELETE"])
|
||||||
|
@login_required
|
||||||
|
@check_admin_auth
|
||||||
|
def delete_user_api_key(username: str, key: str) -> tuple[Response, int]:
|
||||||
|
try:
|
||||||
|
deleted = UserMgr.delete_api_key(username, key)
|
||||||
|
if deleted:
|
||||||
|
return success_response(None, "API key deleted successfully")
|
||||||
|
else:
|
||||||
|
return error_response("API key not found or could not be deleted", 404)
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route("/version", methods=["GET"])
|
||||||
@login_required
|
@login_required
|
||||||
@check_admin_auth
|
@check_admin_auth
|
||||||
def show_version():
|
def show_version():
|
||||||
|
|||||||
@ -13,15 +13,22 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from werkzeug.security import check_password_hash
|
from werkzeug.security import check_password_hash
|
||||||
from common.constants import ActiveEnum
|
from common.constants import ActiveEnum
|
||||||
from api.db.services import UserService
|
from api.db.services import UserService
|
||||||
from api.db.joint_services.user_account_service import create_new_user, delete_user_data
|
from api.db.joint_services.user_account_service import create_new_user, delete_user_data
|
||||||
from api.db.services.canvas_service import UserCanvasService
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.db.services.system_settings_service import SystemSettingsService
|
||||||
|
from api.db.services.api_service import APITokenService
|
||||||
|
from api.db.db_models import APIToken
|
||||||
from api.utils.crypt import decrypt
|
from api.utils.crypt import decrypt
|
||||||
from api.utils import health_utils
|
from api.utils import health_utils
|
||||||
|
|
||||||
@ -35,13 +42,15 @@ class UserMgr:
|
|||||||
users = UserService.get_all_users()
|
users = UserService.get_all_users()
|
||||||
result = []
|
result = []
|
||||||
for user in users:
|
for user in users:
|
||||||
result.append({
|
result.append(
|
||||||
'email': user.email,
|
{
|
||||||
'nickname': user.nickname,
|
"email": user.email,
|
||||||
'create_date': user.create_date,
|
"nickname": user.nickname,
|
||||||
'is_active': user.is_active,
|
"create_date": user.create_date,
|
||||||
'is_superuser': user.is_superuser,
|
"is_active": user.is_active,
|
||||||
})
|
"is_superuser": user.is_superuser,
|
||||||
|
}
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -50,19 +59,21 @@ class UserMgr:
|
|||||||
users = UserService.query_user_by_email(username)
|
users = UserService.query_user_by_email(username)
|
||||||
result = []
|
result = []
|
||||||
for user in users:
|
for user in users:
|
||||||
result.append({
|
result.append(
|
||||||
'avatar': user.avatar,
|
{
|
||||||
'email': user.email,
|
"avatar": user.avatar,
|
||||||
'language': user.language,
|
"email": user.email,
|
||||||
'last_login_time': user.last_login_time,
|
"language": user.language,
|
||||||
'is_active': user.is_active,
|
"last_login_time": user.last_login_time,
|
||||||
'is_anonymous': user.is_anonymous,
|
"is_active": user.is_active,
|
||||||
'login_channel': user.login_channel,
|
"is_anonymous": user.is_anonymous,
|
||||||
'status': user.status,
|
"login_channel": user.login_channel,
|
||||||
'is_superuser': user.is_superuser,
|
"status": user.status,
|
||||||
'create_date': user.create_date,
|
"is_superuser": user.is_superuser,
|
||||||
'update_date': user.update_date
|
"create_date": user.create_date,
|
||||||
})
|
"update_date": user.update_date,
|
||||||
|
}
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -124,8 +135,8 @@ class UserMgr:
|
|||||||
# format activate_status before handle
|
# format activate_status before handle
|
||||||
_activate_status = activate_status.lower()
|
_activate_status = activate_status.lower()
|
||||||
target_status = {
|
target_status = {
|
||||||
'on': ActiveEnum.ACTIVE.value,
|
"on": ActiveEnum.ACTIVE.value,
|
||||||
'off': ActiveEnum.INACTIVE.value,
|
"off": ActiveEnum.INACTIVE.value,
|
||||||
}.get(_activate_status)
|
}.get(_activate_status)
|
||||||
if not target_status:
|
if not target_status:
|
||||||
raise AdminException(f"Invalid activate_status: {activate_status}")
|
raise AdminException(f"Invalid activate_status: {activate_status}")
|
||||||
@ -135,9 +146,84 @@ class UserMgr:
|
|||||||
UserService.update_user(usr.id, {"is_active": target_status})
|
UserService.update_user(usr.id, {"is_active": target_status})
|
||||||
return f"Turn {_activate_status} user activate status successfully!"
|
return f"Turn {_activate_status} user activate status successfully!"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_api_key(username: str) -> list[dict[str, Any]]:
|
||||||
|
# use email to find user. check exist and unique.
|
||||||
|
user_list: list[Any] = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"More than one user with username '{username}' found!")
|
||||||
|
|
||||||
|
usr: Any = user_list[0]
|
||||||
|
# tenant_id is typically the same as user_id for the owner tenant
|
||||||
|
tenant_id: str = usr.id
|
||||||
|
|
||||||
|
# Query all API keys for this tenant
|
||||||
|
api_keys: Any = APITokenService.query(tenant_id=tenant_id)
|
||||||
|
|
||||||
|
result: list[dict[str, Any]] = []
|
||||||
|
for key in api_keys:
|
||||||
|
result.append(key.to_dict())
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_api_key(api_key: dict[str, Any]) -> bool:
|
||||||
|
return APITokenService.save(**api_key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_api_key(username: str, key: str) -> bool:
|
||||||
|
# use email to find user. check exist and unique.
|
||||||
|
user_list: list[Any] = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
|
||||||
|
usr: Any = user_list[0]
|
||||||
|
# tenant_id is typically the same as user_id for the owner tenant
|
||||||
|
tenant_id: str = usr.id
|
||||||
|
|
||||||
|
# Delete the API key
|
||||||
|
deleted_count: int = APITokenService.filter_delete([APIToken.tenant_id == tenant_id, APIToken.token == key])
|
||||||
|
return deleted_count > 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def grant_admin(username: str):
|
||||||
|
# use email to find user. check exist and unique.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
|
||||||
|
# check activate status different from new
|
||||||
|
usr = user_list[0]
|
||||||
|
if usr.is_superuser:
|
||||||
|
return f"{usr} is already superuser!"
|
||||||
|
# update is_active
|
||||||
|
UserService.update_user(usr.id, {"is_superuser": True})
|
||||||
|
return "Grant successfully!"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def revoke_admin(username: str):
|
||||||
|
# use email to find user. check exist and unique.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# check activate status different from new
|
||||||
|
usr = user_list[0]
|
||||||
|
if not usr.is_superuser:
|
||||||
|
return f"{usr} isn't superuser, yet!"
|
||||||
|
# update is_active
|
||||||
|
UserService.update_user(usr.id, {"is_superuser": False})
|
||||||
|
return "Revoke successfully!"
|
||||||
|
|
||||||
|
|
||||||
class UserServiceMgr:
|
class UserServiceMgr:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_user_datasets(username):
|
def get_user_datasets(username):
|
||||||
# use email to find user.
|
# use email to find user.
|
||||||
@ -167,35 +253,43 @@ class UserServiceMgr:
|
|||||||
tenant_ids = [m["tenant_id"] for m in tenants]
|
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||||
# filter permitted agents and owned agents
|
# filter permitted agents and owned agents
|
||||||
res = UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
|
res = UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
|
||||||
return [{
|
return [{"title": r["title"], "permission": r["permission"], "canvas_category": r["canvas_category"].split("_")[0], "avatar": r["avatar"]} for r in res]
|
||||||
'title': r['title'],
|
|
||||||
'permission': r['permission'],
|
@staticmethod
|
||||||
'canvas_category': r['canvas_category'].split('_')[0],
|
def get_user_tenants(email: str) -> list[dict[str, Any]]:
|
||||||
'avatar': r['avatar']
|
users: list[Any] = UserService.query_user_by_email(email)
|
||||||
} for r in res]
|
if not users:
|
||||||
|
raise UserNotFoundError(email)
|
||||||
|
user: Any = users[0]
|
||||||
|
|
||||||
|
tenants: list[dict[str, Any]] = UserTenantService.get_tenants_by_user_id(user.id)
|
||||||
|
return tenants
|
||||||
|
|
||||||
|
|
||||||
class ServiceMgr:
|
class ServiceMgr:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all_services():
|
def get_all_services():
|
||||||
|
doc_engine = os.getenv("DOC_ENGINE", "elasticsearch")
|
||||||
result = []
|
result = []
|
||||||
configs = SERVICE_CONFIGS.configs
|
configs = SERVICE_CONFIGS.configs
|
||||||
for service_id, config in enumerate(configs):
|
for service_id, config in enumerate(configs):
|
||||||
config_dict = config.to_dict()
|
config_dict = config.to_dict()
|
||||||
|
if config_dict["service_type"] == "retrieval":
|
||||||
|
if config_dict["extra"]["retrieval_type"] != doc_engine:
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
service_detail = ServiceMgr.get_service_details(service_id)
|
service_detail = ServiceMgr.get_service_details(service_id)
|
||||||
if "status" in service_detail:
|
if "status" in service_detail:
|
||||||
config_dict['status'] = service_detail['status']
|
config_dict["status"] = service_detail["status"]
|
||||||
else:
|
else:
|
||||||
config_dict['status'] = 'timeout'
|
config_dict["status"] = "timeout"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Can't get service details, error: {e}")
|
logging.warning(f"Can't get service details, error: {e}")
|
||||||
config_dict['status'] = 'timeout'
|
config_dict["status"] = "timeout"
|
||||||
if not config_dict['host']:
|
if not config_dict["host"]:
|
||||||
config_dict['host'] = '-'
|
config_dict["host"] = "-"
|
||||||
if not config_dict['port']:
|
if not config_dict["port"]:
|
||||||
config_dict['port'] = '-'
|
config_dict["port"] = "-"
|
||||||
result.append(config_dict)
|
result.append(config_dict)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -211,11 +305,18 @@ class ServiceMgr:
|
|||||||
raise AdminException(f"invalid service_index: {service_idx}")
|
raise AdminException(f"invalid service_index: {service_idx}")
|
||||||
|
|
||||||
service_config = configs[service_idx]
|
service_config = configs[service_idx]
|
||||||
service_info = {'name': service_config.name, 'detail_func_name': service_config.detail_func_name}
|
|
||||||
|
|
||||||
detail_func = getattr(health_utils, service_info.get('detail_func_name'))
|
# exclude retrieval service if retrieval_type is not matched
|
||||||
|
doc_engine = os.getenv("DOC_ENGINE", "elasticsearch")
|
||||||
|
if service_config.service_type == "retrieval":
|
||||||
|
if service_config.retrieval_type != doc_engine:
|
||||||
|
raise AdminException(f"invalid service_index: {service_idx}")
|
||||||
|
|
||||||
|
service_info = {"name": service_config.name, "detail_func_name": service_config.detail_func_name}
|
||||||
|
|
||||||
|
detail_func = getattr(health_utils, service_info.get("detail_func_name"))
|
||||||
res = detail_func()
|
res = detail_func()
|
||||||
res.update({'service_name': service_info.get('name')})
|
res.update({"service_name": service_info.get("name")})
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -225,3 +326,84 @@ class ServiceMgr:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def restart_service(service_id: int):
|
def restart_service(service_id: int):
|
||||||
raise AdminException("restart_service: not implemented")
|
raise AdminException("restart_service: not implemented")
|
||||||
|
|
||||||
|
|
||||||
|
class SettingsMgr:
|
||||||
|
@staticmethod
|
||||||
|
def get_all():
|
||||||
|
settings = SystemSettingsService.get_all()
|
||||||
|
result = []
|
||||||
|
for setting in settings:
|
||||||
|
result.append(
|
||||||
|
{
|
||||||
|
"name": setting.name,
|
||||||
|
"source": setting.source,
|
||||||
|
"data_type": setting.data_type,
|
||||||
|
"value": setting.value,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_by_name(name: str):
|
||||||
|
settings = SystemSettingsService.get_by_name(name)
|
||||||
|
if len(settings) == 0:
|
||||||
|
raise AdminException(f"Can't get setting: {name}")
|
||||||
|
result = []
|
||||||
|
for setting in settings:
|
||||||
|
result.append(
|
||||||
|
{
|
||||||
|
"name": setting.name,
|
||||||
|
"source": setting.source,
|
||||||
|
"data_type": setting.data_type,
|
||||||
|
"value": setting.value,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_by_name(name: str, value: str):
|
||||||
|
settings = SystemSettingsService.get_by_name(name)
|
||||||
|
if len(settings) == 1:
|
||||||
|
setting = settings[0]
|
||||||
|
setting.value = value
|
||||||
|
setting_dict = setting.to_dict()
|
||||||
|
SystemSettingsService.update_by_name(name, setting_dict)
|
||||||
|
elif len(settings) > 1:
|
||||||
|
raise AdminException(f"Can't update more than 1 setting: {name}")
|
||||||
|
else:
|
||||||
|
raise AdminException(f"No setting: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigMgr:
|
||||||
|
@staticmethod
|
||||||
|
def get_all():
|
||||||
|
result = []
|
||||||
|
configs = SERVICE_CONFIGS.configs
|
||||||
|
for config in configs:
|
||||||
|
config_dict = config.to_dict()
|
||||||
|
result.append(config_dict)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class EnvironmentsMgr:
|
||||||
|
@staticmethod
|
||||||
|
def get_all():
|
||||||
|
result = []
|
||||||
|
|
||||||
|
env_kv = {"env": "DOC_ENGINE", "value": os.getenv("DOC_ENGINE")}
|
||||||
|
result.append(env_kv)
|
||||||
|
|
||||||
|
env_kv = {"env": "DEFAULT_SUPERUSER_EMAIL", "value": os.getenv("DEFAULT_SUPERUSER_EMAIL", "admin@ragflow.io")}
|
||||||
|
result.append(env_kv)
|
||||||
|
|
||||||
|
env_kv = {"env": "DB_TYPE", "value": os.getenv("DB_TYPE", "mysql")}
|
||||||
|
result.append(env_kv)
|
||||||
|
|
||||||
|
env_kv = {"env": "DEVICE", "value": os.getenv("DEVICE", "cpu")}
|
||||||
|
result.append(env_kv)
|
||||||
|
|
||||||
|
env_kv = {"env": "STORAGE_IMPL", "value": os.getenv("STORAGE_IMPL", "MINIO")}
|
||||||
|
result.append(env_kv)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
@ -160,7 +160,7 @@ class Graph:
|
|||||||
return self._tenant_id
|
return self._tenant_id
|
||||||
|
|
||||||
def get_value_with_variable(self,value: str) -> Any:
|
def get_value_with_variable(self,value: str) -> Any:
|
||||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
||||||
out_parts = []
|
out_parts = []
|
||||||
last = 0
|
last = 0
|
||||||
|
|
||||||
@ -278,7 +278,7 @@ class Graph:
|
|||||||
|
|
||||||
class Canvas(Graph):
|
class Canvas(Graph):
|
||||||
|
|
||||||
def __init__(self, dsl: str, tenant_id=None, task_id=None):
|
def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None):
|
||||||
self.globals = {
|
self.globals = {
|
||||||
"sys.query": "",
|
"sys.query": "",
|
||||||
"sys.user_id": tenant_id,
|
"sys.user_id": tenant_id,
|
||||||
@ -287,6 +287,7 @@ class Canvas(Graph):
|
|||||||
}
|
}
|
||||||
self.variables = {}
|
self.variables = {}
|
||||||
super().__init__(dsl, tenant_id, task_id)
|
super().__init__(dsl, tenant_id, task_id)
|
||||||
|
self._id = canvas_id
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
super().load()
|
super().load()
|
||||||
@ -368,8 +369,13 @@ class Canvas(Graph):
|
|||||||
|
|
||||||
if kwargs.get("webhook_payload"):
|
if kwargs.get("webhook_payload"):
|
||||||
for k, cpn in self.components.items():
|
for k, cpn in self.components.items():
|
||||||
if self.components[k]["obj"].component_name.lower() == "webhook":
|
if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook":
|
||||||
for kk, vv in kwargs["webhook_payload"].items():
|
payload = kwargs.get("webhook_payload", {})
|
||||||
|
if "input" in payload:
|
||||||
|
self.components[k]["obj"].set_input_value("request", payload["input"])
|
||||||
|
for kk, vv in payload.items():
|
||||||
|
if kk == "input":
|
||||||
|
continue
|
||||||
self.components[k]["obj"].set_output(kk, vv)
|
self.components[k]["obj"].set_output(kk, vv)
|
||||||
|
|
||||||
for k in kwargs.keys():
|
for k in kwargs.keys():
|
||||||
@ -535,6 +541,8 @@ class Canvas(Graph):
|
|||||||
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
||||||
|
|
||||||
message_end = {}
|
message_end = {}
|
||||||
|
if cpn_obj.get_param("status"):
|
||||||
|
message_end["status"] = cpn_obj.get_param("status")
|
||||||
if isinstance(cpn_obj.output("attachment"), dict):
|
if isinstance(cpn_obj.output("attachment"), dict):
|
||||||
message_end["attachment"] = cpn_obj.output("attachment")
|
message_end["attachment"] = cpn_obj.output("attachment")
|
||||||
if cite:
|
if cite:
|
||||||
@ -714,6 +722,9 @@ class Canvas(Graph):
|
|||||||
def get_mode(self):
|
def get_mode(self):
|
||||||
return self.components["begin"]["obj"]._param.mode
|
return self.components["begin"]["obj"]._param.mode
|
||||||
|
|
||||||
|
def get_sys_query(self):
|
||||||
|
return self.globals.get("sys.query", "")
|
||||||
|
|
||||||
def set_global_param(self, **kwargs):
|
def set_global_param(self, **kwargs):
|
||||||
self.globals.update(kwargs)
|
self.globals.update(kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -29,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.mcp_server_service import MCPServerService
|
from api.db.services.mcp_server_service import MCPServerService
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
|
from rag.prompts.generator import next_step_async, COMPLETE_TASK, \
|
||||||
citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
|
citation_prompt, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
|
||||||
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||||
from agent.component.llm import LLMParam, LLM
|
from agent.component.llm import LLMParam, LLM
|
||||||
|
|
||||||
@ -84,9 +84,11 @@ class Agent(LLM, ToolBase):
|
|||||||
def __init__(self, canvas, id, param: LLMParam):
|
def __init__(self, canvas, id, param: LLMParam):
|
||||||
LLM.__init__(self, canvas, id, param)
|
LLM.__init__(self, canvas, id, param)
|
||||||
self.tools = {}
|
self.tools = {}
|
||||||
for cpn in self._param.tools:
|
for idx, cpn in enumerate(self._param.tools):
|
||||||
cpn = self._load_tool_obj(cpn)
|
cpn = self._load_tool_obj(cpn)
|
||||||
self.tools[cpn.get_meta()["function"]["name"]] = cpn
|
original_name = cpn.get_meta()["function"]["name"]
|
||||||
|
indexed_name = f"{original_name}_{idx}"
|
||||||
|
self.tools[indexed_name] = cpn
|
||||||
|
|
||||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id,
|
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id,
|
||||||
max_retries=self._param.max_retries,
|
max_retries=self._param.max_retries,
|
||||||
@ -94,7 +96,12 @@ class Agent(LLM, ToolBase):
|
|||||||
max_rounds=self._param.max_rounds,
|
max_rounds=self._param.max_rounds,
|
||||||
verbose_tool_use=True
|
verbose_tool_use=True
|
||||||
)
|
)
|
||||||
self.tool_meta = [v.get_meta() for _,v in self.tools.items()]
|
self.tool_meta = []
|
||||||
|
for indexed_name, tool_obj in self.tools.items():
|
||||||
|
original_meta = tool_obj.get_meta()
|
||||||
|
indexed_meta = deepcopy(original_meta)
|
||||||
|
indexed_meta["function"]["name"] = indexed_name
|
||||||
|
self.tool_meta.append(indexed_meta)
|
||||||
|
|
||||||
for mcp in self._param.mcp:
|
for mcp in self._param.mcp:
|
||||||
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
|
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
|
||||||
@ -108,7 +115,8 @@ class Agent(LLM, ToolBase):
|
|||||||
|
|
||||||
def _load_tool_obj(self, cpn: dict) -> object:
|
def _load_tool_obj(self, cpn: dict) -> object:
|
||||||
from agent.component import component_class
|
from agent.component import component_class
|
||||||
param = component_class(cpn["component_name"] + "Param")()
|
tool_name = cpn["component_name"]
|
||||||
|
param = component_class(tool_name + "Param")()
|
||||||
param.update(cpn["params"])
|
param.update(cpn["params"])
|
||||||
try:
|
try:
|
||||||
param.check()
|
param.check()
|
||||||
@ -202,7 +210,7 @@ class Agent(LLM, ToolBase):
|
|||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
use_tools = []
|
use_tools = []
|
||||||
ans = ""
|
ans = ""
|
||||||
async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
|
async for delta_ans, _tk in self._react_with_tools_streamly_async_simple(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
|
||||||
if self.check_if_canceled("Agent processing"):
|
if self.check_if_canceled("Agent processing"):
|
||||||
return
|
return
|
||||||
ans += delta_ans
|
ans += delta_ans
|
||||||
@ -246,7 +254,7 @@ class Agent(LLM, ToolBase):
|
|||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
answer_without_toolcall = ""
|
answer_without_toolcall = ""
|
||||||
use_tools = []
|
use_tools = []
|
||||||
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
|
async for delta_ans, _ in self._react_with_tools_streamly_async_simple(prompt, msg, use_tools, user_defined_prompt):
|
||||||
if self.check_if_canceled("Agent streaming"):
|
if self.check_if_canceled("Agent streaming"):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -264,7 +272,7 @@ class Agent(LLM, ToolBase):
|
|||||||
if use_tools:
|
if use_tools:
|
||||||
self.set_output("use_tools", use_tools)
|
self.set_output("use_tools", use_tools)
|
||||||
|
|
||||||
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
async def _react_with_tools_streamly_async_simple(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||||
token_count = 0
|
token_count = 0
|
||||||
tool_metas = self.tool_meta
|
tool_metas = self.tool_meta
|
||||||
hist = deepcopy(history)
|
hist = deepcopy(history)
|
||||||
@ -276,6 +284,24 @@ class Agent(LLM, ToolBase):
|
|||||||
else:
|
else:
|
||||||
user_request = history[-1]["content"]
|
user_request = history[-1]["content"]
|
||||||
|
|
||||||
|
def build_task_desc(prompt: str, user_request: str, user_defined_prompt: dict | None = None) -> str:
|
||||||
|
"""Build a minimal task_desc by concatenating prompt, query, and tool schemas."""
|
||||||
|
user_defined_prompt = user_defined_prompt or {}
|
||||||
|
|
||||||
|
task_desc = (
|
||||||
|
"### Agent Prompt\n"
|
||||||
|
f"{prompt}\n\n"
|
||||||
|
"### User Request\n"
|
||||||
|
f"{user_request}\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_defined_prompt:
|
||||||
|
udp_json = json.dumps(user_defined_prompt, ensure_ascii=False, indent=2)
|
||||||
|
task_desc += "\n### User Defined Prompts\n" + udp_json + "\n"
|
||||||
|
|
||||||
|
return task_desc
|
||||||
|
|
||||||
|
|
||||||
async def use_tool_async(name, args):
|
async def use_tool_async(name, args):
|
||||||
nonlocal hist, use_tools, last_calling
|
nonlocal hist, use_tools, last_calling
|
||||||
logging.info(f"{last_calling=} == {name=}")
|
logging.info(f"{last_calling=} == {name=}")
|
||||||
@ -286,9 +312,6 @@ class Agent(LLM, ToolBase):
|
|||||||
"arguments": args,
|
"arguments": args,
|
||||||
"results": tool_response
|
"results": tool_response
|
||||||
})
|
})
|
||||||
# self.callback("add_memory", {}, "...")
|
|
||||||
#self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt)
|
|
||||||
|
|
||||||
return name, tool_response
|
return name, tool_response
|
||||||
|
|
||||||
async def complete():
|
async def complete():
|
||||||
@ -326,6 +349,21 @@ class Agent(LLM, ToolBase):
|
|||||||
|
|
||||||
self.callback("gen_citations", {}, txt, elapsed_time=timer()-st)
|
self.callback("gen_citations", {}, txt, elapsed_time=timer()-st)
|
||||||
|
|
||||||
|
def build_observation(tool_call_res: list[tuple]) -> str:
|
||||||
|
"""
|
||||||
|
Build a Observation from tool call results.
|
||||||
|
No LLM involved.
|
||||||
|
"""
|
||||||
|
if not tool_call_res:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = ["Observation:"]
|
||||||
|
for name, result in tool_call_res:
|
||||||
|
lines.append(f"[{name} result]")
|
||||||
|
lines.append(str(result))
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
def append_user_content(hist, content):
|
def append_user_content(hist, content):
|
||||||
if hist[-1]["role"] == "user":
|
if hist[-1]["role"] == "user":
|
||||||
hist[-1]["content"] += content
|
hist[-1]["content"] += content
|
||||||
@ -333,7 +371,7 @@ class Agent(LLM, ToolBase):
|
|||||||
hist.append({"role": "user", "content": content})
|
hist.append({"role": "user", "content": content})
|
||||||
|
|
||||||
st = timer()
|
st = timer()
|
||||||
task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
task_desc = build_task_desc(prompt, user_request, user_defined_prompt)
|
||||||
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||||
for _ in range(self._param.max_rounds + 1):
|
for _ in range(self._param.max_rounds + 1):
|
||||||
if self.check_if_canceled("Agent streaming"):
|
if self.check_if_canceled("Agent streaming"):
|
||||||
@ -364,7 +402,7 @@ class Agent(LLM, ToolBase):
|
|||||||
|
|
||||||
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
|
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
|
||||||
st = timer()
|
st = timer()
|
||||||
reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
|
reflection = build_observation(results)
|
||||||
append_user_content(hist, reflection)
|
append_user_content(hist, reflection)
|
||||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||||
|
|
||||||
@ -393,6 +431,135 @@ Respond immediately with your final comprehensive answer.
|
|||||||
async for txt, tkcnt in complete():
|
async for txt, tkcnt in complete():
|
||||||
yield txt, tkcnt
|
yield txt, tkcnt
|
||||||
|
|
||||||
|
# async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||||
|
# token_count = 0
|
||||||
|
# tool_metas = self.tool_meta
|
||||||
|
# hist = deepcopy(history)
|
||||||
|
# last_calling = ""
|
||||||
|
# if len(hist) > 3:
|
||||||
|
# st = timer()
|
||||||
|
# user_request = await full_question(messages=history, chat_mdl=self.chat_mdl)
|
||||||
|
# self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
||||||
|
# else:
|
||||||
|
# user_request = history[-1]["content"]
|
||||||
|
|
||||||
|
# async def use_tool_async(name, args):
|
||||||
|
# nonlocal hist, use_tools, last_calling
|
||||||
|
# logging.info(f"{last_calling=} == {name=}")
|
||||||
|
# last_calling = name
|
||||||
|
# tool_response = await self.toolcall_session.tool_call_async(name, args)
|
||||||
|
# use_tools.append({
|
||||||
|
# "name": name,
|
||||||
|
# "arguments": args,
|
||||||
|
# "results": tool_response
|
||||||
|
# })
|
||||||
|
# # self.callback("add_memory", {}, "...")
|
||||||
|
# #self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt)
|
||||||
|
|
||||||
|
# return name, tool_response
|
||||||
|
|
||||||
|
# async def complete():
|
||||||
|
# nonlocal hist
|
||||||
|
# need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
|
||||||
|
# if schema_prompt:
|
||||||
|
# need2cite = False
|
||||||
|
# cited = False
|
||||||
|
# if hist and hist[0]["role"] == "system":
|
||||||
|
# if schema_prompt:
|
||||||
|
# hist[0]["content"] += "\n" + schema_prompt
|
||||||
|
# if need2cite and len(hist) < 7:
|
||||||
|
# hist[0]["content"] += citation_prompt()
|
||||||
|
# cited = True
|
||||||
|
# yield "", token_count
|
||||||
|
|
||||||
|
# _hist = hist
|
||||||
|
# if len(hist) > 12:
|
||||||
|
# _hist = [hist[0], hist[1], *hist[-10:]]
|
||||||
|
# entire_txt = ""
|
||||||
|
# async for delta_ans in self._generate_streamly(_hist):
|
||||||
|
# if not need2cite or cited:
|
||||||
|
# yield delta_ans, 0
|
||||||
|
# entire_txt += delta_ans
|
||||||
|
# if not need2cite or cited:
|
||||||
|
# return
|
||||||
|
|
||||||
|
# st = timer()
|
||||||
|
# txt = ""
|
||||||
|
# async for delta_ans in self._gen_citations_async(entire_txt):
|
||||||
|
# if self.check_if_canceled("Agent streaming"):
|
||||||
|
# return
|
||||||
|
# yield delta_ans, 0
|
||||||
|
# txt += delta_ans
|
||||||
|
|
||||||
|
# self.callback("gen_citations", {}, txt, elapsed_time=timer()-st)
|
||||||
|
|
||||||
|
# def append_user_content(hist, content):
|
||||||
|
# if hist[-1]["role"] == "user":
|
||||||
|
# hist[-1]["content"] += content
|
||||||
|
# else:
|
||||||
|
# hist.append({"role": "user", "content": content})
|
||||||
|
|
||||||
|
# st = timer()
|
||||||
|
# task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||||
|
# self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||||
|
# for _ in range(self._param.max_rounds + 1):
|
||||||
|
# if self.check_if_canceled("Agent streaming"):
|
||||||
|
# return
|
||||||
|
# response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||||
|
# # self.callback("next_step", {}, str(response)[:256]+"...")
|
||||||
|
# token_count += tk or 0
|
||||||
|
# hist.append({"role": "assistant", "content": response})
|
||||||
|
# try:
|
||||||
|
# functions = json_repair.loads(re.sub(r"```.*", "", response))
|
||||||
|
# if not isinstance(functions, list):
|
||||||
|
# raise TypeError(f"List should be returned, but `{functions}`")
|
||||||
|
# for f in functions:
|
||||||
|
# if not isinstance(f, dict):
|
||||||
|
# raise TypeError(f"An object type should be returned, but `{f}`")
|
||||||
|
|
||||||
|
# tool_tasks = []
|
||||||
|
# for func in functions:
|
||||||
|
# name = func["name"]
|
||||||
|
# args = func["arguments"]
|
||||||
|
# if name == COMPLETE_TASK:
|
||||||
|
# append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
||||||
|
# async for txt, tkcnt in complete():
|
||||||
|
# yield txt, tkcnt
|
||||||
|
# return
|
||||||
|
|
||||||
|
# tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
|
||||||
|
|
||||||
|
# results = await asyncio.gather(*tool_tasks) if tool_tasks else []
|
||||||
|
# st = timer()
|
||||||
|
# reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
|
||||||
|
# append_user_content(hist, reflection)
|
||||||
|
# self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||||
|
|
||||||
|
# except Exception as e:
|
||||||
|
# logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
|
||||||
|
# e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}"
|
||||||
|
# append_user_content(hist, str(e))
|
||||||
|
|
||||||
|
# logging.warning( f"Exceed max rounds: {self._param.max_rounds}")
|
||||||
|
# final_instruction = f"""
|
||||||
|
# {user_request}
|
||||||
|
# IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request.
|
||||||
|
# Instructions:
|
||||||
|
# 1. SYNTHESIZE all information collected during this conversation
|
||||||
|
# 2. Provide a COMPLETE response using existing data - do not suggest additional research
|
||||||
|
# 3. Structure your response as a FINAL DELIVERABLE, not a plan
|
||||||
|
# 4. If information is incomplete, state what you found and provide the best analysis possible with available data
|
||||||
|
# 5. DO NOT mention conversation limits or suggest further steps
|
||||||
|
# 6. Focus on delivering VALUE with the information already gathered
|
||||||
|
# Respond immediately with your final comprehensive answer.
|
||||||
|
# """
|
||||||
|
# if self.check_if_canceled("Agent final instruction"):
|
||||||
|
# return
|
||||||
|
# append_user_content(hist, final_instruction)
|
||||||
|
|
||||||
|
# async for txt, tkcnt in complete():
|
||||||
|
# yield txt, tkcnt
|
||||||
|
|
||||||
async def _gen_citations_async(self, text):
|
async def _gen_citations_async(self, text):
|
||||||
retrievals = self._canvas.get_reference()
|
retrievals = self._canvas.get_reference()
|
||||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||||
|
|||||||
@ -27,6 +27,10 @@ import pandas as pd
|
|||||||
from agent import settings
|
from agent import settings
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
from common.misc_utils import thread_pool_exec
|
||||||
|
|
||||||
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
|
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
|
||||||
_DEPRECATED_PARAMS = "_deprecated_params"
|
_DEPRECATED_PARAMS = "_deprecated_params"
|
||||||
_USER_FEEDED_PARAMS = "_user_feeded_params"
|
_USER_FEEDED_PARAMS = "_user_feeded_params"
|
||||||
@ -361,7 +365,7 @@ class ComponentParamBase(ABC):
|
|||||||
class ComponentBase(ABC):
|
class ComponentBase(ABC):
|
||||||
component_name: str
|
component_name: str
|
||||||
thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
||||||
variable_ref_patt = r"\{* *\{([a-zA-Z_:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
"""
|
"""
|
||||||
@ -379,6 +383,7 @@ class ComponentBase(ABC):
|
|||||||
|
|
||||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||||
from agent.canvas import Graph # Local import to avoid cyclic dependency
|
from agent.canvas import Graph # Local import to avoid cyclic dependency
|
||||||
|
|
||||||
assert isinstance(canvas, Graph), "canvas must be an instance of Canvas"
|
assert isinstance(canvas, Graph), "canvas must be an instance of Canvas"
|
||||||
self._canvas = canvas
|
self._canvas = canvas
|
||||||
self._id = id
|
self._id = id
|
||||||
@ -430,7 +435,7 @@ class ComponentBase(ABC):
|
|||||||
elif asyncio.iscoroutinefunction(self._invoke):
|
elif asyncio.iscoroutinefunction(self._invoke):
|
||||||
await self._invoke(**kwargs)
|
await self._invoke(**kwargs)
|
||||||
else:
|
else:
|
||||||
await asyncio.to_thread(self._invoke, **kwargs)
|
await thread_pool_exec(self._invoke, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.get_exception_default_value():
|
if self.get_exception_default_value():
|
||||||
self.set_exception_default_value()
|
self.set_exception_default_value()
|
||||||
|
|||||||
@ -28,7 +28,7 @@ class BeginParam(UserFillUpParam):
|
|||||||
self.prologue = "Hi! I'm your smart assistant. What can I do for you?"
|
self.prologue = "Hi! I'm your smart assistant. What can I do for you?"
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task"])
|
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task","Webhook"])
|
||||||
|
|
||||||
def get_input_form(self) -> dict[str, dict]:
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
return getattr(self, "inputs")
|
return getattr(self, "inputs")
|
||||||
|
|||||||
@ -97,6 +97,13 @@ Here's description of each category:
|
|||||||
class Categorize(LLM, ABC):
|
class Categorize(LLM, ABC):
|
||||||
component_name = "Categorize"
|
component_name = "Categorize"
|
||||||
|
|
||||||
|
def get_input_elements(self) -> dict[str, dict]:
|
||||||
|
query_key = self._param.query or "sys.query"
|
||||||
|
elements = self.get_input_elements_from_text(f"{{{query_key}}}")
|
||||||
|
if not elements:
|
||||||
|
logging.warning(f"[Categorize] input element not detected for query key: {query_key}")
|
||||||
|
return elements
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
async def _invoke_async(self, **kwargs):
|
async def _invoke_async(self, **kwargs):
|
||||||
if self.check_if_canceled("Categorize processing"):
|
if self.check_if_canceled("Categorize processing"):
|
||||||
@ -105,12 +112,15 @@ class Categorize(LLM, ABC):
|
|||||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||||
if not msg:
|
if not msg:
|
||||||
msg = [{"role": "user", "content": ""}]
|
msg = [{"role": "user", "content": ""}]
|
||||||
if kwargs.get("sys.query"):
|
query_key = self._param.query or "sys.query"
|
||||||
msg[-1]["content"] = kwargs["sys.query"]
|
if query_key in kwargs:
|
||||||
self.set_input_value("sys.query", kwargs["sys.query"])
|
query_value = kwargs[query_key]
|
||||||
else:
|
else:
|
||||||
msg[-1]["content"] = self._canvas.get_variable_value(self._param.query)
|
query_value = self._canvas.get_variable_value(query_key)
|
||||||
self.set_input_value(self._param.query, msg[-1]["content"])
|
if query_value is None:
|
||||||
|
query_value = ""
|
||||||
|
msg[-1]["content"] = query_value
|
||||||
|
self.set_input_value(query_key, msg[-1]["content"])
|
||||||
self._param.update_prompt()
|
self._param.update_prompt()
|
||||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
||||||
|
|
||||||
|
|||||||
@ -56,7 +56,6 @@ class LLMParam(ComponentParamBase):
|
|||||||
self.check_nonnegative_number(int(self.max_tokens), "[Agent] Max tokens")
|
self.check_nonnegative_number(int(self.max_tokens), "[Agent] Max tokens")
|
||||||
self.check_decimal_float(float(self.top_p), "[Agent] Top P")
|
self.check_decimal_float(float(self.top_p), "[Agent] Top P")
|
||||||
self.check_empty(self.llm_id, "[Agent] LLM")
|
self.check_empty(self.llm_id, "[Agent] LLM")
|
||||||
self.check_empty(self.sys_prompt, "[Agent] System prompt")
|
|
||||||
self.check_empty(self.prompts, "[Agent] User prompt")
|
self.check_empty(self.prompts, "[Agent] User prompt")
|
||||||
|
|
||||||
def gen_conf(self):
|
def gen_conf(self):
|
||||||
|
|||||||
@ -113,6 +113,10 @@ class LoopItem(ComponentBase, ABC):
|
|||||||
return len(var) == 0
|
return len(var) == 0
|
||||||
elif operator == "not empty":
|
elif operator == "not empty":
|
||||||
return len(var) > 0
|
return len(var) > 0
|
||||||
|
elif var is None:
|
||||||
|
if operator == "empty":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
raise Exception(f"Invalid operator: {operator}")
|
raise Exception(f"Invalid operator: {operator}")
|
||||||
|
|
||||||
|
|||||||
@ -33,6 +33,8 @@ from common.connection_utils import timeout
|
|||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common import settings
|
from common import settings
|
||||||
|
|
||||||
|
from api.db.joint_services.memory_message_service import queue_save_to_memory_task
|
||||||
|
|
||||||
|
|
||||||
class MessageParam(ComponentParamBase):
|
class MessageParam(ComponentParamBase):
|
||||||
"""
|
"""
|
||||||
@ -166,6 +168,7 @@ class Message(ComponentBase):
|
|||||||
|
|
||||||
self.set_output("content", all_content)
|
self.set_output("content", all_content)
|
||||||
self._convert_content(all_content)
|
self._convert_content(all_content)
|
||||||
|
await self._save_to_memory(all_content)
|
||||||
|
|
||||||
def _is_jinjia2(self, content:str) -> bool:
|
def _is_jinjia2(self, content:str) -> bool:
|
||||||
patt = [
|
patt = [
|
||||||
@ -198,6 +201,7 @@ class Message(ComponentBase):
|
|||||||
|
|
||||||
self.set_output("content", content)
|
self.set_output("content", content)
|
||||||
self._convert_content(content)
|
self._convert_content(content)
|
||||||
|
self._save_to_memory(content)
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return ""
|
return ""
|
||||||
@ -421,3 +425,16 @@ class Message(ComponentBase):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error converting content to {self._param.output_format}: {e}")
|
logging.error(f"Error converting content to {self._param.output_format}: {e}")
|
||||||
|
|
||||||
|
async def _save_to_memory(self, content):
|
||||||
|
if not hasattr(self._param, "memory_ids") or not self._param.memory_ids:
|
||||||
|
return True, "No memory selected."
|
||||||
|
|
||||||
|
message_dict = {
|
||||||
|
"user_id": self._canvas._tenant_id,
|
||||||
|
"agent_id": self._canvas._id,
|
||||||
|
"session_id": self._canvas.task_id,
|
||||||
|
"user_input": self._canvas.get_sys_query(),
|
||||||
|
"agent_response": content
|
||||||
|
}
|
||||||
|
return await queue_save_to_memory_task(self._param.memory_ids, message_dict)
|
||||||
|
|||||||
@ -1,38 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
|
||||||
|
|
||||||
|
|
||||||
class WebhookParam(ComponentParamBase):
|
|
||||||
|
|
||||||
"""
|
|
||||||
Define the Begin component parameters.
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def get_input_form(self) -> dict[str, dict]:
|
|
||||||
return getattr(self, "inputs")
|
|
||||||
|
|
||||||
|
|
||||||
class Webhook(ComponentBase):
|
|
||||||
component_name = "Webhook"
|
|
||||||
|
|
||||||
def _invoke(self, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
|
||||||
return ""
|
|
||||||
@ -27,6 +27,10 @@ from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
|
|||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
from common.misc_utils import thread_pool_exec
|
||||||
|
|
||||||
class ToolParameter(TypedDict):
|
class ToolParameter(TypedDict):
|
||||||
type: str
|
type: str
|
||||||
description: str
|
description: str
|
||||||
@ -56,12 +60,12 @@ class LLMToolPluginCallSession(ToolCallSession):
|
|||||||
st = timer()
|
st = timer()
|
||||||
tool_obj = self.tools_map[name]
|
tool_obj = self.tools_map[name]
|
||||||
if isinstance(tool_obj, MCPToolCallSession):
|
if isinstance(tool_obj, MCPToolCallSession):
|
||||||
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
|
resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, 60)
|
||||||
else:
|
else:
|
||||||
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
|
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
|
||||||
resp = await tool_obj.invoke_async(**arguments)
|
resp = await tool_obj.invoke_async(**arguments)
|
||||||
else:
|
else:
|
||||||
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
|
resp = await thread_pool_exec(tool_obj.invoke, **arguments)
|
||||||
|
|
||||||
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
||||||
return resp
|
return resp
|
||||||
@ -122,6 +126,7 @@ class ToolParamBase(ComponentParamBase):
|
|||||||
class ToolBase(ComponentBase):
|
class ToolBase(ComponentBase):
|
||||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||||
from agent.canvas import Canvas # Local import to avoid cyclic dependency
|
from agent.canvas import Canvas # Local import to avoid cyclic dependency
|
||||||
|
|
||||||
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
|
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
|
||||||
self._canvas = canvas
|
self._canvas = canvas
|
||||||
self._id = id
|
self._id = id
|
||||||
@ -164,7 +169,7 @@ class ToolBase(ComponentBase):
|
|||||||
elif asyncio.iscoroutinefunction(self._invoke):
|
elif asyncio.iscoroutinefunction(self._invoke):
|
||||||
res = await self._invoke(**kwargs)
|
res = await self._invoke(**kwargs)
|
||||||
else:
|
else:
|
||||||
res = await asyncio.to_thread(self._invoke, **kwargs)
|
res = await thread_pool_exec(self._invoke, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._param.outputs["_ERROR"] = {"value": str(e)}
|
self._param.outputs["_ERROR"] = {"value": str(e)}
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
|
|||||||
@ -86,6 +86,12 @@ class ExeSQL(ToolBase, ABC):
|
|||||||
|
|
||||||
def convert_decimals(obj):
|
def convert_decimals(obj):
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
import math
|
||||||
|
if isinstance(obj, float):
|
||||||
|
# Handle NaN and Infinity which are not valid JSON values
|
||||||
|
if math.isnan(obj) or math.isinf(obj):
|
||||||
|
return None
|
||||||
|
return obj
|
||||||
if isinstance(obj, Decimal):
|
if isinstance(obj, Decimal):
|
||||||
return float(obj) # 或 str(obj)
|
return float(obj) # 或 str(obj)
|
||||||
elif isinstance(obj, dict):
|
elif isinstance(obj, dict):
|
||||||
|
|||||||
@ -25,10 +25,12 @@ from api.db.services.document_service import DocumentService
|
|||||||
from common.metadata_utils import apply_meta_data_filter
|
from common.metadata_utils import apply_meta_data_filter
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.db.services.memory_service import MemoryService
|
||||||
|
from api.db.joint_services import memory_message_service
|
||||||
from common import settings
|
from common import settings
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts.generator import cross_languages, kb_prompt
|
from rag.prompts.generator import cross_languages, kb_prompt, memory_prompt
|
||||||
|
|
||||||
|
|
||||||
class RetrievalParam(ToolParamBase):
|
class RetrievalParam(ToolParamBase):
|
||||||
@ -57,6 +59,7 @@ class RetrievalParam(ToolParamBase):
|
|||||||
self.top_n = 8
|
self.top_n = 8
|
||||||
self.top_k = 1024
|
self.top_k = 1024
|
||||||
self.kb_ids = []
|
self.kb_ids = []
|
||||||
|
self.memory_ids = []
|
||||||
self.kb_vars = []
|
self.kb_vars = []
|
||||||
self.rerank_id = ""
|
self.rerank_id = ""
|
||||||
self.empty_response = ""
|
self.empty_response = ""
|
||||||
@ -81,15 +84,7 @@ class RetrievalParam(ToolParamBase):
|
|||||||
class Retrieval(ToolBase, ABC):
|
class Retrieval(ToolBase, ABC):
|
||||||
component_name = "Retrieval"
|
component_name = "Retrieval"
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
async def _retrieve_kb(self, query_text: str):
|
||||||
async def _invoke_async(self, **kwargs):
|
|
||||||
if self.check_if_canceled("Retrieval processing"):
|
|
||||||
return
|
|
||||||
|
|
||||||
if not kwargs.get("query"):
|
|
||||||
self.set_output("formalized_content", self._param.empty_response)
|
|
||||||
return
|
|
||||||
|
|
||||||
kb_ids: list[str] = []
|
kb_ids: list[str] = []
|
||||||
for id in self._param.kb_ids:
|
for id in self._param.kb_ids:
|
||||||
if id.find("@") < 0:
|
if id.find("@") < 0:
|
||||||
@ -124,12 +119,12 @@ class Retrieval(ToolBase, ABC):
|
|||||||
if self._param.rerank_id:
|
if self._param.rerank_id:
|
||||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
||||||
|
|
||||||
vars = self.get_input_elements_from_text(kwargs["query"])
|
vars = self.get_input_elements_from_text(query_text)
|
||||||
vars = {k:o["value"] for k,o in vars.items()}
|
vars = {k: o["value"] for k, o in vars.items()}
|
||||||
query = self.string_format(kwargs["query"], vars)
|
query = self.string_format(query_text, vars)
|
||||||
|
|
||||||
doc_ids=[]
|
doc_ids = []
|
||||||
if self._param.meta_data_filter!={}:
|
if self._param.meta_data_filter != {}:
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
|
|
||||||
def _resolve_manual_filter(flt: dict) -> dict:
|
def _resolve_manual_filter(flt: dict) -> dict:
|
||||||
@ -179,7 +174,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
|
|
||||||
if kbs:
|
if kbs:
|
||||||
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
||||||
kbinfos = settings.retriever.retrieval(
|
kbinfos = await settings.retriever.retrieval(
|
||||||
query,
|
query,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
[kb.tenant_id for kb in kbs],
|
[kb.tenant_id for kb in kbs],
|
||||||
@ -198,18 +193,20 @@ class Retrieval(ToolBase, ABC):
|
|||||||
|
|
||||||
if self._param.toc_enhance:
|
if self._param.toc_enhance:
|
||||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
|
||||||
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
|
cks = await settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
|
||||||
|
chat_mdl, self._param.top_n)
|
||||||
if self.check_if_canceled("Retrieval processing"):
|
if self.check_if_canceled("Retrieval processing"):
|
||||||
return
|
return
|
||||||
if cks:
|
if cks:
|
||||||
kbinfos["chunks"] = cks
|
kbinfos["chunks"] = cks
|
||||||
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
|
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
|
||||||
|
[kb.tenant_id for kb in kbs])
|
||||||
if self._param.use_kg:
|
if self._param.use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(query,
|
ck = await settings.kg_retriever.retrieval(query,
|
||||||
[kb.tenant_id for kb in kbs],
|
[kb.tenant_id for kb in kbs],
|
||||||
kb_ids,
|
kb_ids,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
|
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
|
||||||
if self.check_if_canceled("Retrieval processing"):
|
if self.check_if_canceled("Retrieval processing"):
|
||||||
return
|
return
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
@ -218,7 +215,8 @@ class Retrieval(ToolBase, ABC):
|
|||||||
kbinfos = {"chunks": [], "doc_aggs": []}
|
kbinfos = {"chunks": [], "doc_aggs": []}
|
||||||
|
|
||||||
if self._param.use_kg and kbs:
|
if self._param.use_kg and kbs:
|
||||||
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
ck = await settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
|
||||||
|
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||||
if self.check_if_canceled("Retrieval processing"):
|
if self.check_if_canceled("Retrieval processing"):
|
||||||
return
|
return
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
@ -248,6 +246,54 @@ class Retrieval(ToolBase, ABC):
|
|||||||
|
|
||||||
return form_cnt
|
return form_cnt
|
||||||
|
|
||||||
|
async def _retrieve_memory(self, query_text: str):
|
||||||
|
memory_ids: list[str] = [memory_id for memory_id in self._param.memory_ids]
|
||||||
|
memory_list = MemoryService.get_by_ids(memory_ids)
|
||||||
|
if not memory_list:
|
||||||
|
raise Exception("No memory is selected.")
|
||||||
|
|
||||||
|
embd_names = list({memory.embd_id for memory in memory_list})
|
||||||
|
assert len(embd_names) == 1, "Memory use different embedding models."
|
||||||
|
|
||||||
|
vars = self.get_input_elements_from_text(query_text)
|
||||||
|
vars = {k: o["value"] for k, o in vars.items()}
|
||||||
|
query = self.string_format(query_text, vars)
|
||||||
|
# query message
|
||||||
|
message_list = memory_message_service.query_message({"memory_id": memory_ids}, {
|
||||||
|
"query": query,
|
||||||
|
"similarity_threshold": self._param.similarity_threshold,
|
||||||
|
"keywords_similarity_weight": self._param.keywords_similarity_weight,
|
||||||
|
"top_n": self._param.top_n
|
||||||
|
})
|
||||||
|
if not message_list:
|
||||||
|
self.set_output("formalized_content", self._param.empty_response)
|
||||||
|
return ""
|
||||||
|
formated_content = "\n".join(memory_prompt(message_list, 200000))
|
||||||
|
# set formalized_content output
|
||||||
|
self.set_output("formalized_content", formated_content)
|
||||||
|
|
||||||
|
return formated_content
|
||||||
|
|
||||||
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
|
async def _invoke_async(self, **kwargs):
|
||||||
|
if self.check_if_canceled("Retrieval processing"):
|
||||||
|
return
|
||||||
|
if not kwargs.get("query"):
|
||||||
|
self.set_output("formalized_content", self._param.empty_response)
|
||||||
|
return
|
||||||
|
|
||||||
|
if hasattr(self._param, "retrieval_from") and self._param.retrieval_from == "dataset":
|
||||||
|
return await self._retrieve_kb(kwargs["query"])
|
||||||
|
elif hasattr(self._param, "retrieval_from") and self._param.retrieval_from == "memory":
|
||||||
|
return await self._retrieve_memory(kwargs["query"])
|
||||||
|
elif self._param.kb_ids:
|
||||||
|
return await self._retrieve_kb(kwargs["query"])
|
||||||
|
elif hasattr(self._param, "memory_ids") and self._param.memory_ids:
|
||||||
|
return await self._retrieve_memory(kwargs["query"])
|
||||||
|
else:
|
||||||
|
self.set_output("formalized_content", self._param.empty_response)
|
||||||
|
return
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
return asyncio.run(self._invoke_async(**kwargs))
|
return asyncio.run(self._invoke_async(**kwargs))
|
||||||
|
|||||||
@ -1 +0,0 @@
|
|||||||
from .deep_research import DeepResearcher as DeepResearcher
|
|
||||||
@ -1,238 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from functools import partial
|
|
||||||
from agentic_reasoning.prompts import BEGIN_SEARCH_QUERY, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT, MAX_SEARCH_LIMIT, \
|
|
||||||
END_SEARCH_QUERY, REASON_PROMPT, RELEVANT_EXTRACTION_PROMPT
|
|
||||||
from api.db.services.llm_service import LLMBundle
|
|
||||||
from rag.nlp import extract_between
|
|
||||||
from rag.prompts import kb_prompt
|
|
||||||
from rag.utils.tavily_conn import Tavily
|
|
||||||
|
|
||||||
|
|
||||||
class DeepResearcher:
|
|
||||||
def __init__(self,
|
|
||||||
chat_mdl: LLMBundle,
|
|
||||||
prompt_config: dict,
|
|
||||||
kb_retrieve: partial = None,
|
|
||||||
kg_retrieve: partial = None
|
|
||||||
):
|
|
||||||
self.chat_mdl = chat_mdl
|
|
||||||
self.prompt_config = prompt_config
|
|
||||||
self._kb_retrieve = kb_retrieve
|
|
||||||
self._kg_retrieve = kg_retrieve
|
|
||||||
|
|
||||||
def _remove_tags(text: str, start_tag: str, end_tag: str) -> str:
|
|
||||||
"""General Tag Removal Method"""
|
|
||||||
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
|
|
||||||
return re.sub(pattern, "", text)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _remove_query_tags(text: str) -> str:
|
|
||||||
"""Remove Query Tags"""
|
|
||||||
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _remove_result_tags(text: str) -> str:
|
|
||||||
"""Remove Result Tags"""
|
|
||||||
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT)
|
|
||||||
|
|
||||||
async def _generate_reasoning(self, msg_history):
|
|
||||||
"""Generate reasoning steps"""
|
|
||||||
query_think = ""
|
|
||||||
if msg_history[-1]["role"] != "user":
|
|
||||||
msg_history.append({"role": "user", "content": "Continues reasoning with the new information.\n"})
|
|
||||||
else:
|
|
||||||
msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
|
|
||||||
|
|
||||||
async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
|
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
|
||||||
if not ans:
|
|
||||||
continue
|
|
||||||
query_think = ans
|
|
||||||
yield query_think
|
|
||||||
query_think = ""
|
|
||||||
yield query_think
|
|
||||||
|
|
||||||
def _extract_search_queries(self, query_think, question, step_index):
|
|
||||||
"""Extract search queries from thinking"""
|
|
||||||
queries = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
|
|
||||||
if not queries and step_index == 0:
|
|
||||||
# If this is the first step and no queries are found, use the original question as the query
|
|
||||||
queries = [question]
|
|
||||||
return queries
|
|
||||||
|
|
||||||
def _truncate_previous_reasoning(self, all_reasoning_steps):
|
|
||||||
"""Truncate previous reasoning steps to maintain a reasonable length"""
|
|
||||||
truncated_prev_reasoning = ""
|
|
||||||
for i, step in enumerate(all_reasoning_steps):
|
|
||||||
truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n"
|
|
||||||
|
|
||||||
prev_steps = truncated_prev_reasoning.split('\n\n')
|
|
||||||
if len(prev_steps) <= 5:
|
|
||||||
truncated_prev_reasoning = '\n\n'.join(prev_steps)
|
|
||||||
else:
|
|
||||||
truncated_prev_reasoning = ''
|
|
||||||
for i, step in enumerate(prev_steps):
|
|
||||||
if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step:
|
|
||||||
truncated_prev_reasoning += step + '\n\n'
|
|
||||||
else:
|
|
||||||
if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
|
|
||||||
truncated_prev_reasoning += '...\n\n'
|
|
||||||
|
|
||||||
return truncated_prev_reasoning.strip('\n')
|
|
||||||
|
|
||||||
def _retrieve_information(self, search_query):
|
|
||||||
"""Retrieve information from different sources"""
|
|
||||||
# 1. Knowledge base retrieval
|
|
||||||
kbinfos = []
|
|
||||||
try:
|
|
||||||
kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []}
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Knowledge base retrieval error: {e}")
|
|
||||||
|
|
||||||
# 2. Web retrieval (if Tavily API is configured)
|
|
||||||
try:
|
|
||||||
if self.prompt_config.get("tavily_api_key"):
|
|
||||||
tav = Tavily(self.prompt_config["tavily_api_key"])
|
|
||||||
tav_res = tav.retrieve_chunks(search_query)
|
|
||||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
|
||||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Web retrieval error: {e}")
|
|
||||||
|
|
||||||
# 3. Knowledge graph retrieval (if configured)
|
|
||||||
try:
|
|
||||||
if self.prompt_config.get("use_kg") and self._kg_retrieve:
|
|
||||||
ck = self._kg_retrieve(question=search_query)
|
|
||||||
if ck["content_with_weight"]:
|
|
||||||
kbinfos["chunks"].insert(0, ck)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Knowledge graph retrieval error: {e}")
|
|
||||||
|
|
||||||
return kbinfos
|
|
||||||
|
|
||||||
def _update_chunk_info(self, chunk_info, kbinfos):
|
|
||||||
"""Update chunk information for citations"""
|
|
||||||
if not chunk_info["chunks"]:
|
|
||||||
# If this is the first retrieval, use the retrieval results directly
|
|
||||||
for k in chunk_info.keys():
|
|
||||||
chunk_info[k] = kbinfos[k]
|
|
||||||
else:
|
|
||||||
# Merge newly retrieved information, avoiding duplicates
|
|
||||||
cids = [c["chunk_id"] for c in chunk_info["chunks"]]
|
|
||||||
for c in kbinfos["chunks"]:
|
|
||||||
if c["chunk_id"] not in cids:
|
|
||||||
chunk_info["chunks"].append(c)
|
|
||||||
|
|
||||||
dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
|
|
||||||
for d in kbinfos["doc_aggs"]:
|
|
||||||
if d["doc_id"] not in dids:
|
|
||||||
chunk_info["doc_aggs"].append(d)
|
|
||||||
|
|
||||||
async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
|
|
||||||
"""Extract and summarize relevant information"""
|
|
||||||
summary_think = ""
|
|
||||||
async for ans in self.chat_mdl.async_chat_streamly(
|
|
||||||
RELEVANT_EXTRACTION_PROMPT.format(
|
|
||||||
prev_reasoning=truncated_prev_reasoning,
|
|
||||||
search_query=search_query,
|
|
||||||
document="\n".join(kb_prompt(kbinfos, 4096))
|
|
||||||
),
|
|
||||||
[{"role": "user",
|
|
||||||
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
|
|
||||||
{"temperature": 0.7}):
|
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
|
||||||
if not ans:
|
|
||||||
continue
|
|
||||||
summary_think = ans
|
|
||||||
yield summary_think
|
|
||||||
summary_think = ""
|
|
||||||
|
|
||||||
yield summary_think
|
|
||||||
|
|
||||||
async def thinking(self, chunk_info: dict, question: str):
|
|
||||||
executed_search_queries = []
|
|
||||||
msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
|
|
||||||
all_reasoning_steps = []
|
|
||||||
think = "<think>"
|
|
||||||
|
|
||||||
for step_index in range(MAX_SEARCH_LIMIT + 1):
|
|
||||||
# Check if the maximum search limit has been reached
|
|
||||||
if step_index == MAX_SEARCH_LIMIT - 1:
|
|
||||||
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n"
|
|
||||||
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
|
|
||||||
all_reasoning_steps.append(summary_think)
|
|
||||||
msg_history.append({"role": "assistant", "content": summary_think})
|
|
||||||
break
|
|
||||||
|
|
||||||
# Step 1: Generate reasoning
|
|
||||||
query_think = ""
|
|
||||||
async for ans in self._generate_reasoning(msg_history):
|
|
||||||
query_think = ans
|
|
||||||
yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
|
|
||||||
|
|
||||||
think += self._remove_query_tags(query_think)
|
|
||||||
all_reasoning_steps.append(query_think)
|
|
||||||
|
|
||||||
# Step 2: Extract search queries
|
|
||||||
queries = self._extract_search_queries(query_think, question, step_index)
|
|
||||||
if not queries and step_index > 0:
|
|
||||||
# If not the first step and no queries, end the search process
|
|
||||||
break
|
|
||||||
|
|
||||||
# Process each search query
|
|
||||||
for search_query in queries:
|
|
||||||
logging.info(f"[THINK]Query: {step_index}. {search_query}")
|
|
||||||
msg_history.append({"role": "assistant", "content": search_query})
|
|
||||||
think += f"\n\n> {step_index + 1}. {search_query}\n\n"
|
|
||||||
yield {"answer": think + "</think>", "reference": {}, "audio_binary": None}
|
|
||||||
|
|
||||||
# Check if the query has already been executed
|
|
||||||
if search_query in executed_search_queries:
|
|
||||||
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n"
|
|
||||||
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
|
|
||||||
all_reasoning_steps.append(summary_think)
|
|
||||||
msg_history.append({"role": "user", "content": summary_think})
|
|
||||||
think += summary_think
|
|
||||||
continue
|
|
||||||
|
|
||||||
executed_search_queries.append(search_query)
|
|
||||||
|
|
||||||
# Step 3: Truncate previous reasoning steps
|
|
||||||
truncated_prev_reasoning = self._truncate_previous_reasoning(all_reasoning_steps)
|
|
||||||
|
|
||||||
# Step 4: Retrieve information
|
|
||||||
kbinfos = self._retrieve_information(search_query)
|
|
||||||
|
|
||||||
# Step 5: Update chunk information
|
|
||||||
self._update_chunk_info(chunk_info, kbinfos)
|
|
||||||
|
|
||||||
# Step 6: Extract relevant information
|
|
||||||
think += "\n\n"
|
|
||||||
summary_think = ""
|
|
||||||
async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
|
|
||||||
summary_think = ans
|
|
||||||
yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}
|
|
||||||
|
|
||||||
all_reasoning_steps.append(summary_think)
|
|
||||||
msg_history.append(
|
|
||||||
{"role": "user", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"})
|
|
||||||
think += self._remove_result_tags(summary_think)
|
|
||||||
logging.info(f"[THINK]Summary: {step_index}. {summary_think}")
|
|
||||||
|
|
||||||
yield think + "</think>"
|
|
||||||
@ -1,147 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
|
|
||||||
END_SEARCH_QUERY = "<|end_search_query|>"
|
|
||||||
BEGIN_SEARCH_RESULT = "<|begin_search_result|>"
|
|
||||||
END_SEARCH_RESULT = "<|end_search_result|>"
|
|
||||||
MAX_SEARCH_LIMIT = 6
|
|
||||||
|
|
||||||
REASON_PROMPT = f"""You are an advanced reasoning agent. Your goal is to answer the user's question by breaking it down into a series of verifiable steps.
|
|
||||||
|
|
||||||
You have access to a powerful search tool to find information.
|
|
||||||
|
|
||||||
**Your Task:**
|
|
||||||
1. Analyze the user's question.
|
|
||||||
2. If you need information, issue a search query to find a specific fact.
|
|
||||||
3. Review the search results.
|
|
||||||
4. Repeat the search process until you have all the facts needed to answer the question.
|
|
||||||
5. Once you have gathered sufficient information, synthesize the facts and provide the final answer directly.
|
|
||||||
|
|
||||||
**Tool Usage:**
|
|
||||||
- To search, you MUST write your query between the special tokens: {BEGIN_SEARCH_QUERY}your query{END_SEARCH_QUERY}.
|
|
||||||
- The system will provide results between {BEGIN_SEARCH_RESULT}search results{END_SEARCH_RESULT}.
|
|
||||||
- You have a maximum of {MAX_SEARCH_LIMIT} search attempts.
|
|
||||||
|
|
||||||
---
|
|
||||||
**Example 1: Multi-hop Question**
|
|
||||||
|
|
||||||
**Question:** "Are both the directors of Jaws and Casino Royale from the same country?"
|
|
||||||
|
|
||||||
**Your Thought Process & Actions:**
|
|
||||||
First, I need to identify the director of Jaws.
|
|
||||||
{BEGIN_SEARCH_QUERY}who is the director of Jaws?{END_SEARCH_QUERY}
|
|
||||||
[System returns search results]
|
|
||||||
{BEGIN_SEARCH_RESULT}
|
|
||||||
Jaws is a 1975 American thriller film directed by Steven Spielberg.
|
|
||||||
{END_SEARCH_RESULT}
|
|
||||||
Okay, the director of Jaws is Steven Spielberg. Now I need to find out his nationality.
|
|
||||||
{BEGIN_SEARCH_QUERY}where is Steven Spielberg from?{END_SEARCH_QUERY}
|
|
||||||
[System returns search results]
|
|
||||||
{BEGIN_SEARCH_RESULT}
|
|
||||||
Steven Allan Spielberg is an American filmmaker. Born in Cincinnati, Ohio...
|
|
||||||
{END_SEARCH_RESULT}
|
|
||||||
So, Steven Spielberg is from the USA. Next, I need to find the director of Casino Royale.
|
|
||||||
{BEGIN_SEARCH_QUERY}who is the director of Casino Royale 2006?{END_SEARCH_QUERY}
|
|
||||||
[System returns search results]
|
|
||||||
{BEGIN_SEARCH_RESULT}
|
|
||||||
Casino Royale is a 2006 spy film directed by Martin Campbell.
|
|
||||||
{END_SEARCH_RESULT}
|
|
||||||
The director of Casino Royale is Martin Campbell. Now I need his nationality.
|
|
||||||
{BEGIN_SEARCH_QUERY}where is Martin Campbell from?{END_SEARCH_QUERY}
|
|
||||||
[System returns search results]
|
|
||||||
{BEGIN_SEARCH_RESULT}
|
|
||||||
Martin Campbell (born 24 October 1943) is a New Zealand film and television director.
|
|
||||||
{END_SEARCH_RESULT}
|
|
||||||
I have all the information. Steven Spielberg is from the USA, and Martin Campbell is from New Zealand. They are not from the same country.
|
|
||||||
|
|
||||||
Final Answer: No, the directors of Jaws and Casino Royale are not from the same country. Steven Spielberg is from the USA, and Martin Campbell is from New Zealand.
|
|
||||||
|
|
||||||
---
|
|
||||||
**Example 2: Simple Fact Retrieval**
|
|
||||||
|
|
||||||
**Question:** "When was the founder of craigslist born?"
|
|
||||||
|
|
||||||
**Your Thought Process & Actions:**
|
|
||||||
First, I need to know who founded craigslist.
|
|
||||||
{BEGIN_SEARCH_QUERY}who founded craigslist?{END_SEARCH_QUERY}
|
|
||||||
[System returns search results]
|
|
||||||
{BEGIN_SEARCH_RESULT}
|
|
||||||
Craigslist was founded in 1995 by Craig Newmark.
|
|
||||||
{END_SEARCH_RESULT}
|
|
||||||
The founder is Craig Newmark. Now I need his birth date.
|
|
||||||
{BEGIN_SEARCH_QUERY}when was Craig Newmark born?{END_SEARCH_QUERY}
|
|
||||||
[System returns search results]
|
|
||||||
{BEGIN_SEARCH_RESULT}
|
|
||||||
Craig Newmark was born on December 6, 1952.
|
|
||||||
{END_SEARCH_RESULT}
|
|
||||||
I have found the answer.
|
|
||||||
|
|
||||||
Final Answer: The founder of craigslist, Craig Newmark, was born on December 6, 1952.
|
|
||||||
|
|
||||||
---
|
|
||||||
**Important Rules:**
|
|
||||||
- **One Fact at a Time:** Decompose the problem and issue one search query at a time to find a single, specific piece of information.
|
|
||||||
- **Be Precise:** Formulate clear and precise search queries. If a search fails, rephrase it.
|
|
||||||
- **Synthesize at the End:** Do not provide the final answer until you have completed all necessary searches.
|
|
||||||
- **Language Consistency:** Your search queries should be in the same language as the user's question.
|
|
||||||
|
|
||||||
Now, begin your work. Please answer the following question by thinking step-by-step.
|
|
||||||
"""
|
|
||||||
|
|
||||||
RELEVANT_EXTRACTION_PROMPT = """You are a highly efficient information extraction module. Your sole purpose is to extract the single most relevant piece of information from the provided `Searched Web Pages` that directly answers the `Current Search Query`.
|
|
||||||
|
|
||||||
**Your Task:**
|
|
||||||
1. Read the `Current Search Query` to understand what specific information is needed.
|
|
||||||
2. Scan the `Searched Web Pages` to find the answer to that query.
|
|
||||||
3. Extract only the essential, factual information that answers the query. Be concise.
|
|
||||||
|
|
||||||
**Context (For Your Information Only):**
|
|
||||||
The `Previous Reasoning Steps` are provided to give you context on the overall goal, but your primary focus MUST be on answering the `Current Search Query`. Do not use information from the previous steps in your output.
|
|
||||||
|
|
||||||
**Output Format:**
|
|
||||||
Your response must follow one of two formats precisely.
|
|
||||||
|
|
||||||
1. **If a direct and relevant answer is found:**
|
|
||||||
- Start your response immediately with `Final Information`.
|
|
||||||
- Provide only the extracted fact(s). Do not add any extra conversational text.
|
|
||||||
|
|
||||||
*Example:*
|
|
||||||
`Current Search Query`: Where is Martin Campbell from?
|
|
||||||
`Searched Web Pages`: [Long article snippet about Martin Campbell's career, which includes the sentence "Martin Campbell (born 24 October 1943) is a New Zealand film and television director..."]
|
|
||||||
|
|
||||||
*Your Output:*
|
|
||||||
Final Information
|
|
||||||
Martin Campbell is a New Zealand film and television director.
|
|
||||||
|
|
||||||
2. **If no relevant answer that directly addresses the query is found in the web pages:**
|
|
||||||
- Start your response immediately with `Final Information`.
|
|
||||||
- Write the exact phrase: `No helpful information found.`
|
|
||||||
|
|
||||||
---
|
|
||||||
**BEGIN TASK**
|
|
||||||
|
|
||||||
**Inputs:**
|
|
||||||
|
|
||||||
- **Previous Reasoning Steps:**
|
|
||||||
{prev_reasoning}
|
|
||||||
|
|
||||||
- **Current Search Query:**
|
|
||||||
{search_query}
|
|
||||||
|
|
||||||
- **Searched Web Pages:**
|
|
||||||
{document}
|
|
||||||
"""
|
|
||||||
@ -16,21 +16,23 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
from importlib.util import module_from_spec, spec_from_file_location
|
from importlib.util import module_from_spec, spec_from_file_location
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from quart import Blueprint, Quart, request, g, current_app, session
|
from quart import Blueprint, Quart, request, g, current_app, session, jsonify
|
||||||
from flasgger import Swagger
|
|
||||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||||
from quart_cors import cors
|
from quart_cors import cors
|
||||||
from common.constants import StatusEnum
|
from common.constants import StatusEnum, RetCode
|
||||||
from api.db.db_models import close_connection, APIToken
|
from api.db.db_models import close_connection, APIToken
|
||||||
from api.db.services import UserService
|
from api.db.services import UserService
|
||||||
from api.utils.json_encode import CustomJSONEncoder
|
from api.utils.json_encode import CustomJSONEncoder
|
||||||
from api.utils import commands
|
from api.utils import commands
|
||||||
|
|
||||||
from quart_auth import Unauthorized
|
from quart_auth import Unauthorized as QuartAuthUnauthorized
|
||||||
|
from werkzeug.exceptions import Unauthorized as WerkzeugUnauthorized
|
||||||
|
from quart_schema import QuartSchema
|
||||||
from common import settings
|
from common import settings
|
||||||
from api.utils.api_utils import server_error_response
|
from api.utils.api_utils import server_error_response, get_json_result
|
||||||
from api.constants import API_VERSION
|
from api.constants import API_VERSION
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
|
|
||||||
@ -38,41 +40,27 @@ settings.init_settings()
|
|||||||
|
|
||||||
__all__ = ["app"]
|
__all__ = ["app"]
|
||||||
|
|
||||||
|
UNAUTHORIZED_MESSAGE = "<Unauthorized '401: Unauthorized'>"
|
||||||
|
|
||||||
|
|
||||||
|
def _unauthorized_message(error):
|
||||||
|
if error is None:
|
||||||
|
return UNAUTHORIZED_MESSAGE
|
||||||
|
try:
|
||||||
|
msg = repr(error)
|
||||||
|
except Exception:
|
||||||
|
return UNAUTHORIZED_MESSAGE
|
||||||
|
if msg == UNAUTHORIZED_MESSAGE:
|
||||||
|
return msg
|
||||||
|
if "Unauthorized" in msg and "401" in msg:
|
||||||
|
return msg
|
||||||
|
return UNAUTHORIZED_MESSAGE
|
||||||
|
|
||||||
app = Quart(__name__)
|
app = Quart(__name__)
|
||||||
app = cors(app, allow_origin="*")
|
app = cors(app, allow_origin="*")
|
||||||
|
|
||||||
# Add this at the beginning of your file to configure Swagger UI
|
# openapi supported
|
||||||
swagger_config = {
|
QuartSchema(app)
|
||||||
"headers": [],
|
|
||||||
"specs": [
|
|
||||||
{
|
|
||||||
"endpoint": "apispec",
|
|
||||||
"route": "/apispec.json",
|
|
||||||
"rule_filter": lambda rule: True, # Include all endpoints
|
|
||||||
"model_filter": lambda tag: True, # Include all models
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"static_url_path": "/flasgger_static",
|
|
||||||
"swagger_ui": True,
|
|
||||||
"specs_route": "/apidocs/",
|
|
||||||
}
|
|
||||||
|
|
||||||
swagger = Swagger(
|
|
||||||
app,
|
|
||||||
config=swagger_config,
|
|
||||||
template={
|
|
||||||
"swagger": "2.0",
|
|
||||||
"info": {
|
|
||||||
"title": "RAGFlow API",
|
|
||||||
"description": "",
|
|
||||||
"version": "1.0.0",
|
|
||||||
},
|
|
||||||
"securityDefinitions": {
|
|
||||||
"ApiKeyAuth": {"type": "apiKey", "name": "Authorization", "in": "header"}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
app.url_map.strict_slashes = False
|
app.url_map.strict_slashes = False
|
||||||
app.json_encoder = CustomJSONEncoder
|
app.json_encoder = CustomJSONEncoder
|
||||||
@ -103,12 +91,13 @@ from werkzeug.local import LocalProxy
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
def _load_user():
|
def _load_user():
|
||||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||||
authorization = request.headers.get("Authorization")
|
authorization = request.headers.get("Authorization")
|
||||||
g.user = None
|
g.user = None
|
||||||
if not authorization:
|
if not authorization:
|
||||||
return
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
access_token = str(jwt.loads(authorization))
|
access_token = str(jwt.loads(authorization))
|
||||||
@ -125,18 +114,28 @@ def _load_user():
|
|||||||
user = UserService.query(
|
user = UserService.query(
|
||||||
access_token=access_token, status=StatusEnum.VALID.value
|
access_token=access_token, status=StatusEnum.VALID.value
|
||||||
)
|
)
|
||||||
if not user and len(authorization.split()) == 2:
|
|
||||||
objs = APIToken.query(token=authorization.split()[1])
|
|
||||||
if objs:
|
|
||||||
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
|
|
||||||
if user:
|
if user:
|
||||||
if not user[0].access_token or not user[0].access_token.strip():
|
if not user[0].access_token or not user[0].access_token.strip():
|
||||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||||
return None
|
return None
|
||||||
g.user = user[0]
|
g.user = user[0]
|
||||||
return user[0]
|
return user[0]
|
||||||
except Exception as e:
|
except Exception as e_auth:
|
||||||
logging.warning(f"load_user got exception {e}")
|
logging.warning(f"load_user got exception {e_auth}")
|
||||||
|
try:
|
||||||
|
authorization = request.headers.get("Authorization")
|
||||||
|
if len(authorization.split()) == 2:
|
||||||
|
objs = APIToken.query(token=authorization.split()[1])
|
||||||
|
if objs:
|
||||||
|
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
|
||||||
|
if user:
|
||||||
|
if not user[0].access_token or not user[0].access_token.strip():
|
||||||
|
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||||
|
return None
|
||||||
|
g.user = user[0]
|
||||||
|
return user[0]
|
||||||
|
except Exception as e_api_token:
|
||||||
|
logging.warning(f"load_user got exception {e_api_token}")
|
||||||
|
|
||||||
|
|
||||||
current_user = LocalProxy(_load_user)
|
current_user = LocalProxy(_load_user)
|
||||||
@ -164,10 +163,18 @@ def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]
|
|||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
if not current_user:# or not session.get("_user_id"):
|
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
|
||||||
raise Unauthorized()
|
t_start = time.perf_counter() if timing_enabled else None
|
||||||
else:
|
user = current_user
|
||||||
return await current_app.ensure_async(func)(*args, **kwargs)
|
if timing_enabled:
|
||||||
|
logging.info(
|
||||||
|
"api_timing login_required auth_ms=%.2f path=%s",
|
||||||
|
(time.perf_counter() - t_start) * 1000,
|
||||||
|
request.path,
|
||||||
|
)
|
||||||
|
if not user: # or not session.get("_user_id"):
|
||||||
|
raise QuartAuthUnauthorized()
|
||||||
|
return await current_app.ensure_async(func)(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@ -228,6 +235,7 @@ def logout_user():
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def search_pages_path(page_path):
|
def search_pages_path(page_path):
|
||||||
app_path_list = [
|
app_path_list = [
|
||||||
path for path in page_path.glob("*_app.py") if not path.name.startswith(".")
|
path for path in page_path.glob("*_app.py") if not path.name.startswith(".")
|
||||||
@ -274,6 +282,36 @@ client_urls_prefix = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@app.errorhandler(404)
|
||||||
|
async def not_found(error):
|
||||||
|
logging.error(f"The requested URL {request.path} was not found")
|
||||||
|
message = f"Not Found: {request.path}"
|
||||||
|
response = {
|
||||||
|
"code": RetCode.NOT_FOUND,
|
||||||
|
"message": message,
|
||||||
|
"data": None,
|
||||||
|
"error": "Not Found",
|
||||||
|
}
|
||||||
|
return jsonify(response), RetCode.NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
@app.errorhandler(401)
|
||||||
|
async def unauthorized(error):
|
||||||
|
logging.warning("Unauthorized request")
|
||||||
|
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
@app.errorhandler(QuartAuthUnauthorized)
|
||||||
|
async def unauthorized_quart_auth(error):
|
||||||
|
logging.warning("Unauthorized request (quart_auth)")
|
||||||
|
return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(error)), RetCode.UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
@app.errorhandler(WerkzeugUnauthorized)
|
||||||
|
async def unauthorized_werkzeug(error):
|
||||||
|
logging.warning("Unauthorized request (werkzeug)")
|
||||||
|
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
|
||||||
|
|
||||||
@app.teardown_request
|
@app.teardown_request
|
||||||
def _db_close(exception):
|
def _db_close(exception):
|
||||||
if exception:
|
if exception:
|
||||||
|
|||||||
@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import asyncio
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -29,9 +28,14 @@ from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, Ta
|
|||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid, thread_pool_exec
|
||||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
|
from api.utils.api_utils import (
|
||||||
get_request_json
|
get_json_result,
|
||||||
|
server_error_response,
|
||||||
|
validate_request,
|
||||||
|
get_data_error_result,
|
||||||
|
get_request_json,
|
||||||
|
)
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||||
from api.db.db_models import APIToken, Task
|
from api.db.db_models import APIToken, Task
|
||||||
@ -132,12 +136,12 @@ async def run():
|
|||||||
files = req.get("files", [])
|
files = req.get("files", [])
|
||||||
inputs = req.get("inputs", {})
|
inputs = req.get("inputs", {})
|
||||||
user_id = req.get("user_id", current_user.id)
|
user_id = req.get("user_id", current_user.id)
|
||||||
if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
|
if not await thread_pool_exec(UserCanvasService.accessible, req["id"], current_user.id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of canvas authorized for this operation.',
|
data=False, message='Only owner of canvas authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
|
e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, req["id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="canvas not found.")
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
|
||||||
@ -147,13 +151,13 @@ async def run():
|
|||||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||||
task_id = get_uuid()
|
task_id = get_uuid()
|
||||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||||
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
|
ok, error_message = await thread_pool_exec(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
|
||||||
if not ok:
|
if not ok:
|
||||||
return get_data_error_result(message=error_message)
|
return get_data_error_result(message=error_message)
|
||||||
return get_json_result(data={"message_id": task_id})
|
return get_json_result(data={"message_id": task_id})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
canvas = Canvas(cvs.dsl, current_user.id)
|
canvas = Canvas(cvs.dsl, current_user.id, canvas_id=cvs.id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -192,7 +196,7 @@ async def rerun():
|
|||||||
if 0 < doc["progress"] < 1:
|
if 0 < doc["progress"] < 1:
|
||||||
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
||||||
|
|
||||||
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
|
if settings.docStoreConn.index_exist(search.index_name(current_user.id), doc["kb_id"]):
|
||||||
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
||||||
doc["progress_msg"] = ""
|
doc["progress_msg"] = ""
|
||||||
doc["chunk_num"] = 0
|
doc["chunk_num"] = 0
|
||||||
@ -232,7 +236,7 @@ async def reset():
|
|||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="canvas not found.")
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
|
||||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id)
|
||||||
canvas.reset()
|
canvas.reset()
|
||||||
req["dsl"] = json.loads(str(canvas))
|
req["dsl"] = json.loads(str(canvas))
|
||||||
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
|
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
|
||||||
@ -270,7 +274,7 @@ def input_form():
|
|||||||
data=False, message='Only owner of canvas authorized for this operation.',
|
data=False, message='Only owner of canvas authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id)
|
||||||
return get_json_result(data=canvas.get_component_input_form(cpn_id))
|
return get_json_result(data=canvas.get_component_input_form(cpn_id))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -287,7 +291,7 @@ async def debug():
|
|||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
try:
|
try:
|
||||||
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id)
|
||||||
canvas.reset()
|
canvas.reset()
|
||||||
canvas.message_id = get_uuid()
|
canvas.message_id = get_uuid()
|
||||||
component = canvas.get_component(req["component_id"])["obj"]
|
component = canvas.get_component(req["component_id"])["obj"]
|
||||||
@ -540,6 +544,7 @@ def sessions(canvas_id):
|
|||||||
@login_required
|
@login_required
|
||||||
def prompts():
|
def prompts():
|
||||||
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
||||||
|
|
||||||
return get_json_result(data={
|
return get_json_result(data={
|
||||||
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
||||||
"plan_generation": NEXT_STEP,
|
"plan_generation": NEXT_STEP,
|
||||||
|
|||||||
@ -13,11 +13,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import asyncio
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import base64
|
|
||||||
import xxhash
|
import xxhash
|
||||||
from quart import request
|
from quart import request
|
||||||
|
|
||||||
@ -27,8 +27,14 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from common.metadata_utils import apply_meta_data_filter
|
from common.metadata_utils import apply_meta_data_filter
|
||||||
from api.db.services.search_service import SearchService
|
from api.db.services.search_service import SearchService
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
from api.utils.api_utils import (
|
||||||
get_request_json
|
get_data_error_result,
|
||||||
|
get_json_result,
|
||||||
|
server_error_response,
|
||||||
|
validate_request,
|
||||||
|
get_request_json,
|
||||||
|
)
|
||||||
|
from common.misc_utils import thread_pool_exec
|
||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
@ -38,7 +44,6 @@ from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
|
|||||||
from common import settings
|
from common import settings
|
||||||
from api.apps import login_required, current_user
|
from api.apps import login_required, current_user
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_id")
|
@validate_request("doc_id")
|
||||||
@ -61,7 +66,7 @@ async def list_chunk():
|
|||||||
}
|
}
|
||||||
if "available_int" in req:
|
if "available_int" in req:
|
||||||
query["available_int"] = int(req["available_int"])
|
query["available_int"] = int(req["available_int"])
|
||||||
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
|
sres = await settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
|
||||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||||
for id in sres.ids:
|
for id in sres.ids:
|
||||||
d = {
|
d = {
|
||||||
@ -76,6 +81,7 @@ async def list_chunk():
|
|||||||
"image_id": sres.field[id].get("img_id", ""),
|
"image_id": sres.field[id].get("img_id", ""),
|
||||||
"available_int": int(sres.field[id].get("available_int", 1)),
|
"available_int": int(sres.field[id].get("available_int", 1)),
|
||||||
"positions": sres.field[id].get("position_int", []),
|
"positions": sres.field[id].get("position_int", []),
|
||||||
|
"doc_type_kwd": sres.field[id].get("doc_type_kwd")
|
||||||
}
|
}
|
||||||
assert isinstance(d["positions"], list)
|
assert isinstance(d["positions"], list)
|
||||||
assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
|
assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
|
||||||
@ -125,10 +131,15 @@ def get():
|
|||||||
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
||||||
async def set():
|
async def set():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
|
content_with_weight = req["content_with_weight"]
|
||||||
|
if not isinstance(content_with_weight, (str, bytes)):
|
||||||
|
raise TypeError("expected string or bytes-like object")
|
||||||
|
if isinstance(content_with_weight, bytes):
|
||||||
|
content_with_weight = content_with_weight.decode("utf-8", errors="ignore")
|
||||||
d = {
|
d = {
|
||||||
"id": req["chunk_id"],
|
"id": req["chunk_id"],
|
||||||
"content_with_weight": req["content_with_weight"]}
|
"content_with_weight": content_with_weight}
|
||||||
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
|
d["content_ltks"] = rag_tokenizer.tokenize(content_with_weight)
|
||||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||||
if "important_kwd" in req:
|
if "important_kwd" in req:
|
||||||
if not isinstance(req["important_kwd"], list):
|
if not isinstance(req["important_kwd"], list):
|
||||||
@ -170,19 +181,21 @@ async def set():
|
|||||||
_d = beAdoc(d, q, a, not any(
|
_d = beAdoc(d, q, a, not any(
|
||||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||||
|
|
||||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
|
v, c = embd_mdl.encode([doc.name, content_with_weight if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||||
_d["q_%d_vec" % len(v)] = v.tolist()
|
_d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
# update image
|
# update image
|
||||||
image_base64 = req.get("image_base64", None)
|
image_base64 = req.get("image_base64", None)
|
||||||
if image_base64:
|
img_id = req.get("img_id", "")
|
||||||
|
if image_base64 and img_id and "-" in img_id:
|
||||||
|
bkt, name = img_id.split("-", 1)
|
||||||
image_binary = base64.b64decode(image_base64)
|
image_binary = base64.b64decode(image_base64)
|
||||||
settings.STORAGE_IMPL.put(req["doc_id"], req["chunk_id"], image_binary)
|
settings.STORAGE_IMPL.put(bkt, name, image_binary)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
return await asyncio.to_thread(_set_sync)
|
return await thread_pool_exec(_set_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -205,7 +218,7 @@ async def switch():
|
|||||||
return get_data_error_result(message="Index updating failure")
|
return get_data_error_result(message="Index updating failure")
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
return await asyncio.to_thread(_switch_sync)
|
return await thread_pool_exec(_switch_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -220,19 +233,34 @@ async def rm():
|
|||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
condition = {"id": req["chunk_ids"], "doc_id": req["doc_id"]}
|
||||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
try:
|
||||||
doc.kb_id):
|
deleted_count = settings.docStoreConn.delete(condition,
|
||||||
|
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||||
|
doc.kb_id)
|
||||||
|
except Exception:
|
||||||
return get_data_error_result(message="Chunk deleting failure")
|
return get_data_error_result(message="Chunk deleting failure")
|
||||||
deleted_chunk_ids = req["chunk_ids"]
|
deleted_chunk_ids = req["chunk_ids"]
|
||||||
chunk_number = len(deleted_chunk_ids)
|
if isinstance(deleted_chunk_ids, list):
|
||||||
|
unique_chunk_ids = list(dict.fromkeys(deleted_chunk_ids))
|
||||||
|
has_ids = len(unique_chunk_ids) > 0
|
||||||
|
else:
|
||||||
|
unique_chunk_ids = [deleted_chunk_ids]
|
||||||
|
has_ids = deleted_chunk_ids not in (None, "")
|
||||||
|
if has_ids and deleted_count == 0:
|
||||||
|
return get_data_error_result(message="Index updating failure")
|
||||||
|
if deleted_count > 0 and deleted_count < len(unique_chunk_ids):
|
||||||
|
deleted_count += settings.docStoreConn.delete({"doc_id": req["doc_id"]},
|
||||||
|
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||||
|
doc.kb_id)
|
||||||
|
chunk_number = deleted_count
|
||||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||||
for cid in deleted_chunk_ids:
|
for cid in deleted_chunk_ids:
|
||||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
return await asyncio.to_thread(_rm_sync)
|
return await thread_pool_exec(_rm_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -242,6 +270,7 @@ async def rm():
|
|||||||
@validate_request("doc_id", "content_with_weight")
|
@validate_request("doc_id", "content_with_weight")
|
||||||
async def create():
|
async def create():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
|
req_id = request.headers.get("X-Request-ID")
|
||||||
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
|
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
|
||||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
||||||
"content_with_weight": req["content_with_weight"]}
|
"content_with_weight": req["content_with_weight"]}
|
||||||
@ -258,14 +287,23 @@ async def create():
|
|||||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||||
if "tag_feas" in req:
|
if "tag_feas" in req:
|
||||||
d["tag_feas"] = req["tag_feas"]
|
d["tag_feas"] = req["tag_feas"]
|
||||||
if "tag_feas" in req:
|
|
||||||
d["tag_feas"] = req["tag_feas"]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
def _log_response(resp, code, message):
|
||||||
|
logging.info(
|
||||||
|
"chunk_create response req_id=%s status=%s code=%s message=%s",
|
||||||
|
req_id,
|
||||||
|
getattr(resp, "status_code", None),
|
||||||
|
code,
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
|
||||||
def _create_sync():
|
def _create_sync():
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
resp = get_data_error_result(message="Document not found!")
|
||||||
|
_log_response(resp, RetCode.DATA_ERROR, "Document not found!")
|
||||||
|
return resp
|
||||||
d["kb_id"] = [doc.kb_id]
|
d["kb_id"] = [doc.kb_id]
|
||||||
d["docnm_kwd"] = doc.name
|
d["docnm_kwd"] = doc.name
|
||||||
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
||||||
@ -273,11 +311,15 @@ async def create():
|
|||||||
|
|
||||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
resp = get_data_error_result(message="Tenant not found!")
|
||||||
|
_log_response(resp, RetCode.DATA_ERROR, "Tenant not found!")
|
||||||
|
return resp
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Knowledgebase not found!")
|
resp = get_data_error_result(message="Knowledgebase not found!")
|
||||||
|
_log_response(resp, RetCode.DATA_ERROR, "Knowledgebase not found!")
|
||||||
|
return resp
|
||||||
if kb.pagerank:
|
if kb.pagerank:
|
||||||
d[PAGERANK_FLD] = kb.pagerank
|
d[PAGERANK_FLD] = kb.pagerank
|
||||||
|
|
||||||
@ -291,10 +333,13 @@ async def create():
|
|||||||
|
|
||||||
DocumentService.increment_chunk_num(
|
DocumentService.increment_chunk_num(
|
||||||
doc.id, doc.kb_id, c, 1, 0)
|
doc.id, doc.kb_id, c, 1, 0)
|
||||||
return get_json_result(data={"chunk_id": chunck_id})
|
resp = get_json_result(data={"chunk_id": chunck_id})
|
||||||
|
_log_response(resp, RetCode.SUCCESS, "success")
|
||||||
|
return resp
|
||||||
|
|
||||||
return await asyncio.to_thread(_create_sync)
|
return await thread_pool_exec(_create_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logging.info("chunk_create exception req_id=%s error=%r", req_id, e)
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@ -370,16 +415,23 @@ async def retrieval_test():
|
|||||||
_question += await keyword_extraction(chat_mdl, _question)
|
_question += await keyword_extraction(chat_mdl, _question)
|
||||||
|
|
||||||
labels = label_question(_question, [kb])
|
labels = label_question(_question, [kb])
|
||||||
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
|
ranks = await settings.retriever.retrieval(
|
||||||
float(req.get("similarity_threshold", 0.0)),
|
_question,
|
||||||
float(req.get("vector_similarity_weight", 0.3)),
|
embd_mdl,
|
||||||
top,
|
tenant_ids,
|
||||||
local_doc_ids, rerank_mdl=rerank_mdl,
|
kb_ids,
|
||||||
highlight=req.get("highlight", False),
|
page,
|
||||||
rank_feature=labels
|
size,
|
||||||
)
|
float(req.get("similarity_threshold", 0.0)),
|
||||||
|
float(req.get("vector_similarity_weight", 0.3)),
|
||||||
|
doc_ids=local_doc_ids,
|
||||||
|
top=top,
|
||||||
|
rerank_mdl=rerank_mdl,
|
||||||
|
rank_feature=labels
|
||||||
|
)
|
||||||
|
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(_question,
|
ck = await settings.kg_retriever.retrieval(_question,
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
kb_ids,
|
kb_ids,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
@ -405,7 +457,7 @@ async def retrieval_test():
|
|||||||
|
|
||||||
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
|
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def knowledge_graph():
|
async def knowledge_graph():
|
||||||
doc_id = request.args["doc_id"]
|
doc_id = request.args["doc_id"]
|
||||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||||
@ -413,7 +465,7 @@ def knowledge_graph():
|
|||||||
"doc_ids": [doc_id],
|
"doc_ids": [doc_id],
|
||||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||||
}
|
}
|
||||||
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
sres = await settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
for id in sres.ids[:2]:
|
for id in sres.ids[:2]:
|
||||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, get_requ
|
|||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
from api.apps import login_required, current_user
|
from api.apps import login_required, current_user
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||||
@ -42,13 +43,19 @@ async def set_dialog():
|
|||||||
if len(name.encode("utf-8")) > 255:
|
if len(name.encode("utf-8")) > 255:
|
||||||
return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255")
|
return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255")
|
||||||
|
|
||||||
if is_create and DialogService.query(tenant_id=current_user.id, name=name.strip()):
|
name = name.strip()
|
||||||
name = name.strip()
|
if is_create:
|
||||||
name = duplicate_name(
|
# only for chat creating
|
||||||
DialogService.query,
|
existing_names = {
|
||||||
name=name,
|
d.name.casefold()
|
||||||
tenant_id=current_user.id,
|
for d in DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value)
|
||||||
status=StatusEnum.VALID.value)
|
if d.name
|
||||||
|
}
|
||||||
|
if name.casefold() in existing_names:
|
||||||
|
def _name_exists(name: str, **_kwargs) -> bool:
|
||||||
|
return name.casefold() in existing_names
|
||||||
|
|
||||||
|
name = duplicate_name(_name_exists, name=name)
|
||||||
|
|
||||||
description = req.get("description", "A helpful dialog")
|
description = req.get("description", "A helpful dialog")
|
||||||
icon = req.get("icon", "")
|
icon = req.get("icon", "")
|
||||||
@ -63,16 +70,30 @@ async def set_dialog():
|
|||||||
meta_data_filter = req.get("meta_data_filter", {})
|
meta_data_filter = req.get("meta_data_filter", {})
|
||||||
prompt_config = req["prompt_config"]
|
prompt_config = req["prompt_config"]
|
||||||
|
|
||||||
|
# Set default parameters for datasets with knowledge retrieval
|
||||||
|
# All datasets with {knowledge} in system prompt need "knowledge" parameter to enable retrieval
|
||||||
|
kb_ids = req.get("kb_ids", [])
|
||||||
|
parameters = prompt_config.get("parameters")
|
||||||
|
logging.debug(f"set_dialog: kb_ids={kb_ids}, parameters={parameters}, is_create={not is_create}")
|
||||||
|
# Check if parameters is missing, None, or empty list
|
||||||
|
if kb_ids and not parameters:
|
||||||
|
# Check if system prompt uses {knowledge} placeholder
|
||||||
|
if "{knowledge}" in prompt_config.get("system", ""):
|
||||||
|
# Set default parameters for any dataset with knowledge placeholder
|
||||||
|
prompt_config["parameters"] = [{"key": "knowledge", "optional": False}]
|
||||||
|
logging.debug(f"Set default parameters for datasets with knowledge placeholder: {kb_ids}")
|
||||||
|
|
||||||
if not is_create:
|
if not is_create:
|
||||||
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']:
|
# only for chat updating
|
||||||
|
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config.get("system", ""):
|
||||||
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.")
|
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.")
|
||||||
|
|
||||||
for p in prompt_config["parameters"]:
|
for p in prompt_config.get("parameters", []):
|
||||||
if p["optional"]:
|
if p["optional"]:
|
||||||
continue
|
continue
|
||||||
if prompt_config["system"].find("{%s}" % p["key"]) < 0:
|
if prompt_config.get("system", "").find("{%s}" % p["key"]) < 0:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Parameter '{}' is not used".format(p["key"]))
|
message="Parameter '{}' is not used".format(p["key"]))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
e, tenant = TenantService.get_by_id(current_user.id)
|
e, tenant = TenantService.get_by_id(current_user.id)
|
||||||
|
|||||||
@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License
|
# limitations under the License
|
||||||
#
|
#
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
import pathlib
|
import pathlib
|
||||||
@ -27,18 +26,19 @@ from api.db import VALID_FILE_TYPES, FileType
|
|||||||
from api.db.db_models import Task
|
from api.db.db_models import Task
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||||
from common.metadata_utils import meta_filter, convert_conditions
|
from common.metadata_utils import meta_filter, convert_conditions, turn2jsonschema
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.task_service import TaskService, cancel_all_task_of
|
from api.db.services.task_service import TaskService, cancel_all_task_of
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid, thread_pool_exec
|
||||||
from api.utils.api_utils import (
|
from api.utils.api_utils import (
|
||||||
get_data_error_result,
|
get_data_error_result,
|
||||||
get_json_result,
|
get_json_result,
|
||||||
server_error_response,
|
server_error_response,
|
||||||
validate_request, get_request_json,
|
validate_request,
|
||||||
|
get_request_json,
|
||||||
)
|
)
|
||||||
from api.utils.file_utils import filename_type, thumbnail
|
from api.utils.file_utils import filename_type, thumbnail
|
||||||
from common.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
@ -62,10 +62,21 @@ async def upload():
|
|||||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
file_objs = files.getlist("file")
|
file_objs = files.getlist("file")
|
||||||
|
def _close_file_objs(objs):
|
||||||
|
for obj in objs:
|
||||||
|
try:
|
||||||
|
obj.close()
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
obj.stream.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
if file_obj.filename == "":
|
if file_obj.filename == "":
|
||||||
|
_close_file_objs(file_objs)
|
||||||
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||||
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||||
|
_close_file_objs(file_objs)
|
||||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
@ -74,8 +85,9 @@ async def upload():
|
|||||||
if not check_kb_team_permission(kb, current_user.id):
|
if not check_kb_team_permission(kb, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
|
err, files = await thread_pool_exec(FileService.upload_document, kb, file_objs, current_user.id)
|
||||||
if err:
|
if err:
|
||||||
|
files = [f[0] for f in files] if files else []
|
||||||
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
if not files:
|
if not files:
|
||||||
@ -214,6 +226,7 @@ async def list_docs():
|
|||||||
kb_id = request.args.get("kb_id")
|
kb_id = request.args.get("kb_id")
|
||||||
if not kb_id:
|
if not kb_id:
|
||||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||||
@ -234,6 +247,10 @@ async def list_docs():
|
|||||||
|
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
|
|
||||||
|
return_empty_metadata = req.get("return_empty_metadata", False)
|
||||||
|
if isinstance(return_empty_metadata, str):
|
||||||
|
return_empty_metadata = return_empty_metadata.lower() == "true"
|
||||||
|
|
||||||
run_status = req.get("run_status", [])
|
run_status = req.get("run_status", [])
|
||||||
if run_status:
|
if run_status:
|
||||||
invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS}
|
invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS}
|
||||||
@ -248,18 +265,73 @@ async def list_docs():
|
|||||||
|
|
||||||
suffix = req.get("suffix", [])
|
suffix = req.get("suffix", [])
|
||||||
metadata_condition = req.get("metadata_condition", {}) or {}
|
metadata_condition = req.get("metadata_condition", {}) or {}
|
||||||
if metadata_condition and not isinstance(metadata_condition, dict):
|
metadata = req.get("metadata", {}) or {}
|
||||||
return get_data_error_result(message="metadata_condition must be an object.")
|
if isinstance(metadata, dict) and metadata.get("empty_metadata"):
|
||||||
|
return_empty_metadata = True
|
||||||
|
metadata = {k: v for k, v in metadata.items() if k != "empty_metadata"}
|
||||||
|
if return_empty_metadata:
|
||||||
|
metadata_condition = {}
|
||||||
|
metadata = {}
|
||||||
|
else:
|
||||||
|
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||||
|
return get_data_error_result(message="metadata_condition must be an object.")
|
||||||
|
if metadata and not isinstance(metadata, dict):
|
||||||
|
return get_data_error_result(message="metadata must be an object.")
|
||||||
|
|
||||||
doc_ids_filter = None
|
doc_ids_filter = None
|
||||||
if metadata_condition:
|
metas = None
|
||||||
|
if metadata_condition or metadata:
|
||||||
metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
|
metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
|
||||||
doc_ids_filter = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
|
|
||||||
|
if metadata_condition:
|
||||||
|
doc_ids_filter = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||||
if metadata_condition.get("conditions") and not doc_ids_filter:
|
if metadata_condition.get("conditions") and not doc_ids_filter:
|
||||||
return get_json_result(data={"total": 0, "docs": []})
|
return get_json_result(data={"total": 0, "docs": []})
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
metadata_doc_ids = None
|
||||||
|
for key, values in metadata.items():
|
||||||
|
if not values:
|
||||||
|
continue
|
||||||
|
if not isinstance(values, list):
|
||||||
|
values = [values]
|
||||||
|
values = [str(v) for v in values if v is not None and str(v).strip()]
|
||||||
|
if not values:
|
||||||
|
continue
|
||||||
|
key_doc_ids = set()
|
||||||
|
for value in values:
|
||||||
|
key_doc_ids.update(metas.get(key, {}).get(value, []))
|
||||||
|
if metadata_doc_ids is None:
|
||||||
|
metadata_doc_ids = key_doc_ids
|
||||||
|
else:
|
||||||
|
metadata_doc_ids &= key_doc_ids
|
||||||
|
if not metadata_doc_ids:
|
||||||
|
return get_json_result(data={"total": 0, "docs": []})
|
||||||
|
if metadata_doc_ids is not None:
|
||||||
|
if doc_ids_filter is None:
|
||||||
|
doc_ids_filter = metadata_doc_ids
|
||||||
|
else:
|
||||||
|
doc_ids_filter &= metadata_doc_ids
|
||||||
|
if not doc_ids_filter:
|
||||||
|
return get_json_result(data={"total": 0, "docs": []})
|
||||||
|
|
||||||
|
if doc_ids_filter is not None:
|
||||||
|
doc_ids_filter = list(doc_ids_filter)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, doc_ids_filter)
|
docs, tol = DocumentService.get_by_kb_id(
|
||||||
|
kb_id,
|
||||||
|
page_number,
|
||||||
|
items_per_page,
|
||||||
|
orderby,
|
||||||
|
desc,
|
||||||
|
keywords,
|
||||||
|
run_status,
|
||||||
|
types,
|
||||||
|
suffix,
|
||||||
|
doc_ids_filter,
|
||||||
|
return_empty_metadata=return_empty_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
if create_time_from or create_time_to:
|
if create_time_from or create_time_to:
|
||||||
filtered_docs = []
|
filtered_docs = []
|
||||||
@ -274,6 +346,8 @@ async def list_docs():
|
|||||||
doc_item["thumbnail"] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}"
|
doc_item["thumbnail"] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}"
|
||||||
if doc_item.get("source_type"):
|
if doc_item.get("source_type"):
|
||||||
doc_item["source_type"] = doc_item["source_type"].split("/")[0]
|
doc_item["source_type"] = doc_item["source_type"].split("/")[0]
|
||||||
|
if doc_item["parser_config"].get("metadata"):
|
||||||
|
doc_item["parser_config"]["metadata"] = turn2jsonschema(doc_item["parser_config"]["metadata"])
|
||||||
|
|
||||||
return get_json_result(data={"total": tol, "docs": docs})
|
return get_json_result(data={"total": tol, "docs": docs})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -335,6 +409,7 @@ async def doc_infos():
|
|||||||
async def metadata_summary():
|
async def metadata_summary():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
kb_id = req.get("kb_id")
|
kb_id = req.get("kb_id")
|
||||||
|
doc_ids = req.get("doc_ids")
|
||||||
if not kb_id:
|
if not kb_id:
|
||||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
@ -346,7 +421,7 @@ async def metadata_summary():
|
|||||||
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
summary = DocumentService.get_metadata_summary(kb_id)
|
summary = DocumentService.get_metadata_summary(kb_id, doc_ids)
|
||||||
return get_json_result(data={"summary": summary})
|
return get_json_result(data={"summary": summary})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -354,36 +429,16 @@ async def metadata_summary():
|
|||||||
|
|
||||||
@manager.route("/metadata/update", methods=["POST"]) # noqa: F821
|
@manager.route("/metadata/update", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
|
@validate_request("doc_ids")
|
||||||
async def metadata_update():
|
async def metadata_update():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
kb_id = req.get("kb_id")
|
document_ids = req.get("doc_ids")
|
||||||
if not kb_id:
|
|
||||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
|
||||||
|
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
|
||||||
for tenant in tenants:
|
|
||||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
|
||||||
|
|
||||||
selector = req.get("selector", {}) or {}
|
|
||||||
updates = req.get("updates", []) or []
|
updates = req.get("updates", []) or []
|
||||||
deletes = req.get("deletes", []) or []
|
deletes = req.get("deletes", []) or []
|
||||||
|
|
||||||
if not isinstance(selector, dict):
|
|
||||||
return get_json_result(data=False, message="selector must be an object.", code=RetCode.ARGUMENT_ERROR)
|
|
||||||
if not isinstance(updates, list) or not isinstance(deletes, list):
|
if not isinstance(updates, list) or not isinstance(deletes, list):
|
||||||
return get_json_result(data=False, message="updates and deletes must be lists.", code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message="updates and deletes must be lists.", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
metadata_condition = selector.get("metadata_condition", {}) or {}
|
|
||||||
if metadata_condition and not isinstance(metadata_condition, dict):
|
|
||||||
return get_json_result(data=False, message="metadata_condition must be an object.", code=RetCode.ARGUMENT_ERROR)
|
|
||||||
|
|
||||||
document_ids = selector.get("document_ids", []) or []
|
|
||||||
if document_ids and not isinstance(document_ids, list):
|
|
||||||
return get_json_result(data=False, message="document_ids must be a list.", code=RetCode.ARGUMENT_ERROR)
|
|
||||||
|
|
||||||
for upd in updates:
|
for upd in updates:
|
||||||
if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd:
|
if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd:
|
||||||
return get_json_result(data=False, message="Each update requires key and value.", code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message="Each update requires key and value.", code=RetCode.ARGUMENT_ERROR)
|
||||||
@ -391,24 +446,28 @@ async def metadata_update():
|
|||||||
if not isinstance(d, dict) or not d.get("key"):
|
if not isinstance(d, dict) or not d.get("key"):
|
||||||
return get_json_result(data=False, message="Each delete requires key.", code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message="Each delete requires key.", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([kb_id])
|
updated = DocumentService.batch_update_metadata(None, document_ids, updates, deletes)
|
||||||
target_doc_ids = set(kb_doc_ids)
|
return get_json_result(data={"updated": updated})
|
||||||
if document_ids:
|
|
||||||
invalid_ids = set(document_ids) - set(kb_doc_ids)
|
|
||||||
if invalid_ids:
|
|
||||||
return get_json_result(data=False, message=f"These documents do not belong to dataset {kb_id}: {', '.join(invalid_ids)}", code=RetCode.ARGUMENT_ERROR)
|
|
||||||
target_doc_ids = set(document_ids)
|
|
||||||
|
|
||||||
if metadata_condition:
|
|
||||||
metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
|
|
||||||
filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
|
||||||
target_doc_ids = target_doc_ids & filtered_ids
|
|
||||||
if metadata_condition.get("conditions") and not target_doc_ids:
|
|
||||||
return get_json_result(data={"updated": 0, "matched_docs": 0})
|
|
||||||
|
|
||||||
target_doc_ids = list(target_doc_ids)
|
@manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821
|
||||||
updated = DocumentService.batch_update_metadata(kb_id, target_doc_ids, updates, deletes)
|
@login_required
|
||||||
return get_json_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
|
@validate_request("doc_id", "metadata")
|
||||||
|
async def update_metadata_setting():
|
||||||
|
req = await get_request_json()
|
||||||
|
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||||
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="Document not found!")
|
||||||
|
|
||||||
|
DocumentService.update_parser_config(doc.id, {"metadata": req["metadata"]})
|
||||||
|
e, doc = DocumentService.get_by_id(doc.id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="Document not found!")
|
||||||
|
|
||||||
|
return get_json_result(data=doc.to_dict())
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
|
@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
|
||||||
@ -442,31 +501,61 @@ async def change_status():
|
|||||||
return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
|
has_error = False
|
||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
if not DocumentService.accessible(doc_id, current_user.id):
|
if not DocumentService.accessible(doc_id, current_user.id):
|
||||||
result[doc_id] = {"error": "No authorization."}
|
result[doc_id] = {"error": "No authorization."}
|
||||||
|
has_error = True
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
e, doc = DocumentService.get_by_id(doc_id)
|
e, doc = DocumentService.get_by_id(doc_id)
|
||||||
if not e:
|
if not e:
|
||||||
result[doc_id] = {"error": "No authorization."}
|
result[doc_id] = {"error": "No authorization."}
|
||||||
|
has_error = True
|
||||||
continue
|
continue
|
||||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
result[doc_id] = {"error": "Can't find this dataset!"}
|
result[doc_id] = {"error": "Can't find this dataset!"}
|
||||||
|
has_error = True
|
||||||
|
continue
|
||||||
|
current_status = str(doc.status)
|
||||||
|
if current_status == status:
|
||||||
|
result[doc_id] = {"status": status}
|
||||||
continue
|
continue
|
||||||
if not DocumentService.update_by_id(doc_id, {"status": str(status)}):
|
if not DocumentService.update_by_id(doc_id, {"status": str(status)}):
|
||||||
result[doc_id] = {"error": "Database error (Document update)!"}
|
result[doc_id] = {"error": "Database error (Document update)!"}
|
||||||
|
has_error = True
|
||||||
continue
|
continue
|
||||||
|
|
||||||
status_int = int(status)
|
status_int = int(status)
|
||||||
if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
|
if getattr(doc, "chunk_num", 0) > 0:
|
||||||
result[doc_id] = {"error": "Database error (docStore update)!"}
|
try:
|
||||||
|
ok = settings.docStoreConn.update(
|
||||||
|
{"doc_id": doc_id},
|
||||||
|
{"available_int": status_int},
|
||||||
|
search.index_name(kb.tenant_id),
|
||||||
|
doc.kb_id,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
msg = str(exc)
|
||||||
|
if "3022" in msg:
|
||||||
|
result[doc_id] = {"error": "Document store table missing."}
|
||||||
|
else:
|
||||||
|
result[doc_id] = {"error": f"Document store update failed: {msg}"}
|
||||||
|
has_error = True
|
||||||
|
continue
|
||||||
|
if not ok:
|
||||||
|
result[doc_id] = {"error": "Database error (docStore update)!"}
|
||||||
|
has_error = True
|
||||||
|
continue
|
||||||
result[doc_id] = {"status": status}
|
result[doc_id] = {"status": status}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result[doc_id] = {"error": f"Internal server error: {str(e)}"}
|
result[doc_id] = {"error": f"Internal server error: {str(e)}"}
|
||||||
|
has_error = True
|
||||||
|
|
||||||
|
if has_error:
|
||||||
|
return get_json_result(data=result, message="Partial failure", code=RetCode.SERVER_ERROR)
|
||||||
return get_json_result(data=result)
|
return get_json_result(data=result)
|
||||||
|
|
||||||
|
|
||||||
@ -483,7 +572,7 @@ async def rm():
|
|||||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
|
errors = await thread_pool_exec(FileService.delete_docs, doc_ids, current_user.id)
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||||
@ -496,10 +585,11 @@ async def rm():
|
|||||||
@validate_request("doc_ids", "run")
|
@validate_request("doc_ids", "run")
|
||||||
async def run():
|
async def run():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
|
uid = current_user.id
|
||||||
try:
|
try:
|
||||||
def _run_sync():
|
def _run_sync():
|
||||||
for doc_id in req["doc_ids"]:
|
for doc_id in req["doc_ids"]:
|
||||||
if not DocumentService.accessible(doc_id, current_user.id):
|
if not DocumentService.accessible(doc_id, uid):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
kb_table_num_map = {}
|
kb_table_num_map = {}
|
||||||
@ -528,16 +618,24 @@ async def run():
|
|||||||
DocumentService.update_by_id(id, info)
|
DocumentService.update_by_id(id, info)
|
||||||
if req.get("delete", False):
|
if req.get("delete", False):
|
||||||
TaskService.filter_delete([Task.doc_id == id])
|
TaskService.filter_delete([Task.doc_id == id])
|
||||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
|
||||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||||
|
if req.get("apply_kb"):
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||||
|
if not e:
|
||||||
|
raise LookupError("Can't find this dataset!")
|
||||||
|
doc.parser_config["llm_id"] = kb.parser_config.get("llm_id")
|
||||||
|
doc.parser_config["enable_metadata"] = kb.parser_config.get("enable_metadata", False)
|
||||||
|
doc.parser_config["metadata"] = kb.parser_config.get("metadata", {})
|
||||||
|
DocumentService.update_parser_config(doc.id, doc.parser_config)
|
||||||
doc_dict = doc.to_dict()
|
doc_dict = doc.to_dict()
|
||||||
DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
|
DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
return await asyncio.to_thread(_run_sync)
|
return await thread_pool_exec(_run_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -547,9 +645,10 @@ async def run():
|
|||||||
@validate_request("doc_id", "name")
|
@validate_request("doc_id", "name")
|
||||||
async def rename():
|
async def rename():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
|
uid = current_user.id
|
||||||
try:
|
try:
|
||||||
def _rename_sync():
|
def _rename_sync():
|
||||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
if not DocumentService.accessible(req["doc_id"], uid):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
@ -579,7 +678,7 @@ async def rename():
|
|||||||
"title_tks": title_tks,
|
"title_tks": title_tks,
|
||||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||||
}
|
}
|
||||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
|
||||||
settings.docStoreConn.update(
|
settings.docStoreConn.update(
|
||||||
{"doc_id": req["doc_id"]},
|
{"doc_id": req["doc_id"]},
|
||||||
es_body,
|
es_body,
|
||||||
@ -588,7 +687,7 @@ async def rename():
|
|||||||
)
|
)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
return await asyncio.to_thread(_rename_sync)
|
return await thread_pool_exec(_rename_sync)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -603,7 +702,7 @@ async def get(doc_id):
|
|||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
|
|
||||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
data = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
|
||||||
response = await make_response(data)
|
response = await make_response(data)
|
||||||
|
|
||||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||||
@ -625,7 +724,7 @@ async def get(doc_id):
|
|||||||
async def download_attachment(attachment_id):
|
async def download_attachment(attachment_id):
|
||||||
try:
|
try:
|
||||||
ext = request.args.get("ext", "markdown")
|
ext = request.args.get("ext", "markdown")
|
||||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
|
data = await thread_pool_exec(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
|
||||||
response = await make_response(data)
|
response = await make_response(data)
|
||||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||||
|
|
||||||
@ -660,7 +759,8 @@ async def change_parser():
|
|||||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
DocumentService.delete_chunk_images(doc, tenant_id)
|
||||||
|
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
|
||||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -697,7 +797,7 @@ async def get_image(image_id):
|
|||||||
if len(arr) != 2:
|
if len(arr) != 2:
|
||||||
return get_data_error_result(message="Image not found.")
|
return get_data_error_result(message="Image not found.")
|
||||||
bkt, nm = image_id.split("-")
|
bkt, nm = image_id.split("-")
|
||||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
|
data = await thread_pool_exec(settings.STORAGE_IMPL.get, bkt, nm)
|
||||||
response = await make_response(data)
|
response = await make_response(data)
|
||||||
response.headers.set("Content-Type", "image/JPEG")
|
response.headers.set("Content-Type", "image/JPEG")
|
||||||
return response
|
return response
|
||||||
|
|||||||
@ -14,7 +14,6 @@
|
|||||||
# limitations under the License
|
# limitations under the License
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
@ -25,7 +24,7 @@ from api.common.check_team_permission import check_file_team_permission
|
|||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid, thread_pool_exec
|
||||||
from common.constants import RetCode, FileSource
|
from common.constants import RetCode, FileSource
|
||||||
from api.db import FileType
|
from api.db import FileType
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
@ -35,7 +34,6 @@ from api.utils.file_utils import filename_type
|
|||||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||||
from common import settings
|
from common import settings
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/upload', methods=['POST']) # noqa: F821
|
@manager.route('/upload', methods=['POST']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
# @validate_request("parent_id")
|
# @validate_request("parent_id")
|
||||||
@ -65,7 +63,7 @@ async def upload():
|
|||||||
|
|
||||||
async def _handle_single_file(file_obj):
|
async def _handle_single_file(file_obj):
|
||||||
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||||
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
|
if 0 < MAX_FILE_NUM_PER_USER <= await thread_pool_exec(DocumentService.get_doc_count, current_user.id):
|
||||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
||||||
|
|
||||||
# split file name path
|
# split file name path
|
||||||
@ -77,35 +75,35 @@ async def upload():
|
|||||||
file_len = len(file_obj_names)
|
file_len = len(file_obj_names)
|
||||||
|
|
||||||
# get folder
|
# get folder
|
||||||
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
|
file_id_list = await thread_pool_exec(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
|
||||||
len_id_list = len(file_id_list)
|
len_id_list = len(file_id_list)
|
||||||
|
|
||||||
# create folder
|
# create folder
|
||||||
if file_len != len_id_list:
|
if file_len != len_id_list:
|
||||||
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
|
e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 1])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Folder not found!")
|
return get_data_error_result(message="Folder not found!")
|
||||||
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
|
last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
|
||||||
len_id_list)
|
len_id_list)
|
||||||
else:
|
else:
|
||||||
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
|
e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 2])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Folder not found!")
|
return get_data_error_result(message="Folder not found!")
|
||||||
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
|
last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
|
||||||
len_id_list)
|
len_id_list)
|
||||||
|
|
||||||
# file type
|
# file type
|
||||||
filetype = filename_type(file_obj_names[file_len - 1])
|
filetype = filename_type(file_obj_names[file_len - 1])
|
||||||
location = file_obj_names[file_len - 1]
|
location = file_obj_names[file_len - 1]
|
||||||
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
|
while await thread_pool_exec(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
|
||||||
location += "_"
|
location += "_"
|
||||||
blob = await asyncio.to_thread(file_obj.read)
|
blob = await thread_pool_exec(file_obj.read)
|
||||||
filename = await asyncio.to_thread(
|
filename = await thread_pool_exec(
|
||||||
duplicate_name,
|
duplicate_name,
|
||||||
FileService.query,
|
FileService.query,
|
||||||
name=file_obj_names[file_len - 1],
|
name=file_obj_names[file_len - 1],
|
||||||
parent_id=last_folder.id)
|
parent_id=last_folder.id)
|
||||||
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
|
await thread_pool_exec(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
|
||||||
file_data = {
|
file_data = {
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"parent_id": last_folder.id,
|
"parent_id": last_folder.id,
|
||||||
@ -116,7 +114,7 @@ async def upload():
|
|||||||
"location": location,
|
"location": location,
|
||||||
"size": len(blob),
|
"size": len(blob),
|
||||||
}
|
}
|
||||||
inserted = await asyncio.to_thread(FileService.insert, file_data)
|
inserted = await thread_pool_exec(FileService.insert, file_data)
|
||||||
return inserted.to_json()
|
return inserted.to_json()
|
||||||
|
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
@ -249,6 +247,7 @@ def get_all_parent_folders():
|
|||||||
async def rm():
|
async def rm():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
file_ids = req["file_ids"]
|
file_ids = req["file_ids"]
|
||||||
|
uid = current_user.id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
def _delete_single_file(file):
|
def _delete_single_file(file):
|
||||||
@ -287,21 +286,21 @@ async def rm():
|
|||||||
return get_data_error_result(message="File or Folder not found!")
|
return get_data_error_result(message="File or Folder not found!")
|
||||||
if not file.tenant_id:
|
if not file.tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
if not check_file_team_permission(file, current_user.id):
|
if not check_file_team_permission(file, uid):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if file.type == FileType.FOLDER.value:
|
if file.type == FileType.FOLDER.value:
|
||||||
_delete_folder_recursive(file, current_user.id)
|
_delete_folder_recursive(file, uid)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
_delete_single_file(file)
|
_delete_single_file(file)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
return await asyncio.to_thread(_rm_sync)
|
return await thread_pool_exec(_rm_sync)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -357,10 +356,10 @@ async def get(file_id):
|
|||||||
if not check_file_team_permission(file, current_user.id):
|
if not check_file_team_permission(file, current_user.id):
|
||||||
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
|
blob = await thread_pool_exec(settings.STORAGE_IMPL.get, file.parent_id, file.location)
|
||||||
if not blob:
|
if not blob:
|
||||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||||
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
blob = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
|
||||||
|
|
||||||
response = await make_response(blob)
|
response = await make_response(blob)
|
||||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||||
@ -460,7 +459,7 @@ async def move():
|
|||||||
_move_entry_recursive(file, dest_folder)
|
_move_entry_recursive(file, dest_folder)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
return await asyncio.to_thread(_move_sync)
|
return await thread_pool_exec(_move_sync)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|||||||
@ -17,8 +17,8 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
from common.metadata_utils import turn2jsonschema
|
||||||
from quart import request
|
from quart import request
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -30,8 +30,15 @@ from api.db.services.file_service import FileService
|
|||||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
|
from api.utils.api_utils import (
|
||||||
get_request_json
|
get_error_data_result,
|
||||||
|
server_error_response,
|
||||||
|
get_data_error_result,
|
||||||
|
validate_request,
|
||||||
|
not_allowed_parameters,
|
||||||
|
get_request_json,
|
||||||
|
)
|
||||||
|
from common.misc_utils import thread_pool_exec
|
||||||
from api.db import VALID_FILE_TYPES
|
from api.db import VALID_FILE_TYPES
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.db_models import File
|
from api.db.db_models import File
|
||||||
@ -39,12 +46,11 @@ from api.utils.api_utils import get_json_result
|
|||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
from api.constants import DATASET_NAME_LIMIT
|
from api.constants import DATASET_NAME_LIMIT
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from rag.utils.doc_store_conn import OrderByExpr
|
|
||||||
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
|
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
|
||||||
from common import settings
|
from common import settings
|
||||||
|
from common.doc_store.doc_store_base import OrderByExpr
|
||||||
from api.apps import login_required, current_user
|
from api.apps import login_required, current_user
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/create', methods=['post']) # noqa: F821
|
@manager.route('/create', methods=['post']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("name")
|
@validate_request("name")
|
||||||
@ -82,6 +88,20 @@ async def update():
|
|||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
|
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
|
||||||
req["name"] = req["name"].strip()
|
req["name"] = req["name"].strip()
|
||||||
|
if settings.DOC_ENGINE_INFINITY:
|
||||||
|
parser_id = req.get("parser_id")
|
||||||
|
if isinstance(parser_id, str) and parser_id.lower() == "tag":
|
||||||
|
return get_json_result(
|
||||||
|
code=RetCode.OPERATING_ERROR,
|
||||||
|
message="The chunking method Tag has not been supported by Infinity yet.",
|
||||||
|
data=False,
|
||||||
|
)
|
||||||
|
if "pagerank" in req and req["pagerank"] > 0:
|
||||||
|
return get_json_result(
|
||||||
|
code=RetCode.DATA_ERROR,
|
||||||
|
message="'pagerank' can only be set when doc_engine is elasticsearch",
|
||||||
|
data=False,
|
||||||
|
)
|
||||||
|
|
||||||
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
|
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
@ -97,6 +117,19 @@ async def update():
|
|||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||||
|
|
||||||
|
# Rename folder in FileService
|
||||||
|
if e and req["name"].lower() != kb.name.lower():
|
||||||
|
FileService.filter_update(
|
||||||
|
[
|
||||||
|
File.tenant_id == kb.tenant_id,
|
||||||
|
File.source_type == FileSource.KNOWLEDGEBASE,
|
||||||
|
File.type == "folder",
|
||||||
|
File.name == kb.name,
|
||||||
|
],
|
||||||
|
{"name": req["name"]},
|
||||||
|
)
|
||||||
|
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Can't find this dataset!")
|
message="Can't find this dataset!")
|
||||||
@ -117,7 +150,7 @@ async def update():
|
|||||||
|
|
||||||
if kb.pagerank != req.get("pagerank", 0):
|
if kb.pagerank != req.get("pagerank", 0):
|
||||||
if req.get("pagerank", 0) > 0:
|
if req.get("pagerank", 0) > 0:
|
||||||
await asyncio.to_thread(
|
await thread_pool_exec(
|
||||||
settings.docStoreConn.update,
|
settings.docStoreConn.update,
|
||||||
{"kb_id": kb.id},
|
{"kb_id": kb.id},
|
||||||
{PAGERANK_FLD: req["pagerank"]},
|
{PAGERANK_FLD: req["pagerank"]},
|
||||||
@ -126,7 +159,7 @@ async def update():
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||||
await asyncio.to_thread(
|
await thread_pool_exec(
|
||||||
settings.docStoreConn.update,
|
settings.docStoreConn.update,
|
||||||
{"exists": PAGERANK_FLD},
|
{"exists": PAGERANK_FLD},
|
||||||
{"remove": PAGERANK_FLD},
|
{"remove": PAGERANK_FLD},
|
||||||
@ -150,6 +183,22 @@ async def update():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/update_metadata_setting', methods=['post']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("kb_id", "metadata")
|
||||||
|
async def update_metadata_setting():
|
||||||
|
req = await get_request_json()
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(
|
||||||
|
message="Database error (Knowledgebase rename)!")
|
||||||
|
kb = kb.to_dict()
|
||||||
|
kb["parser_config"]["metadata"] = req["metadata"]
|
||||||
|
kb["parser_config"]["enable_metadata"] = req.get("enable_metadata", True)
|
||||||
|
KnowledgebaseService.update_by_id(kb["id"], kb)
|
||||||
|
return get_json_result(data=kb)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/detail', methods=['GET']) # noqa: F821
|
@manager.route('/detail', methods=['GET']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def detail():
|
def detail():
|
||||||
@ -170,6 +219,8 @@ def detail():
|
|||||||
message="Can't find this dataset!")
|
message="Can't find this dataset!")
|
||||||
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
|
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
|
||||||
kb["connectors"] = Connector2KbService.list_connectors(kb_id)
|
kb["connectors"] = Connector2KbService.list_connectors(kb_id)
|
||||||
|
if kb["parser_config"].get("metadata"):
|
||||||
|
kb["parser_config"]["metadata"] = turn2jsonschema(kb["parser_config"]["metadata"])
|
||||||
|
|
||||||
for key in ["graphrag_task_finish_at", "raptor_task_finish_at", "mindmap_task_finish_at"]:
|
for key in ["graphrag_task_finish_at", "raptor_task_finish_at", "mindmap_task_finish_at"]:
|
||||||
if finish_at := kb.get(key):
|
if finish_at := kb.get(key):
|
||||||
@ -221,7 +272,8 @@ async def list_kbs():
|
|||||||
@validate_request("kb_id")
|
@validate_request("kb_id")
|
||||||
async def rm():
|
async def rm():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
|
uid = current_user.id
|
||||||
|
if not KnowledgebaseService.accessible4deletion(req["kb_id"], uid):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message='No authorization.',
|
message='No authorization.',
|
||||||
@ -229,7 +281,7 @@ async def rm():
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
kbs = KnowledgebaseService.query(
|
kbs = KnowledgebaseService.query(
|
||||||
created_by=current_user.id, id=req["kb_id"])
|
created_by=uid, id=req["kb_id"])
|
||||||
if not kbs:
|
if not kbs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of dataset authorized for this operation.',
|
data=False, message='Only owner of dataset authorized for this operation.',
|
||||||
@ -245,18 +297,31 @@ async def rm():
|
|||||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||||
File2DocumentService.delete_by_document_id(doc.id)
|
File2DocumentService.delete_by_document_id(doc.id)
|
||||||
FileService.filter_delete(
|
FileService.filter_delete(
|
||||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
|
[
|
||||||
|
File.tenant_id == kbs[0].tenant_id,
|
||||||
|
File.source_type == FileSource.KNOWLEDGEBASE,
|
||||||
|
File.type == "folder",
|
||||||
|
File.name == kbs[0].name,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Delete the table BEFORE deleting the database record
|
||||||
|
for kb in kbs:
|
||||||
|
try:
|
||||||
|
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||||
|
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
|
||||||
|
logging.info(f"Dropped index for dataset {kb.id}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to drop index for dataset {kb.id}: {e}")
|
||||||
|
|
||||||
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Database error (Knowledgebase removal)!")
|
message="Database error (Knowledgebase removal)!")
|
||||||
for kb in kbs:
|
for kb in kbs:
|
||||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
|
||||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
|
||||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
return await asyncio.to_thread(_rm_sync)
|
return await thread_pool_exec(_rm_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -338,7 +403,7 @@ async def rename_tags(kb_id):
|
|||||||
|
|
||||||
@manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
@manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def knowledge_graph(kb_id):
|
async def knowledge_graph(kb_id):
|
||||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
@ -352,9 +417,9 @@ def knowledge_graph(kb_id):
|
|||||||
}
|
}
|
||||||
|
|
||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
|
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
|
||||||
return get_json_result(data=obj)
|
return get_json_result(data=obj)
|
||||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||||
if not len(sres.ids):
|
if not len(sres.ids):
|
||||||
return get_json_result(data=obj)
|
return get_json_result(data=obj)
|
||||||
|
|
||||||
@ -824,11 +889,11 @@ async def check_embedding():
|
|||||||
index_nm = search.index_name(tenant_id)
|
index_nm = search.index_name(tenant_id)
|
||||||
|
|
||||||
res0 = docStoreConn.search(
|
res0 = docStoreConn.search(
|
||||||
selectFields=[], highlightFields=[],
|
select_fields=[], highlight_fields=[],
|
||||||
condition={"kb_id": kb_id, "available_int": 1},
|
condition={"kb_id": kb_id, "available_int": 1},
|
||||||
matchExprs=[], orderBy=OrderByExpr(),
|
match_expressions=[], order_by=OrderByExpr(),
|
||||||
offset=0, limit=1,
|
offset=0, limit=1,
|
||||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
index_names=index_nm, knowledgebase_ids=[kb_id]
|
||||||
)
|
)
|
||||||
total = docStoreConn.get_total(res0)
|
total = docStoreConn.get_total(res0)
|
||||||
if total <= 0:
|
if total <= 0:
|
||||||
@ -840,14 +905,14 @@ async def check_embedding():
|
|||||||
|
|
||||||
for off in offsets:
|
for off in offsets:
|
||||||
res1 = docStoreConn.search(
|
res1 = docStoreConn.search(
|
||||||
selectFields=list(base_fields),
|
select_fields=list(base_fields),
|
||||||
highlightFields=[],
|
highlight_fields=[],
|
||||||
condition={"kb_id": kb_id, "available_int": 1},
|
condition={"kb_id": kb_id, "available_int": 1},
|
||||||
matchExprs=[], orderBy=OrderByExpr(),
|
match_expressions=[], order_by=OrderByExpr(),
|
||||||
offset=off, limit=1,
|
offset=off, limit=1,
|
||||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
index_names=index_nm, knowledgebase_ids=[kb_id]
|
||||||
)
|
)
|
||||||
ids = docStoreConn.get_chunk_ids(res1)
|
ids = docStoreConn.get_doc_ids(res1)
|
||||||
if not ids:
|
if not ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result
|
|||||||
from common.constants import StatusEnum, LLMType
|
from common.constants import StatusEnum, LLMType
|
||||||
from api.db.db_models import TenantLLM
|
from api.db.db_models import TenantLLM
|
||||||
from rag.utils.base64_image import test_image
|
from rag.utils.base64_image import test_image
|
||||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel
|
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel, Seq2txtModel
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/factories", methods=["GET"]) # noqa: F821
|
@manager.route("/factories", methods=["GET"]) # noqa: F821
|
||||||
@ -157,7 +157,7 @@ async def add_llm():
|
|||||||
elif factory == "Bedrock":
|
elif factory == "Bedrock":
|
||||||
# For Bedrock, due to its special authentication method
|
# For Bedrock, due to its special authentication method
|
||||||
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
||||||
api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])
|
api_key = apikey_json(["auth_mode", "bedrock_ak", "bedrock_sk", "bedrock_region", "aws_role_arn"])
|
||||||
|
|
||||||
elif factory == "LocalAI":
|
elif factory == "LocalAI":
|
||||||
llm_name += "___LocalAI"
|
llm_name += "___LocalAI"
|
||||||
@ -195,6 +195,9 @@ async def add_llm():
|
|||||||
elif factory == "MinerU":
|
elif factory == "MinerU":
|
||||||
api_key = apikey_json(["api_key", "provider_order"])
|
api_key = apikey_json(["api_key", "provider_order"])
|
||||||
|
|
||||||
|
elif factory == "PaddleOCR":
|
||||||
|
api_key = apikey_json(["api_key", "provider_order"])
|
||||||
|
|
||||||
llm = {
|
llm = {
|
||||||
"tenant_id": current_user.id,
|
"tenant_id": current_user.id,
|
||||||
"llm_factory": factory,
|
"llm_factory": factory,
|
||||||
@ -208,70 +211,82 @@ async def add_llm():
|
|||||||
msg = ""
|
msg = ""
|
||||||
mdl_nm = llm["llm_name"].split("___")[0]
|
mdl_nm = llm["llm_name"].split("___")[0]
|
||||||
extra = {"provider": factory}
|
extra = {"provider": factory}
|
||||||
if llm["model_type"] == LLMType.EMBEDDING.value:
|
model_type = llm["model_type"]
|
||||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
model_api_key = llm["api_key"]
|
||||||
mdl = EmbeddingModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
model_base_url = llm.get("api_base", "")
|
||||||
try:
|
match model_type:
|
||||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
case LLMType.EMBEDDING.value:
|
||||||
if len(arr[0]) == 0:
|
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||||
raise Exception("Fail")
|
mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
except Exception as e:
|
try:
|
||||||
msg += f"\nFail to access embedding model({mdl_nm})." + str(e)
|
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||||
elif llm["model_type"] == LLMType.CHAT.value:
|
if len(arr[0]) == 0:
|
||||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
raise Exception("Fail")
|
||||||
mdl = ChatModel[factory](
|
except Exception as e:
|
||||||
key=llm["api_key"],
|
msg += f"\nFail to access embedding model({mdl_nm})." + str(e)
|
||||||
model_name=mdl_nm,
|
case LLMType.CHAT.value:
|
||||||
base_url=llm["api_base"],
|
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||||
**extra,
|
mdl = ChatModel[factory](
|
||||||
)
|
key=model_api_key,
|
||||||
try:
|
model_name=mdl_nm,
|
||||||
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
base_url=model_base_url,
|
||||||
if not tc and m.find("**ERROR**:") >= 0:
|
**extra,
|
||||||
raise Exception(m)
|
)
|
||||||
except Exception as e:
|
try:
|
||||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||||
elif llm["model_type"] == LLMType.RERANK:
|
if not tc and m.find("**ERROR**:") >= 0:
|
||||||
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
|
raise Exception(m)
|
||||||
try:
|
except Exception as e:
|
||||||
mdl = RerankModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
|
|
||||||
if len(arr) == 0:
|
case LLMType.RERANK.value:
|
||||||
raise Exception("Not known.")
|
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
|
||||||
except KeyError:
|
try:
|
||||||
msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
|
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
except Exception as e:
|
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
|
||||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
if len(arr) == 0:
|
||||||
elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
|
raise Exception("Not known.")
|
||||||
assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
|
except KeyError:
|
||||||
mdl = CvModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
|
||||||
try:
|
except Exception as e:
|
||||||
image_data = test_image
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
m, tc = mdl.describe(image_data)
|
|
||||||
if not tc and m.find("**ERROR**:") >= 0:
|
case LLMType.IMAGE2TEXT.value:
|
||||||
raise Exception(m)
|
assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
|
||||||
except Exception as e:
|
mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
try:
|
||||||
elif llm["model_type"] == LLMType.TTS:
|
image_data = test_image
|
||||||
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
|
m, tc = mdl.describe(image_data)
|
||||||
mdl = TTSModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
if not tc and m.find("**ERROR**:") >= 0:
|
||||||
try:
|
raise Exception(m)
|
||||||
for resp in mdl.tts("Hello~ RAGFlower!"):
|
except Exception as e:
|
||||||
pass
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
except RuntimeError as e:
|
case LLMType.TTS.value:
|
||||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
|
||||||
elif llm["model_type"] == LLMType.OCR.value:
|
mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
|
try:
|
||||||
try:
|
for resp in mdl.tts("Hello~ RAGFlower!"):
|
||||||
mdl = OcrModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm.get("api_base", ""))
|
pass
|
||||||
ok, reason = mdl.check_available()
|
except RuntimeError as e:
|
||||||
if not ok:
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
raise RuntimeError(reason or "Model not available")
|
case LLMType.OCR.value:
|
||||||
except Exception as e:
|
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
|
||||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
try:
|
||||||
else:
|
mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
# TODO: check other type of models
|
ok, reason = mdl.check_available()
|
||||||
pass
|
if not ok:
|
||||||
|
raise RuntimeError(reason or "Model not available")
|
||||||
|
except Exception as e:
|
||||||
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
|
case LLMType.SPEECH2TEXT:
|
||||||
|
assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet."
|
||||||
|
try:
|
||||||
|
mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
|
# TODO: check the availability
|
||||||
|
except Exception as e:
|
||||||
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
|
case _:
|
||||||
|
raise RuntimeError(f"Unknown model type: {model_type}")
|
||||||
|
|
||||||
if msg:
|
if msg:
|
||||||
return get_data_error_result(message=msg)
|
return get_data_error_result(message=msg)
|
||||||
@ -358,17 +373,18 @@ def my_llms():
|
|||||||
|
|
||||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def list_app():
|
async def list_app():
|
||||||
self_deployed = ["FastEmbed", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
|
self_deployed = ["FastEmbed", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
|
||||||
weighted = []
|
weighted = []
|
||||||
model_type = request.args.get("model_type")
|
model_type = request.args.get("model_type")
|
||||||
|
tenant_id = current_user.id
|
||||||
try:
|
try:
|
||||||
TenantLLMService.ensure_mineru_from_env(current_user.id)
|
TenantLLMService.ensure_mineru_from_env(tenant_id)
|
||||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
objs = TenantLLMService.query(tenant_id=tenant_id)
|
||||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value])
|
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value])
|
||||||
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
|
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
|
||||||
llms = LLMService.get_all()
|
llms = LLMService.get_all()
|
||||||
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.fid == 'Builtin' or (m.llm_name + "@" + m.fid) in status)]
|
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.fid == "Builtin" or (m.llm_name + "@" + m.fid) in status)]
|
||||||
for m in llms:
|
for m in llms:
|
||||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
|
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
|
||||||
if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"] == LLMType.EMBEDDING and m["fid"] == "Builtin" and m["llm_name"] == os.getenv("TEI_MODEL", ""):
|
if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"] == LLMType.EMBEDDING and m["fid"] == "Builtin" and m["llm_name"] == os.getenv("TEI_MODEL", ""):
|
||||||
|
|||||||
@ -21,12 +21,11 @@ from api.db.services.mcp_server_service import MCPServerService
|
|||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
|
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
|
||||||
|
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid, thread_pool_exec
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
|
||||||
from api.utils.web_utils import get_float, safe_json_parse
|
from api.utils.web_utils import get_float, safe_json_parse
|
||||||
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def list_mcp() -> Response:
|
async def list_mcp() -> Response:
|
||||||
@ -106,7 +105,7 @@ async def create() -> Response:
|
|||||||
return get_data_error_result(message="Tenant not found.")
|
return get_data_error_result(message="Tenant not found.")
|
||||||
|
|
||||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||||
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
|
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
|
||||||
if err_message:
|
if err_message:
|
||||||
return get_data_error_result(err_message)
|
return get_data_error_result(err_message)
|
||||||
|
|
||||||
@ -158,7 +157,7 @@ async def update() -> Response:
|
|||||||
req["id"] = mcp_id
|
req["id"] = mcp_id
|
||||||
|
|
||||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||||
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
|
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
|
||||||
if err_message:
|
if err_message:
|
||||||
return get_data_error_result(err_message)
|
return get_data_error_result(err_message)
|
||||||
|
|
||||||
@ -242,7 +241,7 @@ async def import_multiple() -> Response:
|
|||||||
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
|
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
|
||||||
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
|
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
|
||||||
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
|
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
|
||||||
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
|
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
|
||||||
if err_message:
|
if err_message:
|
||||||
results.append({"server": base_name, "success": False, "message": err_message})
|
results.append({"server": base_name, "success": False, "message": err_message})
|
||||||
continue
|
continue
|
||||||
@ -322,9 +321,8 @@ async def list_tools() -> Response:
|
|||||||
tool_call_sessions.append(tool_call_session)
|
tool_call_sessions.append(tool_call_session)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tools = tool_call_session.get_tools(timeout)
|
tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tools = []
|
|
||||||
return get_data_error_result(message=f"MCP list tools error: {e}")
|
return get_data_error_result(message=f"MCP list tools error: {e}")
|
||||||
|
|
||||||
results[server_key] = []
|
results[server_key] = []
|
||||||
@ -340,7 +338,7 @@ async def list_tools() -> Response:
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
finally:
|
finally:
|
||||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||||
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
|
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
|
||||||
@ -367,10 +365,10 @@ async def test_tool() -> Response:
|
|||||||
|
|
||||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||||
tool_call_sessions.append(tool_call_session)
|
tool_call_sessions.append(tool_call_session)
|
||||||
result = tool_call_session.tool_call(tool_name, arguments, timeout)
|
result = await thread_pool_exec(tool_call_session.tool_call, tool_name, arguments, timeout)
|
||||||
|
|
||||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||||
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||||
return get_json_result(data=result)
|
return get_json_result(data=result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -424,13 +422,12 @@ async def test_mcp() -> Response:
|
|||||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tools = tool_call_session.get_tools(timeout)
|
tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tools = []
|
|
||||||
return get_data_error_result(message=f"Test MCP error: {e}")
|
return get_data_error_result(message=f"Test MCP error: {e}")
|
||||||
finally:
|
finally:
|
||||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||||
close_multiple_mcp_toolcall_sessions([tool_call_session])
|
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session])
|
||||||
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
tool_dict = tool.model_dump()
|
tool_dict = tool.model_dump()
|
||||||
|
|||||||
@ -14,20 +14,29 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import ipaddress
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
from api.db import CanvasCategory
|
from api.db import CanvasCategory
|
||||||
from api.db.services.canvas_service import UserCanvasService
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
|
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
|
||||||
from api.utils.api_utils import get_result
|
from api.utils.api_utils import get_result
|
||||||
from quart import request, Response
|
from quart import request, Response
|
||||||
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||||
@ -42,7 +51,7 @@ def list_agents(tenant_id):
|
|||||||
page_number = int(request.args.get("page", 1))
|
page_number = int(request.args.get("page", 1))
|
||||||
items_per_page = int(request.args.get("page_size", 30))
|
items_per_page = int(request.args.get("page_size", 30))
|
||||||
order_by = request.args.get("orderby", "update_time")
|
order_by = request.args.get("orderby", "update_time")
|
||||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
if str(request.args.get("desc","false")).lower() == "false":
|
||||||
desc = False
|
desc = False
|
||||||
else:
|
else:
|
||||||
desc = True
|
desc = True
|
||||||
@ -132,48 +141,786 @@ def delete_agent(tenant_id: str, agent_id: str):
|
|||||||
UserCanvasService.delete_by_id(agent_id)
|
UserCanvasService.delete_by_id(agent_id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
@manager.route("/webhook/<agent_id>", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
|
||||||
|
@manager.route("/webhook_test/<agent_id>",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821
|
||||||
|
async def webhook(agent_id: str):
|
||||||
|
is_test = request.path.startswith("/api/v1/webhook_test")
|
||||||
|
start_ts = time.time()
|
||||||
|
|
||||||
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
|
# 1. Fetch canvas by agent_id
|
||||||
@token_required
|
exists, cvs = UserCanvasService.get_by_id(agent_id)
|
||||||
async def webhook(tenant_id: str, agent_id: str):
|
if not exists:
|
||||||
req = await get_request_json()
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST
|
||||||
if not UserCanvasService.accessible(req["id"], tenant_id):
|
|
||||||
return get_json_result(
|
|
||||||
data=False, message='Only owner of canvas authorized for this operation.',
|
|
||||||
code=RetCode.OPERATING_ERROR)
|
|
||||||
|
|
||||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
|
||||||
if not e:
|
|
||||||
return get_data_error_result(message="canvas not found.")
|
|
||||||
|
|
||||||
if not isinstance(cvs.dsl, str):
|
|
||||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
# 2. Check canvas category
|
||||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||||
return get_data_error_result(message="Dataflow can not be triggered by webhook.")
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
# 3. Load DSL from canvas
|
||||||
|
dsl = getattr(cvs, "dsl", None)
|
||||||
|
if not isinstance(dsl, dict):
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
# 4. Check webhook configuration in DSL
|
||||||
|
webhook_cfg = {}
|
||||||
|
components = dsl.get("components", {})
|
||||||
|
for k, _ in components.items():
|
||||||
|
cpn_obj = components[k]["obj"]
|
||||||
|
if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook":
|
||||||
|
webhook_cfg = cpn_obj["params"]
|
||||||
|
|
||||||
|
if not webhook_cfg:
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
# 5. Validate request method against webhook_cfg.methods
|
||||||
|
allowed_methods = webhook_cfg.get("methods", [])
|
||||||
|
request_method = request.method.upper()
|
||||||
|
if allowed_methods and request_method not in allowed_methods:
|
||||||
|
return get_data_error_result(
|
||||||
|
code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook."
|
||||||
|
),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
# 6. Validate webhook security
|
||||||
|
async def validate_webhook_security(security_cfg: dict):
|
||||||
|
"""Validate webhook security rules based on security configuration."""
|
||||||
|
|
||||||
|
if not security_cfg:
|
||||||
|
return # No security config → allowed by default
|
||||||
|
|
||||||
|
# 1. Validate max body size
|
||||||
|
await _validate_max_body_size(security_cfg)
|
||||||
|
|
||||||
|
# 2. Validate IP whitelist
|
||||||
|
_validate_ip_whitelist(security_cfg)
|
||||||
|
|
||||||
|
# # 3. Validate rate limiting
|
||||||
|
_validate_rate_limit(security_cfg)
|
||||||
|
|
||||||
|
# 4. Validate authentication
|
||||||
|
auth_type = security_cfg.get("auth_type", "none")
|
||||||
|
|
||||||
|
if auth_type == "none":
|
||||||
|
return
|
||||||
|
|
||||||
|
if auth_type == "token":
|
||||||
|
_validate_token_auth(security_cfg)
|
||||||
|
|
||||||
|
elif auth_type == "basic":
|
||||||
|
_validate_basic_auth(security_cfg)
|
||||||
|
|
||||||
|
elif auth_type == "jwt":
|
||||||
|
_validate_jwt_auth(security_cfg)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unsupported auth_type: {auth_type}")
|
||||||
|
|
||||||
|
async def _validate_max_body_size(security_cfg):
|
||||||
|
"""Check request size does not exceed max_body_size."""
|
||||||
|
max_size = security_cfg.get("max_body_size")
|
||||||
|
if not max_size:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert "10MB" → bytes
|
||||||
|
units = {"kb": 1024, "mb": 1024**2}
|
||||||
|
size_str = max_size.lower()
|
||||||
|
|
||||||
|
for suffix, factor in units.items():
|
||||||
|
if size_str.endswith(suffix):
|
||||||
|
limit = int(size_str.replace(suffix, "")) * factor
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid max_body_size format")
|
||||||
|
MAX_LIMIT = 10 * 1024 * 1024 # 10MB
|
||||||
|
if limit > MAX_LIMIT:
|
||||||
|
raise Exception("max_body_size exceeds maximum allowed size (10MB)")
|
||||||
|
|
||||||
|
content_length = request.content_length or 0
|
||||||
|
if content_length > limit:
|
||||||
|
raise Exception(f"Request body too large: {content_length} > {limit}")
|
||||||
|
|
||||||
|
def _validate_ip_whitelist(security_cfg):
|
||||||
|
"""Allow only IPs listed in ip_whitelist."""
|
||||||
|
whitelist = security_cfg.get("ip_whitelist", [])
|
||||||
|
if not whitelist:
|
||||||
|
return
|
||||||
|
|
||||||
|
client_ip = request.remote_addr
|
||||||
|
|
||||||
|
|
||||||
|
for rule in whitelist:
|
||||||
|
if "/" in rule:
|
||||||
|
# CIDR notation
|
||||||
|
if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Single IP
|
||||||
|
if client_ip == rule:
|
||||||
|
return
|
||||||
|
|
||||||
|
raise Exception(f"IP {client_ip} is not allowed by whitelist")
|
||||||
|
|
||||||
|
def _validate_rate_limit(security_cfg):
|
||||||
|
"""Simple in-memory rate limiting."""
|
||||||
|
rl = security_cfg.get("rate_limit")
|
||||||
|
if not rl:
|
||||||
|
return
|
||||||
|
|
||||||
|
limit = int(rl.get("limit", 60))
|
||||||
|
if limit <= 0:
|
||||||
|
raise Exception("rate_limit.limit must be > 0")
|
||||||
|
per = rl.get("per", "minute")
|
||||||
|
|
||||||
|
window = {
|
||||||
|
"second": 1,
|
||||||
|
"minute": 60,
|
||||||
|
"hour": 3600,
|
||||||
|
"day": 86400,
|
||||||
|
}.get(per)
|
||||||
|
|
||||||
|
if not window:
|
||||||
|
raise Exception(f"Invalid rate_limit.per: {per}")
|
||||||
|
|
||||||
|
capacity = limit
|
||||||
|
rate = limit / window
|
||||||
|
cost = 1
|
||||||
|
|
||||||
|
key = f"rl:tb:{agent_id}"
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = REDIS_CONN.lua_token_bucket(
|
||||||
|
keys=[key],
|
||||||
|
args=[capacity, rate, now, cost],
|
||||||
|
client=REDIS_CONN.REDIS,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed = int(res[0])
|
||||||
|
if allowed != 1:
|
||||||
|
raise Exception("Too many requests (rate limit exceeded)")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Rate limit error: {e}")
|
||||||
|
|
||||||
|
def _validate_token_auth(security_cfg):
|
||||||
|
"""Validate header-based token authentication."""
|
||||||
|
token_cfg = security_cfg.get("token",{})
|
||||||
|
header = token_cfg.get("token_header")
|
||||||
|
token_value = token_cfg.get("token_value")
|
||||||
|
|
||||||
|
provided = request.headers.get(header)
|
||||||
|
if provided != token_value:
|
||||||
|
raise Exception("Invalid token authentication")
|
||||||
|
|
||||||
|
def _validate_basic_auth(security_cfg):
|
||||||
|
"""Validate HTTP Basic Auth credentials."""
|
||||||
|
auth_cfg = security_cfg.get("basic_auth", {})
|
||||||
|
username = auth_cfg.get("username")
|
||||||
|
password = auth_cfg.get("password")
|
||||||
|
|
||||||
|
auth = request.authorization
|
||||||
|
if not auth or auth.username != username or auth.password != password:
|
||||||
|
raise Exception("Invalid Basic Auth credentials")
|
||||||
|
|
||||||
|
def _validate_jwt_auth(security_cfg):
|
||||||
|
"""Validate JWT token in Authorization header."""
|
||||||
|
jwt_cfg = security_cfg.get("jwt", {})
|
||||||
|
secret = jwt_cfg.get("secret")
|
||||||
|
if not secret:
|
||||||
|
raise Exception("JWT secret not configured")
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
raise Exception("Missing Bearer token")
|
||||||
|
|
||||||
|
token = auth_header[len("Bearer "):].strip()
|
||||||
|
if not token:
|
||||||
|
raise Exception("Empty Bearer token")
|
||||||
|
|
||||||
|
alg = (jwt_cfg.get("algorithm") or "HS256").upper()
|
||||||
|
|
||||||
|
decode_kwargs = {
|
||||||
|
"key": secret,
|
||||||
|
"algorithms": [alg],
|
||||||
|
}
|
||||||
|
options = {}
|
||||||
|
if jwt_cfg.get("audience"):
|
||||||
|
decode_kwargs["audience"] = jwt_cfg["audience"]
|
||||||
|
options["verify_aud"] = True
|
||||||
|
else:
|
||||||
|
options["verify_aud"] = False
|
||||||
|
|
||||||
|
if jwt_cfg.get("issuer"):
|
||||||
|
decode_kwargs["issuer"] = jwt_cfg["issuer"]
|
||||||
|
options["verify_iss"] = True
|
||||||
|
else:
|
||||||
|
options["verify_iss"] = False
|
||||||
|
try:
|
||||||
|
decoded = jwt.decode(
|
||||||
|
token,
|
||||||
|
options=options,
|
||||||
|
**decode_kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Invalid JWT: {str(e)}")
|
||||||
|
|
||||||
|
raw_required_claims = jwt_cfg.get("required_claims", [])
|
||||||
|
if isinstance(raw_required_claims, str):
|
||||||
|
required_claims = [raw_required_claims]
|
||||||
|
elif isinstance(raw_required_claims, (list, tuple, set)):
|
||||||
|
required_claims = list(raw_required_claims)
|
||||||
|
else:
|
||||||
|
required_claims = []
|
||||||
|
|
||||||
|
required_claims = [
|
||||||
|
c for c in required_claims
|
||||||
|
if isinstance(c, str) and c.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"}
|
||||||
|
for claim in required_claims:
|
||||||
|
if claim in RESERVED_CLAIMS:
|
||||||
|
raise Exception(f"Reserved JWT claim cannot be required: {claim}")
|
||||||
|
|
||||||
|
for claim in required_claims:
|
||||||
|
if claim not in decoded:
|
||||||
|
raise Exception(f"Missing JWT claim: {claim}")
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
try:
|
try:
|
||||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
security_config=webhook_cfg.get("security", {})
|
||||||
|
await validate_webhook_security(security_config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return get_json_result(
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||||
data=False, message=str(e),
|
if not isinstance(cvs.dsl, str):
|
||||||
code=RetCode.EXCEPTION_ERROR)
|
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
|
try:
|
||||||
|
canvas = Canvas(dsl, cvs.user_id, agent_id, canvas_id=agent_id)
|
||||||
|
except Exception as e:
|
||||||
|
resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e))
|
||||||
|
resp.status_code = RetCode.BAD_REQUEST
|
||||||
|
return resp
|
||||||
|
|
||||||
|
# 7. Parse request body
|
||||||
|
async def parse_webhook_request(content_type):
|
||||||
|
"""Parse request based on content-type and return structured data."""
|
||||||
|
|
||||||
|
# 1. Query
|
||||||
|
query_data = {k: v for k, v in request.args.items()}
|
||||||
|
|
||||||
|
# 2. Headers
|
||||||
|
header_data = {k: v for k, v in request.headers.items()}
|
||||||
|
|
||||||
|
# 3. Body
|
||||||
|
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
|
||||||
|
if ctype and ctype != content_type:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid Content-Type: expect '{content_type}', got '{ctype}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
body_data: dict = {}
|
||||||
|
|
||||||
async def sse():
|
|
||||||
nonlocal canvas
|
|
||||||
try:
|
try:
|
||||||
async for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req):
|
if ctype == "application/json":
|
||||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
body_data = await request.get_json() or {}
|
||||||
|
|
||||||
cvs.dsl = json.loads(str(canvas))
|
elif ctype == "multipart/form-data":
|
||||||
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
|
nonlocal canvas
|
||||||
except Exception as e:
|
form = await request.form
|
||||||
logging.exception(e)
|
files = await request.files
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
|
|
||||||
|
|
||||||
resp = Response(sse(), mimetype="text/event-stream")
|
body_data = {}
|
||||||
resp.headers.add_header("Cache-control", "no-cache")
|
|
||||||
resp.headers.add_header("Connection", "keep-alive")
|
for key, value in form.items():
|
||||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
body_data[key] = value
|
||||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
|
||||||
return resp
|
if len(files) > 10:
|
||||||
|
raise Exception("Too many uploaded files")
|
||||||
|
for key, file in files.items():
|
||||||
|
desc = FileService.upload_info(
|
||||||
|
cvs.user_id, # user
|
||||||
|
file, # FileStorage
|
||||||
|
None # url (None for webhook)
|
||||||
|
)
|
||||||
|
file_parsed= await canvas.get_files_async([desc])
|
||||||
|
body_data[key] = file_parsed
|
||||||
|
|
||||||
|
elif ctype == "application/x-www-form-urlencoded":
|
||||||
|
form = await request.form
|
||||||
|
body_data = dict(form)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# text/plain / octet-stream / empty / unknown
|
||||||
|
raw = await request.get_data()
|
||||||
|
if raw:
|
||||||
|
try:
|
||||||
|
body_data = json.loads(raw.decode("utf-8"))
|
||||||
|
except Exception:
|
||||||
|
body_data = {}
|
||||||
|
else:
|
||||||
|
body_data = {}
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
body_data = {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"query": query_data,
|
||||||
|
"headers": header_data,
|
||||||
|
"body": body_data,
|
||||||
|
"content_type": ctype,
|
||||||
|
}
|
||||||
|
|
||||||
|
def extract_by_schema(data, schema, name="section"):
|
||||||
|
"""
|
||||||
|
Extract only fields defined in schema.
|
||||||
|
Required fields must exist.
|
||||||
|
Optional fields default to type-based default values.
|
||||||
|
Type validation included.
|
||||||
|
"""
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
required = schema.get("required", [])
|
||||||
|
|
||||||
|
extracted = {}
|
||||||
|
|
||||||
|
for field, field_schema in props.items():
|
||||||
|
field_type = field_schema.get("type")
|
||||||
|
|
||||||
|
# 1. Required field missing
|
||||||
|
if field in required and field not in data:
|
||||||
|
raise Exception(f"{name} missing required field: {field}")
|
||||||
|
|
||||||
|
# 2. Optional → default value
|
||||||
|
if field not in data:
|
||||||
|
extracted[field] = default_for_type(field_type)
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_value = data[field]
|
||||||
|
|
||||||
|
# 3. Auto convert value
|
||||||
|
try:
|
||||||
|
value = auto_cast_value(raw_value, field_type)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
|
||||||
|
|
||||||
|
# 4. Type validation
|
||||||
|
if not validate_type(value, field_type):
|
||||||
|
raise Exception(
|
||||||
|
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted[field] = value
|
||||||
|
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
|
||||||
|
def default_for_type(t):
|
||||||
|
"""Return default value for the given schema type."""
|
||||||
|
if t == "file":
|
||||||
|
return []
|
||||||
|
if t == "object":
|
||||||
|
return {}
|
||||||
|
if t == "boolean":
|
||||||
|
return False
|
||||||
|
if t == "number":
|
||||||
|
return 0
|
||||||
|
if t == "string":
|
||||||
|
return ""
|
||||||
|
if t and t.startswith("array"):
|
||||||
|
return []
|
||||||
|
if t == "null":
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def auto_cast_value(value, expected_type):
|
||||||
|
"""Convert string values into schema type when possible."""
|
||||||
|
|
||||||
|
# Non-string values already good
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return value
|
||||||
|
|
||||||
|
v = value.strip()
|
||||||
|
|
||||||
|
# Boolean
|
||||||
|
if expected_type == "boolean":
|
||||||
|
if v.lower() in ["true", "1"]:
|
||||||
|
return True
|
||||||
|
if v.lower() in ["false", "0"]:
|
||||||
|
return False
|
||||||
|
raise Exception(f"Cannot convert '{value}' to boolean")
|
||||||
|
|
||||||
|
# Number
|
||||||
|
if expected_type == "number":
|
||||||
|
# integer
|
||||||
|
if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
|
||||||
|
return int(v)
|
||||||
|
|
||||||
|
# float
|
||||||
|
try:
|
||||||
|
return float(v)
|
||||||
|
except Exception:
|
||||||
|
raise Exception(f"Cannot convert '{value}' to number")
|
||||||
|
|
||||||
|
# Object
|
||||||
|
if expected_type == "object":
|
||||||
|
try:
|
||||||
|
parsed = json.loads(v)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return parsed
|
||||||
|
else:
|
||||||
|
raise Exception("JSON is not an object")
|
||||||
|
except Exception:
|
||||||
|
raise Exception(f"Cannot convert '{value}' to object")
|
||||||
|
|
||||||
|
# Array <T>
|
||||||
|
if expected_type.startswith("array"):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(v)
|
||||||
|
if isinstance(parsed, list):
|
||||||
|
return parsed
|
||||||
|
else:
|
||||||
|
raise Exception("JSON is not an array")
|
||||||
|
except Exception:
|
||||||
|
raise Exception(f"Cannot convert '{value}' to array")
|
||||||
|
|
||||||
|
# String (accept original)
|
||||||
|
if expected_type == "string":
|
||||||
|
return value
|
||||||
|
|
||||||
|
# File
|
||||||
|
if expected_type == "file":
|
||||||
|
return value
|
||||||
|
# Default: do nothing
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def validate_type(value, t):
|
||||||
|
"""Validate value type against schema type t."""
|
||||||
|
if t == "file":
|
||||||
|
return isinstance(value, list)
|
||||||
|
|
||||||
|
if t == "string":
|
||||||
|
return isinstance(value, str)
|
||||||
|
|
||||||
|
if t == "number":
|
||||||
|
return isinstance(value, (int, float))
|
||||||
|
|
||||||
|
if t == "boolean":
|
||||||
|
return isinstance(value, bool)
|
||||||
|
|
||||||
|
if t == "object":
|
||||||
|
return isinstance(value, dict)
|
||||||
|
|
||||||
|
# array<string> / array<number> / array<object>
|
||||||
|
if t.startswith("array"):
|
||||||
|
if not isinstance(value, list):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if "<" in t and ">" in t:
|
||||||
|
inner = t[t.find("<") + 1 : t.find(">")]
|
||||||
|
|
||||||
|
# Check each element type
|
||||||
|
for item in value:
|
||||||
|
if not validate_type(item, inner):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
return True
|
||||||
|
parsed = await parse_webhook_request(webhook_cfg.get("content_types"))
|
||||||
|
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
|
||||||
|
|
||||||
|
# Extract strictly by schema
|
||||||
|
try:
|
||||||
|
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
|
||||||
|
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
|
||||||
|
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
|
||||||
|
except Exception as e:
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
clean_request = {
|
||||||
|
"query": query_clean,
|
||||||
|
"headers": header_clean,
|
||||||
|
"body": body_clean,
|
||||||
|
"input": parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
execution_mode = webhook_cfg.get("execution_mode", "Immediately")
|
||||||
|
response_cfg = webhook_cfg.get("response", {})
|
||||||
|
|
||||||
|
def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600):
|
||||||
|
key = f"webhook-trace-{agent_id}-logs"
|
||||||
|
|
||||||
|
raw = REDIS_CONN.get(key)
|
||||||
|
obj = json.loads(raw) if raw else {"webhooks": {}}
|
||||||
|
|
||||||
|
ws = obj["webhooks"].setdefault(
|
||||||
|
str(start_ts),
|
||||||
|
{"start_ts": start_ts, "events": []}
|
||||||
|
)
|
||||||
|
|
||||||
|
ws["events"].append({
|
||||||
|
"ts": time.time(),
|
||||||
|
**event
|
||||||
|
})
|
||||||
|
|
||||||
|
REDIS_CONN.set_obj(key, obj, ttl)
|
||||||
|
|
||||||
|
if execution_mode == "Immediately":
|
||||||
|
status = response_cfg.get("status", 200)
|
||||||
|
try:
|
||||||
|
status = int(status)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
if not (200 <= status <= 399):
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
body_tpl = response_cfg.get("body_template", "")
|
||||||
|
|
||||||
|
def parse_body(body: str):
|
||||||
|
if not body:
|
||||||
|
return None, "application/json"
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = json.loads(body)
|
||||||
|
return parsed, "application/json"
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return body, "text/plain"
|
||||||
|
|
||||||
|
|
||||||
|
body, content_type = parse_body(body_tpl)
|
||||||
|
resp = Response(
|
||||||
|
json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body,
|
||||||
|
status=status,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def background_run():
|
||||||
|
try:
|
||||||
|
async for ans in canvas.run(
|
||||||
|
query="",
|
||||||
|
user_id=cvs.user_id,
|
||||||
|
webhook_payload=clean_request
|
||||||
|
):
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(agent_id, start_ts, ans)
|
||||||
|
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "finished",
|
||||||
|
"elapsed_time": time.time() - start_ts,
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cvs.dsl = json.loads(str(canvas))
|
||||||
|
UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("Webhook background run failed")
|
||||||
|
if is_test:
|
||||||
|
try:
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "finished",
|
||||||
|
"elapsed_time": time.time() - start_ts,
|
||||||
|
"success": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to append webhook trace")
|
||||||
|
|
||||||
|
asyncio.create_task(background_run())
|
||||||
|
return resp
|
||||||
|
else:
|
||||||
|
async def sse():
|
||||||
|
nonlocal canvas
|
||||||
|
contents: list[str] = []
|
||||||
|
status = 200
|
||||||
|
try:
|
||||||
|
async for ans in canvas.run(
|
||||||
|
query="",
|
||||||
|
user_id=cvs.user_id,
|
||||||
|
webhook_payload=clean_request,
|
||||||
|
):
|
||||||
|
if ans["event"] == "message":
|
||||||
|
content = ans["data"]["content"]
|
||||||
|
if ans["data"].get("start_to_think", False):
|
||||||
|
content = "<think>"
|
||||||
|
elif ans["data"].get("end_to_think", False):
|
||||||
|
content = "</think>"
|
||||||
|
if content:
|
||||||
|
contents.append(content)
|
||||||
|
if ans["event"] == "message_end":
|
||||||
|
status = int(ans["data"].get("status", status))
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
ans
|
||||||
|
)
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "finished",
|
||||||
|
"elapsed_time": time.time() - start_ts,
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
final_content = "".join(contents)
|
||||||
|
return {
|
||||||
|
"message": final_content,
|
||||||
|
"success": True,
|
||||||
|
"code": status,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "finished",
|
||||||
|
"elapsed_time": time.time() - start_ts,
|
||||||
|
"success": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {"code": 400, "message": str(e),"success":False}
|
||||||
|
|
||||||
|
result = await sse()
|
||||||
|
return Response(
|
||||||
|
json.dumps(result),
|
||||||
|
status=result["code"],
|
||||||
|
mimetype="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/webhook_trace/<agent_id>", methods=["GET"]) # noqa: F821
|
||||||
|
async def webhook_trace(agent_id: str):
|
||||||
|
def encode_webhook_id(start_ts: str) -> str:
|
||||||
|
WEBHOOK_ID_SECRET = "webhook_id_secret"
|
||||||
|
sig = hmac.new(
|
||||||
|
WEBHOOK_ID_SECRET.encode("utf-8"),
|
||||||
|
start_ts.encode("utf-8"),
|
||||||
|
hashlib.sha256,
|
||||||
|
).digest()
|
||||||
|
return base64.urlsafe_b64encode(sig).decode("utf-8").rstrip("=")
|
||||||
|
|
||||||
|
def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None:
|
||||||
|
for ts in webhooks.keys():
|
||||||
|
if encode_webhook_id(ts) == enc_id:
|
||||||
|
return ts
|
||||||
|
return None
|
||||||
|
since_ts = request.args.get("since_ts", type=float)
|
||||||
|
webhook_id = request.args.get("webhook_id")
|
||||||
|
|
||||||
|
key = f"webhook-trace-{agent_id}-logs"
|
||||||
|
raw = REDIS_CONN.get(key)
|
||||||
|
|
||||||
|
if since_ts is None:
|
||||||
|
now = time.time()
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": None,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": now,
|
||||||
|
"finished": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not raw:
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": None,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": since_ts,
|
||||||
|
"finished": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
obj = json.loads(raw)
|
||||||
|
webhooks = obj.get("webhooks", {})
|
||||||
|
|
||||||
|
if webhook_id is None:
|
||||||
|
candidates = [
|
||||||
|
float(k) for k in webhooks.keys() if float(k) > since_ts
|
||||||
|
]
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": None,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": since_ts,
|
||||||
|
"finished": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
start_ts = min(candidates)
|
||||||
|
real_id = str(start_ts)
|
||||||
|
webhook_id = encode_webhook_id(real_id)
|
||||||
|
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": webhook_id,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": start_ts,
|
||||||
|
"finished": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
real_id = decode_webhook_id(webhook_id, webhooks)
|
||||||
|
|
||||||
|
if not real_id:
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": webhook_id,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": since_ts,
|
||||||
|
"finished": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ws = webhooks.get(str(real_id))
|
||||||
|
events = ws.get("events", [])
|
||||||
|
new_events = [e for e in events if e.get("ts", 0) > since_ts]
|
||||||
|
|
||||||
|
next_ts = since_ts
|
||||||
|
for e in new_events:
|
||||||
|
next_ts = max(next_ts, e["ts"])
|
||||||
|
|
||||||
|
finished = any(e.get("event") == "finished" for e in new_events)
|
||||||
|
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": webhook_id,
|
||||||
|
"events": new_events,
|
||||||
|
"next_since_ts": next_ts,
|
||||||
|
"finished": finished,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|||||||
@ -51,7 +51,9 @@ async def create(tenant_id):
|
|||||||
req["llm_id"] = llm.pop("model_name")
|
req["llm_id"] = llm.pop("model_name")
|
||||||
if req.get("llm_id") is not None:
|
if req.get("llm_id") is not None:
|
||||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
|
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
|
||||||
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"):
|
model_type = llm.get("model_type")
|
||||||
|
model_type = model_type if model_type in ["chat", "image2text"] else "chat"
|
||||||
|
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type):
|
||||||
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
|
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
|
||||||
req["llm_setting"] = req.pop("llm")
|
req["llm_setting"] = req.pop("llm")
|
||||||
e, tenant = TenantService.get_by_id(tenant_id)
|
e, tenant = TenantService.get_by_id(tenant_id)
|
||||||
@ -174,7 +176,7 @@ async def update(tenant_id, chat_id):
|
|||||||
req["llm_id"] = llm.pop("model_name")
|
req["llm_id"] = llm.pop("model_name")
|
||||||
if req.get("llm_id") is not None:
|
if req.get("llm_id") is not None:
|
||||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
|
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
|
||||||
model_type = llm.pop("model_type")
|
model_type = llm.get("model_type")
|
||||||
model_type = model_type if model_type in ["chat", "image2text"] else "chat"
|
model_type = model_type if model_type in ["chat", "image2text"] else "chat"
|
||||||
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type):
|
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type):
|
||||||
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
|
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
|
||||||
@ -252,7 +254,6 @@ async def delete_chats(tenant_id):
|
|||||||
continue
|
continue
|
||||||
temp_dict = {"status": StatusEnum.INVALID.value}
|
temp_dict = {"status": StatusEnum.INVALID.value}
|
||||||
success_count += DialogService.update_by_id(id, temp_dict)
|
success_count += DialogService.update_by_id(id, temp_dict)
|
||||||
print(success_count, "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$", flush=True)
|
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
@ -288,7 +289,7 @@ def list_chat(tenant_id):
|
|||||||
chats = DialogService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, name)
|
chats = DialogService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, name)
|
||||||
if not chats:
|
if not chats:
|
||||||
return get_result(data=[])
|
return get_result(data=[])
|
||||||
list_assts = []
|
list_assistants = []
|
||||||
key_mapping = {
|
key_mapping = {
|
||||||
"parameters": "variables",
|
"parameters": "variables",
|
||||||
"prologue": "opener",
|
"prologue": "opener",
|
||||||
@ -322,5 +323,5 @@ def list_chat(tenant_id):
|
|||||||
del res["kb_ids"]
|
del res["kb_ids"]
|
||||||
res["datasets"] = kb_list
|
res["datasets"] = kb_list
|
||||||
res["avatar"] = res.pop("icon")
|
res["avatar"] = res.pop("icon")
|
||||||
list_assts.append(res)
|
list_assistants.append(res)
|
||||||
return get_result(data=list_assts)
|
return get_result(data=list_assistants)
|
||||||
|
|||||||
@ -233,6 +233,15 @@ async def delete(tenant_id):
|
|||||||
File2DocumentService.delete_by_document_id(doc.id)
|
File2DocumentService.delete_by_document_id(doc.id)
|
||||||
FileService.filter_delete(
|
FileService.filter_delete(
|
||||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
|
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
|
||||||
|
|
||||||
|
# Drop index for this dataset
|
||||||
|
try:
|
||||||
|
from rag.nlp import search
|
||||||
|
idxnm = search.index_name(kb.tenant_id)
|
||||||
|
settings.docStoreConn.delete_idx(idxnm, kb_id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to drop index for dataset {kb_id}: {e}")
|
||||||
|
|
||||||
if not KnowledgebaseService.delete_by_id(kb_id):
|
if not KnowledgebaseService.delete_by_id(kb_id):
|
||||||
errors.append(f"Delete dataset error for {kb_id}")
|
errors.append(f"Delete dataset error for {kb_id}")
|
||||||
continue
|
continue
|
||||||
@ -481,7 +490,7 @@ def list_datasets(tenant_id):
|
|||||||
|
|
||||||
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def knowledge_graph(tenant_id, dataset_id):
|
async def knowledge_graph(tenant_id, dataset_id):
|
||||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||||
return get_result(
|
return get_result(
|
||||||
data=False,
|
data=False,
|
||||||
@ -495,9 +504,9 @@ def knowledge_graph(tenant_id, dataset_id):
|
|||||||
}
|
}
|
||||||
|
|
||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
|
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
|
||||||
return get_result(data=obj)
|
return get_result(data=obj)
|
||||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||||
if not len(sres.ids):
|
if not len(sres.ids):
|
||||||
return get_result(data=obj)
|
return get_result(data=obj)
|
||||||
|
|
||||||
|
|||||||
@ -135,7 +135,7 @@ async def retrieval(tenant_id):
|
|||||||
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||||
if not doc_ids and metadata_condition:
|
if not doc_ids and metadata_condition:
|
||||||
doc_ids = ["-999"]
|
doc_ids = ["-999"]
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = await settings.retriever.retrieval(
|
||||||
question,
|
question,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
kb.tenant_id,
|
kb.tenant_id,
|
||||||
@ -150,7 +150,7 @@ async def retrieval(tenant_id):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(question,
|
ck = await settings.kg_retriever.retrieval(question,
|
||||||
[tenant_id],
|
[tenant_id],
|
||||||
[kb_id],
|
[kb_id],
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
|
|||||||
@ -606,12 +606,12 @@ def list_docs(dataset_id, tenant_id):
|
|||||||
|
|
||||||
@manager.route("/datasets/<dataset_id>/metadata/summary", methods=["GET"]) # noqa: F821
|
@manager.route("/datasets/<dataset_id>/metadata/summary", methods=["GET"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def metadata_summary(dataset_id, tenant_id):
|
async def metadata_summary(dataset_id, tenant_id):
|
||||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||||
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
summary = DocumentService.get_metadata_summary(dataset_id)
|
summary = DocumentService.get_metadata_summary(dataset_id, req.get("doc_ids"))
|
||||||
return get_result(data={"summary": summary})
|
return get_result(data={"summary": summary})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -648,9 +648,9 @@ async def metadata_batch_update(dataset_id, tenant_id):
|
|||||||
if not isinstance(d, dict) or not d.get("key"):
|
if not isinstance(d, dict) or not d.get("key"):
|
||||||
return get_error_data_result(message="Each delete requires key.")
|
return get_error_data_result(message="Each delete requires key.")
|
||||||
|
|
||||||
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
|
|
||||||
target_doc_ids = set(kb_doc_ids)
|
|
||||||
if document_ids:
|
if document_ids:
|
||||||
|
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
|
||||||
|
target_doc_ids = set(kb_doc_ids)
|
||||||
invalid_ids = set(document_ids) - set(kb_doc_ids)
|
invalid_ids = set(document_ids) - set(kb_doc_ids)
|
||||||
if invalid_ids:
|
if invalid_ids:
|
||||||
return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}")
|
return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}")
|
||||||
@ -935,7 +935,7 @@ async def stop_parsing(tenant_id, dataset_id):
|
|||||||
|
|
||||||
@manager.route("/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["GET"]) # noqa: F821
|
@manager.route("/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["GET"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def list_chunks(tenant_id, dataset_id, document_id):
|
async def list_chunks(tenant_id, dataset_id, document_id):
|
||||||
"""
|
"""
|
||||||
List chunks of a document.
|
List chunks of a document.
|
||||||
---
|
---
|
||||||
@ -1080,8 +1080,8 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|||||||
res["chunks"].append(final_chunk)
|
res["chunks"].append(final_chunk)
|
||||||
_ = Chunk(**final_chunk)
|
_ = Chunk(**final_chunk)
|
||||||
|
|
||||||
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
|
||||||
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
sres = await settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||||
res["total"] = sres.total
|
res["total"] = sres.total
|
||||||
for id in sres.ids:
|
for id in sres.ids:
|
||||||
d = {
|
d = {
|
||||||
@ -1286,6 +1286,9 @@ async def rm_chunk(tenant_id, dataset_id, document_id):
|
|||||||
if "chunk_ids" in req:
|
if "chunk_ids" in req:
|
||||||
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
|
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
|
||||||
condition["id"] = unique_chunk_ids
|
condition["id"] = unique_chunk_ids
|
||||||
|
else:
|
||||||
|
unique_chunk_ids = []
|
||||||
|
duplicate_messages = []
|
||||||
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
||||||
if chunk_number != 0:
|
if chunk_number != 0:
|
||||||
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
||||||
@ -1517,10 +1520,11 @@ async def retrieval_test(tenant_id):
|
|||||||
langs = req.get("cross_languages", [])
|
langs = req.get("cross_languages", [])
|
||||||
if not isinstance(doc_ids, list):
|
if not isinstance(doc_ids, list):
|
||||||
return get_error_data_result("`documents` should be a list")
|
return get_error_data_result("`documents` should be a list")
|
||||||
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
|
if doc_ids:
|
||||||
for doc_id in doc_ids:
|
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
|
||||||
if doc_id not in doc_ids_list:
|
for doc_id in doc_ids:
|
||||||
return get_error_data_result(f"The datasets don't own the document {doc_id}")
|
if doc_id not in doc_ids_list:
|
||||||
|
return get_error_data_result(f"The datasets don't own the document {doc_id}")
|
||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
metadata_condition = req.get("metadata_condition", {}) or {}
|
metadata_condition = req.get("metadata_condition", {}) or {}
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
@ -1555,7 +1559,7 @@ async def retrieval_test(tenant_id):
|
|||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += await keyword_extraction(chat_mdl, question)
|
question += await keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = await settings.retriever.retrieval(
|
||||||
question,
|
question,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
@ -1572,11 +1576,11 @@ async def retrieval_test(tenant_id):
|
|||||||
)
|
)
|
||||||
if toc_enhance:
|
if toc_enhance:
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
cks = settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
|
cks = await settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
|
||||||
if cks:
|
if cks:
|
||||||
ranks["chunks"] = cks
|
ranks["chunks"] = cks
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
ranks["chunks"].insert(0, ck)
|
ranks["chunks"].insert(0, ck)
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
from quart import request, make_response
|
from quart import request, make_response
|
||||||
@ -24,7 +23,7 @@ from api.db.services.document_service import DocumentService
|
|||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
|
from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid, thread_pool_exec
|
||||||
from api.db import FileType
|
from api.db import FileType
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
@ -33,7 +32,6 @@ from api.utils.web_utils import CONTENT_TYPE_MAP
|
|||||||
from common import settings
|
from common import settings
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def upload(tenant_id):
|
async def upload(tenant_id):
|
||||||
@ -205,7 +203,8 @@ async def create(tenant_id):
|
|||||||
if not FileService.is_parent_folder_exist(pf_id):
|
if not FileService.is_parent_folder_exist(pf_id):
|
||||||
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=RetCode.BAD_REQUEST)
|
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=RetCode.BAD_REQUEST)
|
||||||
if FileService.query(name=req["name"], parent_id=pf_id):
|
if FileService.query(name=req["name"], parent_id=pf_id):
|
||||||
return get_json_result(data=False, message="Duplicated folder name in the same folder.", code=409)
|
return get_json_result(data=False, message="Duplicated folder name in the same folder.",
|
||||||
|
code=RetCode.CONFLICT)
|
||||||
|
|
||||||
if input_file_type == FileType.FOLDER.value:
|
if input_file_type == FileType.FOLDER.value:
|
||||||
file_type = FileType.FOLDER.value
|
file_type = FileType.FOLDER.value
|
||||||
@ -565,11 +564,13 @@ async def rename(tenant_id):
|
|||||||
|
|
||||||
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||||
file.name.lower()).suffix:
|
file.name.lower()).suffix:
|
||||||
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.BAD_REQUEST)
|
return get_json_result(data=False, message="The extension of file can't be changed",
|
||||||
|
code=RetCode.BAD_REQUEST)
|
||||||
|
|
||||||
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||||
if existing_file.name == req["name"]:
|
if existing_file.name == req["name"]:
|
||||||
return get_json_result(data=False, message="Duplicated file name in the same folder.", code=409)
|
return get_json_result(data=False, message="Duplicated file name in the same folder.",
|
||||||
|
code=RetCode.CONFLICT)
|
||||||
|
|
||||||
if not FileService.update_by_id(req["file_id"], {"name": req["name"]}):
|
if not FileService.update_by_id(req["file_id"], {"name": req["name"]}):
|
||||||
return get_json_result(message="Database error (File rename)!", code=RetCode.SERVER_ERROR)
|
return get_json_result(message="Database error (File rename)!", code=RetCode.SERVER_ERROR)
|
||||||
@ -631,12 +632,13 @@ async def get(tenant_id, file_id):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/file/download/<attachment_id>", methods=["GET"]) # noqa: F821
|
@manager.route("/file/download/<attachment_id>", methods=["GET"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def download_attachment(tenant_id,attachment_id):
|
async def download_attachment(tenant_id, attachment_id):
|
||||||
try:
|
try:
|
||||||
ext = request.args.get("ext", "markdown")
|
ext = request.args.get("ext", "markdown")
|
||||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
data = await thread_pool_exec(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
||||||
response = await make_response(data)
|
response = await make_response(data)
|
||||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||||
|
|
||||||
@ -645,6 +647,7 @@ async def download_attachment(tenant_id,attachment_id):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def move(tenant_id):
|
async def move(tenant_id):
|
||||||
|
|||||||
@ -14,39 +14,79 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
from quart import request
|
from quart import request
|
||||||
from api.apps import login_required, current_user
|
from api.apps import login_required, current_user
|
||||||
from api.db import TenantPermission
|
from api.db import TenantPermission
|
||||||
from api.db.services.memory_service import MemoryService
|
from api.db.services.memory_service import MemoryService
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
not_allowed_parameters
|
from api.db.services.task_service import TaskService
|
||||||
|
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default
|
||||||
|
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
|
||||||
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||||
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
||||||
|
from memory.services.messages import MessageService
|
||||||
|
from memory.utils.prompt_util import PromptAssembler
|
||||||
from common.constants import MemoryType, RetCode, ForgettingPolicy
|
from common.constants import MemoryType, RetCode, ForgettingPolicy
|
||||||
|
|
||||||
|
|
||||||
@manager.route("", methods=["POST"]) # noqa: F821
|
@manager.route("/memories", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("name", "memory_type", "embd_id", "llm_id")
|
@validate_request("name", "memory_type", "embd_id", "llm_id")
|
||||||
async def create_memory():
|
async def create_memory():
|
||||||
|
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
|
||||||
|
t_start = time.perf_counter() if timing_enabled else None
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
|
t_parsed = time.perf_counter() if timing_enabled else None
|
||||||
# check name length
|
# check name length
|
||||||
name = req["name"]
|
name = req["name"]
|
||||||
memory_name = name.strip()
|
memory_name = name.strip()
|
||||||
if len(memory_name) == 0:
|
if len(memory_name) == 0:
|
||||||
|
if timing_enabled:
|
||||||
|
logging.info(
|
||||||
|
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
|
||||||
|
(t_parsed - t_start) * 1000,
|
||||||
|
(time.perf_counter() - t_start) * 1000,
|
||||||
|
request.path,
|
||||||
|
)
|
||||||
return get_error_argument_result("Memory name cannot be empty or whitespace.")
|
return get_error_argument_result("Memory name cannot be empty or whitespace.")
|
||||||
if len(memory_name) > MEMORY_NAME_LIMIT:
|
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||||
|
if timing_enabled:
|
||||||
|
logging.info(
|
||||||
|
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
|
||||||
|
(t_parsed - t_start) * 1000,
|
||||||
|
(time.perf_counter() - t_start) * 1000,
|
||||||
|
request.path,
|
||||||
|
)
|
||||||
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
|
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
|
||||||
# check memory_type valid
|
# check memory_type valid
|
||||||
|
if not isinstance(req["memory_type"], list):
|
||||||
|
if timing_enabled:
|
||||||
|
logging.info(
|
||||||
|
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
|
||||||
|
(t_parsed - t_start) * 1000,
|
||||||
|
(time.perf_counter() - t_start) * 1000,
|
||||||
|
request.path,
|
||||||
|
)
|
||||||
|
return get_error_argument_result("Memory type must be a list.")
|
||||||
memory_type = set(req["memory_type"])
|
memory_type = set(req["memory_type"])
|
||||||
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
||||||
if invalid_type:
|
if invalid_type:
|
||||||
|
if timing_enabled:
|
||||||
|
logging.info(
|
||||||
|
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
|
||||||
|
(t_parsed - t_start) * 1000,
|
||||||
|
(time.perf_counter() - t_start) * 1000,
|
||||||
|
request.path,
|
||||||
|
)
|
||||||
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
|
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
|
||||||
memory_type = list(memory_type)
|
memory_type = list(memory_type)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
t_before_db = time.perf_counter() if timing_enabled else None
|
||||||
res, memory = MemoryService.create_memory(
|
res, memory = MemoryService.create_memory(
|
||||||
tenant_id=current_user.id,
|
tenant_id=current_user.id,
|
||||||
name=memory_name,
|
name=memory_name,
|
||||||
@ -54,10 +94,18 @@ async def create_memory():
|
|||||||
embd_id=req["embd_id"],
|
embd_id=req["embd_id"],
|
||||||
llm_id=req["llm_id"]
|
llm_id=req["llm_id"]
|
||||||
)
|
)
|
||||||
|
if timing_enabled:
|
||||||
|
logging.info(
|
||||||
|
"api_timing create_memory parse_ms=%.2f validate_ms=%.2f db_ms=%.2f total_ms=%.2f path=%s",
|
||||||
|
(t_parsed - t_start) * 1000,
|
||||||
|
(t_before_db - t_parsed) * 1000,
|
||||||
|
(time.perf_counter() - t_before_db) * 1000,
|
||||||
|
(time.perf_counter() - t_start) * 1000,
|
||||||
|
request.path,
|
||||||
|
)
|
||||||
|
|
||||||
if res:
|
if res:
|
||||||
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return get_json_result(message=memory, code=RetCode.SERVER_ERROR)
|
return get_json_result(message=memory, code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
@ -65,9 +113,8 @@ async def create_memory():
|
|||||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/<memory_id>", methods=["PUT"]) # noqa: F821
|
@manager.route("/memories/<memory_id>", methods=["PUT"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type", "embd_id")
|
|
||||||
async def update_memory(memory_id):
|
async def update_memory(memory_id):
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
update_dict = {}
|
update_dict = {}
|
||||||
@ -87,6 +134,14 @@ async def update_memory(memory_id):
|
|||||||
update_dict["permissions"] = req["permissions"]
|
update_dict["permissions"] = req["permissions"]
|
||||||
if req.get("llm_id"):
|
if req.get("llm_id"):
|
||||||
update_dict["llm_id"] = req["llm_id"]
|
update_dict["llm_id"] = req["llm_id"]
|
||||||
|
if req.get("embd_id"):
|
||||||
|
update_dict["embd_id"] = req["embd_id"]
|
||||||
|
if req.get("memory_type"):
|
||||||
|
memory_type = set(req["memory_type"])
|
||||||
|
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
||||||
|
if invalid_type:
|
||||||
|
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
|
||||||
|
update_dict["memory_type"] = list(memory_type)
|
||||||
# check memory_size valid
|
# check memory_size valid
|
||||||
if req.get("memory_size"):
|
if req.get("memory_size"):
|
||||||
if not 0 < int(req["memory_size"]) <= MEMORY_SIZE_LIMIT:
|
if not 0 < int(req["memory_size"]) <= MEMORY_SIZE_LIMIT:
|
||||||
@ -122,9 +177,18 @@ async def update_memory(memory_id):
|
|||||||
|
|
||||||
if not to_update:
|
if not to_update:
|
||||||
return get_json_result(message=True, data=memory_dict)
|
return get_json_result(message=True, data=memory_dict)
|
||||||
|
# check memory empty when update embd_id, memory_type
|
||||||
|
memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id)
|
||||||
|
not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0]
|
||||||
|
if not_allowed_update:
|
||||||
|
return get_error_argument_result(f"Can't update {not_allowed_update} when memory isn't empty.")
|
||||||
|
if "memory_type" in to_update:
|
||||||
|
if "system_prompt" not in to_update and judge_system_prompt_is_default(current_memory.system_prompt, current_memory.memory_type):
|
||||||
|
# update old default prompt, assemble a new one
|
||||||
|
to_update["system_prompt"] = PromptAssembler.assemble_system_prompt({"memory_type": to_update["memory_type"]})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
MemoryService.update_memory(memory_id, to_update)
|
MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update)
|
||||||
updated_memory = MemoryService.get_by_memory_id(memory_id)
|
updated_memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory))
|
return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory))
|
||||||
|
|
||||||
@ -133,7 +197,7 @@ async def update_memory(memory_id):
|
|||||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
|
@manager.route("/memories/<memory_id>", methods=["DELETE"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def delete_memory(memory_id):
|
async def delete_memory(memory_id):
|
||||||
memory = MemoryService.get_by_memory_id(memory_id)
|
memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
@ -141,13 +205,15 @@ async def delete_memory(memory_id):
|
|||||||
return get_json_result(message=True, code=RetCode.NOT_FOUND)
|
return get_json_result(message=True, code=RetCode.NOT_FOUND)
|
||||||
try:
|
try:
|
||||||
MemoryService.delete_memory(memory_id)
|
MemoryService.delete_memory(memory_id)
|
||||||
|
if MessageService.has_index(memory.tenant_id, memory_id):
|
||||||
|
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
|
||||||
return get_json_result(message=True)
|
return get_json_result(message=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(e)
|
logging.error(e)
|
||||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("", methods=["GET"]) # noqa: F821
|
@manager.route("/memories", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def list_memory():
|
async def list_memory():
|
||||||
args = request.args
|
args = request.args
|
||||||
@ -159,13 +225,18 @@ async def list_memory():
|
|||||||
page = int(args.get("page", 1))
|
page = int(args.get("page", 1))
|
||||||
page_size = int(args.get("page_size", 50))
|
page_size = int(args.get("page_size", 50))
|
||||||
# make filter dict
|
# make filter dict
|
||||||
filter_dict = {"memory_type": memory_types, "storage_type": storage_type}
|
filter_dict: dict = {"storage_type": storage_type}
|
||||||
if not tenant_ids:
|
if not tenant_ids:
|
||||||
# restrict to current user's tenants
|
# restrict to current user's tenants
|
||||||
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
|
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
|
||||||
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
|
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
|
||||||
else:
|
else:
|
||||||
|
if len(tenant_ids) == 1 and ',' in tenant_ids[0]:
|
||||||
|
tenant_ids = tenant_ids[0].split(',')
|
||||||
filter_dict["tenant_id"] = tenant_ids
|
filter_dict["tenant_id"] = tenant_ids
|
||||||
|
if memory_types and len(memory_types) == 1 and ',' in memory_types[0]:
|
||||||
|
memory_types = memory_types[0].split(',')
|
||||||
|
filter_dict["memory_type"] = memory_types
|
||||||
|
|
||||||
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
|
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
|
||||||
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
|
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
|
||||||
@ -176,10 +247,45 @@ async def list_memory():
|
|||||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/<memory_id>/config", methods=["GET"]) # noqa: F821
|
@manager.route("/memories/<memory_id>/config", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def get_memory_config(memory_id):
|
async def get_memory_config(memory_id):
|
||||||
memory = MemoryService.get_with_owner_name_by_id(memory_id)
|
memory = MemoryService.get_with_owner_name_by_id(memory_id)
|
||||||
if not memory:
|
if not memory:
|
||||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||||
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/memories/<memory_id>", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_memory_detail(memory_id):
|
||||||
|
args = request.args
|
||||||
|
agent_ids = args.getlist("agent_id")
|
||||||
|
if len(agent_ids) == 1 and ',' in agent_ids[0]:
|
||||||
|
agent_ids = agent_ids[0].split(',')
|
||||||
|
keywords = args.get("keywords", "")
|
||||||
|
keywords = keywords.strip()
|
||||||
|
page = int(args.get("page", 1))
|
||||||
|
page_size = int(args.get("page_size", 50))
|
||||||
|
memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
|
if not memory:
|
||||||
|
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||||
|
messages = MessageService.list_message(
|
||||||
|
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
|
||||||
|
agent_name_mapping = {}
|
||||||
|
extract_task_mapping = {}
|
||||||
|
if messages["message_list"]:
|
||||||
|
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
|
||||||
|
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
|
||||||
|
task_list = TaskService.get_tasks_progress_by_doc_ids([memory_id])
|
||||||
|
if task_list:
|
||||||
|
task_list.sort(key=lambda t: t["create_time"]) # asc, use newer when exist more than one task
|
||||||
|
for task in task_list:
|
||||||
|
# the 'digest' field carries the source_id when a task is created, so use 'digest' as key
|
||||||
|
extract_task_mapping.update({int(task["digest"]): task})
|
||||||
|
for message in messages["message_list"]:
|
||||||
|
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
|
||||||
|
message["task"] = extract_task_mapping.get(message["message_id"], {})
|
||||||
|
for extract_msg in message["extract"]:
|
||||||
|
extract_msg["agent_name"] = agent_name_mapping.get(extract_msg["agent_id"], "Unknown")
|
||||||
|
return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True)
|
||||||
158
api/apps/sdk/messages.py
Normal file
158
api/apps/sdk/messages.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from quart import request
|
||||||
|
from api.apps import login_required
|
||||||
|
from api.db.services.memory_service import MemoryService
|
||||||
|
from common.time_utils import current_timestamp, timestamp_to_date
|
||||||
|
|
||||||
|
from memory.services.messages import MessageService
|
||||||
|
from api.db.joint_services import memory_message_service
|
||||||
|
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
|
||||||
|
from common.constants import RetCode
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/messages", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response")
|
||||||
|
async def add_message():
|
||||||
|
|
||||||
|
req = await get_request_json()
|
||||||
|
memory_ids = req["memory_id"]
|
||||||
|
|
||||||
|
message_dict = {
|
||||||
|
"user_id": req.get("user_id"),
|
||||||
|
"agent_id": req["agent_id"],
|
||||||
|
"session_id": req["session_id"],
|
||||||
|
"user_input": req["user_input"],
|
||||||
|
"agent_response": req["agent_response"],
|
||||||
|
}
|
||||||
|
|
||||||
|
res, msg = await memory_message_service.queue_save_to_memory_task(memory_ids, message_dict)
|
||||||
|
|
||||||
|
if res:
|
||||||
|
return get_json_result(message=msg)
|
||||||
|
|
||||||
|
return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add. Detail:" + msg)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/messages/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def forget_message(memory_id: str, message_id: int):
|
||||||
|
|
||||||
|
memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
|
if not memory:
|
||||||
|
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||||
|
|
||||||
|
forget_time = timestamp_to_date(current_timestamp())
|
||||||
|
update_succeed = MessageService.update_message(
|
||||||
|
{"memory_id": memory_id, "message_id": int(message_id)},
|
||||||
|
{"forget_at": forget_time},
|
||||||
|
memory.tenant_id, memory_id)
|
||||||
|
if update_succeed:
|
||||||
|
return get_json_result(message=update_succeed)
|
||||||
|
else:
|
||||||
|
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.")
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/messages/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("status")
|
||||||
|
async def update_message(memory_id: str, message_id: int):
|
||||||
|
req = await get_request_json()
|
||||||
|
status = req["status"]
|
||||||
|
if not isinstance(status, bool):
|
||||||
|
return get_error_argument_result("Status must be a boolean.")
|
||||||
|
|
||||||
|
memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
|
if not memory:
|
||||||
|
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||||
|
|
||||||
|
update_succeed = MessageService.update_message({"memory_id": memory_id, "message_id": int(message_id)}, {"status": status}, memory.tenant_id, memory_id)
|
||||||
|
if update_succeed:
|
||||||
|
return get_json_result(message=update_succeed)
|
||||||
|
else:
|
||||||
|
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.")
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/messages/search", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def search_message():
|
||||||
|
args = request.args
|
||||||
|
empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)]
|
||||||
|
if empty_fields:
|
||||||
|
return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.")
|
||||||
|
|
||||||
|
memory_ids = args.getlist("memory_id")
|
||||||
|
if len(memory_ids) == 1 and ',' in memory_ids[0]:
|
||||||
|
memory_ids = memory_ids[0].split(',')
|
||||||
|
query = args.get("query")
|
||||||
|
similarity_threshold = float(args.get("similarity_threshold", 0.2))
|
||||||
|
keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7))
|
||||||
|
top_n = int(args.get("top_n", 5))
|
||||||
|
agent_id = args.get("agent_id", "")
|
||||||
|
session_id = args.get("session_id", "")
|
||||||
|
|
||||||
|
filter_dict = {
|
||||||
|
"memory_id": memory_ids,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"session_id": session_id
|
||||||
|
}
|
||||||
|
params = {
|
||||||
|
"query": query,
|
||||||
|
"similarity_threshold": similarity_threshold,
|
||||||
|
"keywords_similarity_weight": keywords_similarity_weight,
|
||||||
|
"top_n": top_n
|
||||||
|
}
|
||||||
|
res = memory_message_service.query_message(filter_dict, params)
|
||||||
|
return get_json_result(message=True, data=res)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/messages", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_messages():
|
||||||
|
args = request.args
|
||||||
|
memory_ids = args.getlist("memory_id")
|
||||||
|
if len(memory_ids) == 1 and ',' in memory_ids[0]:
|
||||||
|
memory_ids = memory_ids[0].split(',')
|
||||||
|
agent_id = args.get("agent_id", "")
|
||||||
|
session_id = args.get("session_id", "")
|
||||||
|
limit = int(args.get("limit", 10))
|
||||||
|
if not memory_ids:
|
||||||
|
return get_error_argument_result("memory_ids is required.")
|
||||||
|
memory_list = MemoryService.get_by_ids(memory_ids)
|
||||||
|
uids = [memory.tenant_id for memory in memory_list]
|
||||||
|
res = MessageService.get_recent_messages(
|
||||||
|
uids,
|
||||||
|
memory_ids,
|
||||||
|
agent_id,
|
||||||
|
session_id,
|
||||||
|
limit
|
||||||
|
)
|
||||||
|
return get_json_result(message=True, data=res)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/messages/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_message_content(memory_id:str, message_id: int):
|
||||||
|
memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
|
if not memory:
|
||||||
|
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||||
|
|
||||||
|
res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id)
|
||||||
|
if res:
|
||||||
|
return get_json_result(message=True, data=res)
|
||||||
|
else:
|
||||||
|
return get_json_result(code=RetCode.NOT_FOUND, message=f"Message '{message_id}' in memory '{memory_id}' not found.")
|
||||||
@ -14,10 +14,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
|
import copy
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import logging
|
||||||
|
|
||||||
from quart import Response, jsonify, request
|
from quart import Response, jsonify, request
|
||||||
|
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
@ -32,9 +37,9 @@ from api.db.services.dialog_service import DialogService, async_ask, async_chat,
|
|||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from common.metadata_utils import apply_meta_data_filter
|
from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter
|
||||||
from api.db.services.search_service import SearchService
|
from api.db.services.search_service import SearchService
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import TenantService,UserTenantService
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
|
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
|
||||||
get_result, get_request_json, server_error_response, token_required, validate_request
|
get_result, get_request_json, server_error_response, token_required, validate_request
|
||||||
@ -59,7 +64,7 @@ async def create(tenant_id, chat_id):
|
|||||||
"name": req.get("name", "New session"),
|
"name": req.get("name", "New session"),
|
||||||
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}],
|
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}],
|
||||||
"user_id": req.get("user_id", ""),
|
"user_id": req.get("user_id", ""),
|
||||||
"reference": [{}],
|
"reference": [],
|
||||||
}
|
}
|
||||||
if not conv.get("name"):
|
if not conv.get("name"):
|
||||||
return get_error_data_result(message="`name` can not be empty.")
|
return get_error_data_result(message="`name` can not be empty.")
|
||||||
@ -87,7 +92,7 @@ async def create_agent_session(tenant_id, agent_id):
|
|||||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
|
|
||||||
session_id = get_uuid()
|
session_id = get_uuid()
|
||||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
canvas = Canvas(cvs.dsl, tenant_id, agent_id, canvas_id=cvs.id)
|
||||||
canvas.reset()
|
canvas.reset()
|
||||||
|
|
||||||
cvs.dsl = json.loads(str(canvas))
|
cvs.dsl = json.loads(str(canvas))
|
||||||
@ -128,11 +133,33 @@ async def chat_completion(tenant_id, chat_id):
|
|||||||
req = {"question": ""}
|
req = {"question": ""}
|
||||||
if not req.get("session_id"):
|
if not req.get("session_id"):
|
||||||
req["question"] = ""
|
req["question"] = ""
|
||||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value)
|
||||||
|
if not dia:
|
||||||
return get_error_data_result(f"You don't own the chat {chat_id}")
|
return get_error_data_result(f"You don't own the chat {chat_id}")
|
||||||
|
dia = dia[0]
|
||||||
if req.get("session_id"):
|
if req.get("session_id"):
|
||||||
if not ConversationService.query(id=req["session_id"], dialog_id=chat_id):
|
if not ConversationService.query(id=req["session_id"], dialog_id=chat_id):
|
||||||
return get_error_data_result(f"You don't own the session {req['session_id']}")
|
return get_error_data_result(f"You don't own the session {req['session_id']}")
|
||||||
|
|
||||||
|
metadata_condition = req.get("metadata_condition") or {}
|
||||||
|
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||||
|
return get_error_data_result(message="metadata_condition must be an object.")
|
||||||
|
|
||||||
|
if metadata_condition and req.get("question"):
|
||||||
|
metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
|
||||||
|
filtered_doc_ids = meta_filter(
|
||||||
|
metas,
|
||||||
|
convert_conditions(metadata_condition),
|
||||||
|
metadata_condition.get("logic", "and"),
|
||||||
|
)
|
||||||
|
if metadata_condition.get("conditions") and not filtered_doc_ids:
|
||||||
|
filtered_doc_ids = ["-999"]
|
||||||
|
|
||||||
|
if filtered_doc_ids:
|
||||||
|
req["doc_ids"] = ",".join(filtered_doc_ids)
|
||||||
|
else:
|
||||||
|
req.pop("doc_ids", None)
|
||||||
|
|
||||||
if req.get("stream", True):
|
if req.get("stream", True):
|
||||||
resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
|
resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
|
||||||
resp.headers.add_header("Cache-control", "no-cache")
|
resp.headers.add_header("Cache-control", "no-cache")
|
||||||
@ -195,7 +222,19 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
{"role": "user", "content": "Can you tell me how to install neovim"},
|
{"role": "user", "content": "Can you tell me how to install neovim"},
|
||||||
],
|
],
|
||||||
stream=stream,
|
stream=stream,
|
||||||
extra_body={"reference": reference}
|
extra_body={
|
||||||
|
"reference": reference,
|
||||||
|
"metadata_condition": {
|
||||||
|
"logic": "and",
|
||||||
|
"conditions": [
|
||||||
|
{
|
||||||
|
"name": "author",
|
||||||
|
"comparison_operator": "is",
|
||||||
|
"value": "bob"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
@ -211,7 +250,11 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
"""
|
"""
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
|
|
||||||
need_reference = bool(req.get("reference", False))
|
extra_body = req.get("extra_body") or {}
|
||||||
|
if extra_body and not isinstance(extra_body, dict):
|
||||||
|
return get_error_data_result("extra_body must be an object.")
|
||||||
|
|
||||||
|
need_reference = bool(extra_body.get("reference", False))
|
||||||
|
|
||||||
messages = req.get("messages", [])
|
messages = req.get("messages", [])
|
||||||
# To prevent empty [] input
|
# To prevent empty [] input
|
||||||
@ -229,6 +272,22 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
return get_error_data_result(f"You don't own the chat {chat_id}")
|
return get_error_data_result(f"You don't own the chat {chat_id}")
|
||||||
dia = dia[0]
|
dia = dia[0]
|
||||||
|
|
||||||
|
metadata_condition = extra_body.get("metadata_condition") or {}
|
||||||
|
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||||
|
return get_error_data_result(message="metadata_condition must be an object.")
|
||||||
|
|
||||||
|
doc_ids_str = None
|
||||||
|
if metadata_condition:
|
||||||
|
metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
|
||||||
|
filtered_doc_ids = meta_filter(
|
||||||
|
metas,
|
||||||
|
convert_conditions(metadata_condition),
|
||||||
|
metadata_condition.get("logic", "and"),
|
||||||
|
)
|
||||||
|
if metadata_condition.get("conditions") and not filtered_doc_ids:
|
||||||
|
filtered_doc_ids = ["-999"]
|
||||||
|
doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None
|
||||||
|
|
||||||
# Filter system and non-sense assistant messages
|
# Filter system and non-sense assistant messages
|
||||||
msg = []
|
msg = []
|
||||||
for m in messages:
|
for m in messages:
|
||||||
@ -249,9 +308,12 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
# The choices field on the last chunk will always be an empty array [].
|
# The choices field on the last chunk will always be an empty array [].
|
||||||
async def streamed_response_generator(chat_id, dia, msg):
|
async def streamed_response_generator(chat_id, dia, msg):
|
||||||
token_used = 0
|
token_used = 0
|
||||||
answer_cache = ""
|
|
||||||
reasoning_cache = ""
|
|
||||||
last_ans = {}
|
last_ans = {}
|
||||||
|
full_content = ""
|
||||||
|
full_reasoning = ""
|
||||||
|
final_answer = None
|
||||||
|
final_reference = None
|
||||||
|
in_think = False
|
||||||
response = {
|
response = {
|
||||||
"id": f"chatcmpl-{chat_id}",
|
"id": f"chatcmpl-{chat_id}",
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -276,49 +338,35 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for ans in async_chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
|
||||||
|
if doc_ids_str:
|
||||||
|
chat_kwargs["doc_ids"] = doc_ids_str
|
||||||
|
async for ans in async_chat(dia, msg, True, **chat_kwargs):
|
||||||
last_ans = ans
|
last_ans = ans
|
||||||
answer = ans["answer"]
|
if ans.get("final"):
|
||||||
|
if ans.get("answer"):
|
||||||
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
full_content = ans["answer"]
|
||||||
if reasoning_match:
|
final_answer = ans.get("answer") or full_content
|
||||||
reasoning_part = reasoning_match.group(1)
|
final_reference = ans.get("reference", {})
|
||||||
content_part = answer[reasoning_match.end():]
|
|
||||||
else:
|
|
||||||
reasoning_part = ""
|
|
||||||
content_part = answer
|
|
||||||
|
|
||||||
reasoning_incremental = ""
|
|
||||||
if reasoning_part:
|
|
||||||
if reasoning_part.startswith(reasoning_cache):
|
|
||||||
reasoning_incremental = reasoning_part.replace(reasoning_cache, "", 1)
|
|
||||||
else:
|
|
||||||
reasoning_incremental = reasoning_part
|
|
||||||
reasoning_cache = reasoning_part
|
|
||||||
|
|
||||||
content_incremental = ""
|
|
||||||
if content_part:
|
|
||||||
if content_part.startswith(answer_cache):
|
|
||||||
content_incremental = content_part.replace(answer_cache, "", 1)
|
|
||||||
else:
|
|
||||||
content_incremental = content_part
|
|
||||||
answer_cache = content_part
|
|
||||||
|
|
||||||
token_used += len(reasoning_incremental) + len(content_incremental)
|
|
||||||
|
|
||||||
if not any([reasoning_incremental, content_incremental]):
|
|
||||||
continue
|
continue
|
||||||
|
if ans.get("start_to_think"):
|
||||||
if reasoning_incremental:
|
in_think = True
|
||||||
response["choices"][0]["delta"]["reasoning_content"] = reasoning_incremental
|
continue
|
||||||
else:
|
if ans.get("end_to_think"):
|
||||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
in_think = False
|
||||||
|
continue
|
||||||
if content_incremental:
|
delta = ans.get("answer") or ""
|
||||||
response["choices"][0]["delta"]["content"] = content_incremental
|
if not delta:
|
||||||
else:
|
continue
|
||||||
|
token_used += len(delta)
|
||||||
|
if in_think:
|
||||||
|
full_reasoning += delta
|
||||||
|
response["choices"][0]["delta"]["reasoning_content"] = delta
|
||||||
response["choices"][0]["delta"]["content"] = None
|
response["choices"][0]["delta"]["content"] = None
|
||||||
|
else:
|
||||||
|
full_content += delta
|
||||||
|
response["choices"][0]["delta"]["content"] = delta
|
||||||
|
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
|
response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
|
||||||
@ -328,11 +376,11 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
response["choices"][0]["delta"]["content"] = None
|
response["choices"][0]["delta"]["content"] = None
|
||||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||||
response["choices"][0]["finish_reason"] = "stop"
|
response["choices"][0]["finish_reason"] = "stop"
|
||||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used,
|
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
|
||||||
"total_tokens": len(prompt) + token_used}
|
|
||||||
if need_reference:
|
if need_reference:
|
||||||
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
|
reference_payload = final_reference if final_reference is not None else last_ans.get("reference", [])
|
||||||
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
|
response["choices"][0]["delta"]["reference"] = chunks_format(reference_payload)
|
||||||
|
response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content
|
||||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||||
yield "data:[DONE]\n\n"
|
yield "data:[DONE]\n\n"
|
||||||
|
|
||||||
@ -344,7 +392,10 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
return resp
|
return resp
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
async for ans in async_chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
|
||||||
|
if doc_ids_str:
|
||||||
|
chat_kwargs["doc_ids"] = doc_ids_str
|
||||||
|
async for ans in async_chat(dia, msg, False, **chat_kwargs):
|
||||||
# focus answer content only
|
# focus answer content only
|
||||||
answer = ans
|
answer = ans
|
||||||
break
|
break
|
||||||
@ -378,7 +429,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
if need_reference:
|
if need_reference:
|
||||||
response["choices"][0]["message"]["reference"] = chunks_format(answer.get("reference", []))
|
response["choices"][0]["message"]["reference"] = chunks_format(answer.get("reference", {}))
|
||||||
|
|
||||||
return jsonify(response)
|
return jsonify(response)
|
||||||
|
|
||||||
@ -388,7 +439,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
@token_required
|
@token_required
|
||||||
async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
tiktokenenc = tiktoken.get_encoding("cl100k_base")
|
tiktoken_encode = tiktoken.get_encoding("cl100k_base")
|
||||||
messages = req.get("messages", [])
|
messages = req.get("messages", [])
|
||||||
if not messages:
|
if not messages:
|
||||||
return get_error_data_result("You must provide at least one message.")
|
return get_error_data_result("You must provide at least one message.")
|
||||||
@ -396,7 +447,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
|||||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||||
|
|
||||||
filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
|
filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
|
||||||
prompt_tokens = sum(len(tiktokenenc.encode(m["content"])) for m in filtered_messages)
|
prompt_tokens = sum(len(tiktoken_encode.encode(m["content"])) for m in filtered_messages)
|
||||||
if not filtered_messages:
|
if not filtered_messages:
|
||||||
return jsonify(
|
return jsonify(
|
||||||
get_data_openai(
|
get_data_openai(
|
||||||
@ -404,7 +455,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
|||||||
content="No valid messages found (user or assistant).",
|
content="No valid messages found (user or assistant).",
|
||||||
finish_reason="stop",
|
finish_reason="stop",
|
||||||
model=req.get("model", ""),
|
model=req.get("model", ""),
|
||||||
completion_tokens=len(tiktokenenc.encode("No valid messages found (user or assistant).")),
|
completion_tokens=len(tiktoken_encode.encode("No valid messages found (user or assistant).")),
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -441,15 +492,19 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
|||||||
):
|
):
|
||||||
return jsonify(response)
|
return jsonify(response)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def agent_completions(tenant_id, agent_id):
|
async def agent_completions(tenant_id, agent_id):
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
|
return_trace = bool(req.get("return_trace", False))
|
||||||
|
|
||||||
if req.get("stream", True):
|
if req.get("stream", True):
|
||||||
|
|
||||||
async def generate():
|
async def generate():
|
||||||
|
trace_items = []
|
||||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||||
if isinstance(answer, str):
|
if isinstance(answer, str):
|
||||||
try:
|
try:
|
||||||
@ -457,7 +512,21 @@ async def agent_completions(tenant_id, agent_id):
|
|||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if ans.get("event") not in ["message", "message_end"]:
|
event = ans.get("event")
|
||||||
|
if event == "node_finished":
|
||||||
|
if return_trace:
|
||||||
|
data = ans.get("data", {})
|
||||||
|
trace_items.append(
|
||||||
|
{
|
||||||
|
"component_id": data.get("component_id"),
|
||||||
|
"trace": [copy.deepcopy(data)],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ans.setdefault("data", {})["trace"] = trace_items
|
||||||
|
answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||||
|
yield answer
|
||||||
|
|
||||||
|
if event not in ["message", "message_end"]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield answer
|
yield answer
|
||||||
@ -474,6 +543,7 @@ async def agent_completions(tenant_id, agent_id):
|
|||||||
full_content = ""
|
full_content = ""
|
||||||
reference = {}
|
reference = {}
|
||||||
final_ans = ""
|
final_ans = ""
|
||||||
|
trace_items = []
|
||||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||||
try:
|
try:
|
||||||
ans = json.loads(answer[5:])
|
ans = json.loads(answer[5:])
|
||||||
@ -484,11 +554,22 @@ async def agent_completions(tenant_id, agent_id):
|
|||||||
if ans.get("data", {}).get("reference", None):
|
if ans.get("data", {}).get("reference", None):
|
||||||
reference.update(ans["data"]["reference"])
|
reference.update(ans["data"]["reference"])
|
||||||
|
|
||||||
|
if return_trace and ans.get("event") == "node_finished":
|
||||||
|
data = ans.get("data", {})
|
||||||
|
trace_items.append(
|
||||||
|
{
|
||||||
|
"component_id": data.get("component_id"),
|
||||||
|
"trace": [copy.deepcopy(data)],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
final_ans = ans
|
final_ans = ans
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return get_result(data=f"**ERROR**: {str(e)}")
|
return get_result(data=f"**ERROR**: {str(e)}")
|
||||||
final_ans["data"]["content"] = full_content
|
final_ans["data"]["content"] = full_content
|
||||||
final_ans["data"]["reference"] = reference
|
final_ans["data"]["reference"] = reference
|
||||||
|
if return_trace and final_ans:
|
||||||
|
final_ans["data"]["trace"] = trace_items
|
||||||
return get_result(data=final_ans)
|
return get_result(data=final_ans)
|
||||||
|
|
||||||
|
|
||||||
@ -832,6 +913,7 @@ async def chatbot_completions(dialog_id):
|
|||||||
async for answer in iframe_completion(dialog_id, **req):
|
async for answer in iframe_completion(dialog_id, **req):
|
||||||
return get_result(data=answer)
|
return get_result(data=answer)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
|
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
|
||||||
async def chatbots_inputs(dialog_id):
|
async def chatbots_inputs(dialog_id):
|
||||||
@ -879,6 +961,7 @@ async def agent_bot_completions(agent_id):
|
|||||||
async for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
|
async for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
|
||||||
return get_result(data=answer)
|
return get_result(data=answer)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
|
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
|
||||||
async def begin_inputs(agent_id):
|
async def begin_inputs(agent_id):
|
||||||
@ -894,7 +977,7 @@ async def begin_inputs(agent_id):
|
|||||||
if not e:
|
if not e:
|
||||||
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
|
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
|
||||||
|
|
||||||
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id)
|
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id, canvas_id=cvs.id)
|
||||||
return get_result(
|
return get_result(
|
||||||
data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"),
|
data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"),
|
||||||
"prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
|
"prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
|
||||||
@ -966,11 +1049,13 @@ async def retrieval_test_embedded():
|
|||||||
use_kg = req.get("use_kg", False)
|
use_kg = req.get("use_kg", False)
|
||||||
top = int(req.get("top_k", 1024))
|
top = int(req.get("top_k", 1024))
|
||||||
langs = req.get("cross_languages", [])
|
langs = req.get("cross_languages", [])
|
||||||
|
rerank_id = req.get("rerank_id", "")
|
||||||
tenant_id = objs[0].tenant_id
|
tenant_id = objs[0].tenant_id
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_error_data_result(message="permission denined.")
|
return get_error_data_result(message="permission denined.")
|
||||||
|
|
||||||
async def _retrieval():
|
async def _retrieval():
|
||||||
|
nonlocal similarity_threshold, vector_similarity_weight, top, rerank_id
|
||||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||||
tenant_ids = []
|
tenant_ids = []
|
||||||
_question = question
|
_question = question
|
||||||
@ -982,6 +1067,15 @@ async def retrieval_test_embedded():
|
|||||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
|
# Apply search_config settings if not explicitly provided in request
|
||||||
|
if not req.get("similarity_threshold"):
|
||||||
|
similarity_threshold = float(search_config.get("similarity_threshold", similarity_threshold))
|
||||||
|
if not req.get("vector_similarity_weight"):
|
||||||
|
vector_similarity_weight = float(search_config.get("vector_similarity_weight", vector_similarity_weight))
|
||||||
|
if not req.get("top_k"):
|
||||||
|
top = int(search_config.get("top_k", top))
|
||||||
|
if not req.get("rerank_id"):
|
||||||
|
rerank_id = search_config.get("rerank_id", "")
|
||||||
else:
|
else:
|
||||||
meta_data_filter = req.get("meta_data_filter") or {}
|
meta_data_filter = req.get("meta_data_filter") or {}
|
||||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||||
@ -1011,20 +1105,20 @@ async def retrieval_test_embedded():
|
|||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
rerank_mdl = None
|
rerank_mdl = None
|
||||||
if req.get("rerank_id"):
|
if rerank_id:
|
||||||
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=rerank_id)
|
||||||
|
|
||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
_question += await keyword_extraction(chat_mdl, _question)
|
_question += await keyword_extraction(chat_mdl, _question)
|
||||||
|
|
||||||
labels = label_question(_question, [kb])
|
labels = label_question(_question, [kb])
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = await settings.retriever.retrieval(
|
||||||
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||||
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
|
ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
|
||||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
ranks["chunks"].insert(0, ck)
|
ranks["chunks"].insert(0, ck)
|
||||||
@ -1141,3 +1235,93 @@ async def mindmap():
|
|||||||
if "error" in mind_map:
|
if "error" in mind_map:
|
||||||
return server_error_response(Exception(mind_map["error"]))
|
return server_error_response(Exception(mind_map["error"]))
|
||||||
return get_json_result(data=mind_map)
|
return get_json_result(data=mind_map)
|
||||||
|
|
||||||
|
@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821
|
||||||
|
@token_required
|
||||||
|
async def sequence2txt(tenant_id):
|
||||||
|
req = await request.form
|
||||||
|
stream_mode = req.get("stream", "false").lower() == "true"
|
||||||
|
files = await request.files
|
||||||
|
if "file" not in files:
|
||||||
|
return get_error_data_result(message="Missing 'file' in multipart form-data")
|
||||||
|
|
||||||
|
uploaded = files["file"]
|
||||||
|
|
||||||
|
ALLOWED_EXTS = {
|
||||||
|
".wav", ".mp3", ".m4a", ".aac",
|
||||||
|
".flac", ".ogg", ".webm",
|
||||||
|
".opus", ".wma"
|
||||||
|
}
|
||||||
|
|
||||||
|
filename = uploaded.filename or ""
|
||||||
|
suffix = os.path.splitext(filename)[-1].lower()
|
||||||
|
if suffix not in ALLOWED_EXTS:
|
||||||
|
return get_error_data_result(message=
|
||||||
|
f"Unsupported audio format: {suffix}. "
|
||||||
|
f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}"
|
||||||
|
)
|
||||||
|
fd, temp_audio_path = tempfile.mkstemp(suffix=suffix)
|
||||||
|
os.close(fd)
|
||||||
|
await uploaded.save(temp_audio_path)
|
||||||
|
|
||||||
|
tenants = TenantService.get_info_by(tenant_id)
|
||||||
|
if not tenants:
|
||||||
|
return get_error_data_result(message="Tenant not found!")
|
||||||
|
|
||||||
|
asr_id = tenants[0]["asr_id"]
|
||||||
|
if not asr_id:
|
||||||
|
return get_error_data_result(message="No default ASR model is set")
|
||||||
|
|
||||||
|
asr_mdl=LLMBundle(tenants[0]["tenant_id"], LLMType.SPEECH2TEXT, asr_id)
|
||||||
|
if not stream_mode:
|
||||||
|
text = asr_mdl.transcription(temp_audio_path)
|
||||||
|
try:
|
||||||
|
os.remove(temp_audio_path)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||||
|
return get_json_result(data={"text": text})
|
||||||
|
async def event_stream():
|
||||||
|
try:
|
||||||
|
for evt in asr_mdl.stream_transcription(temp_audio_path):
|
||||||
|
yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
err = {"event": "error", "text": str(e)}
|
||||||
|
yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n"
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
os.remove(temp_audio_path)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||||
|
|
||||||
|
return Response(event_stream(), content_type="text/event-stream")
|
||||||
|
|
||||||
|
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
||||||
|
@token_required
|
||||||
|
async def tts(tenant_id):
|
||||||
|
req = await get_request_json()
|
||||||
|
text = req["text"]
|
||||||
|
|
||||||
|
tenants = TenantService.get_info_by(tenant_id)
|
||||||
|
if not tenants:
|
||||||
|
return get_error_data_result(message="Tenant not found!")
|
||||||
|
|
||||||
|
tts_id = tenants[0]["tts_id"]
|
||||||
|
if not tts_id:
|
||||||
|
return get_error_data_result(message="No default TTS model is set")
|
||||||
|
|
||||||
|
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
|
||||||
|
|
||||||
|
def stream_audio():
|
||||||
|
try:
|
||||||
|
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
|
||||||
|
for chunk in tts_mdl.tts(txt):
|
||||||
|
yield chunk
|
||||||
|
except Exception as e:
|
||||||
|
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
|
||||||
|
|
||||||
|
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
||||||
|
resp.headers.add_header("Cache-Control", "no-cache")
|
||||||
|
resp.headers.add_header("Connection", "keep-alive")
|
||||||
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
|
return resp
|
||||||
@ -177,8 +177,8 @@ def healthz():
|
|||||||
return jsonify(result), (200 if all_ok else 500)
|
return jsonify(result), (200 if all_ok else 500)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/ping", methods=["GET"]) # noqa: F821
|
@manager.route("/ping", methods=["GET"]) # noqa: F821
|
||||||
def ping():
|
async def ping():
|
||||||
return "pong", 200
|
return "pong", 200
|
||||||
|
|
||||||
|
|
||||||
@ -213,7 +213,7 @@ def new_token():
|
|||||||
if not tenants:
|
if not tenants:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
|
||||||
tenant_id = [tenant for tenant in tenants if tenant.role == 'owner'][0].tenant_id
|
tenant_id = [tenant for tenant in tenants if tenant.role == "owner"][0].tenant_id
|
||||||
obj = {
|
obj = {
|
||||||
"tenant_id": tenant_id,
|
"tenant_id": tenant_id,
|
||||||
"token": generate_confirmation_token(),
|
"token": generate_confirmation_token(),
|
||||||
@ -268,13 +268,12 @@ def token_list():
|
|||||||
if not tenants:
|
if not tenants:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
|
||||||
tenant_id = [tenant for tenant in tenants if tenant.role == 'owner'][0].tenant_id
|
tenant_id = [tenant for tenant in tenants if tenant.role == "owner"][0].tenant_id
|
||||||
objs = APITokenService.query(tenant_id=tenant_id)
|
objs = APITokenService.query(tenant_id=tenant_id)
|
||||||
objs = [o.to_dict() for o in objs]
|
objs = [o.to_dict() for o in objs]
|
||||||
for o in objs:
|
for o in objs:
|
||||||
if not o["beta"]:
|
if not o["beta"]:
|
||||||
o["beta"] = generate_confirmation_token().replace(
|
o["beta"] = generate_confirmation_token().replace("ragflow-", "")[:32]
|
||||||
"ragflow-", "")[:32]
|
|
||||||
APITokenService.filter_update([APIToken.tenant_id == tenant_id, APIToken.token == o["token"]], o)
|
APITokenService.filter_update([APIToken.tenant_id == tenant_id, APIToken.token == o["token"]], o)
|
||||||
return get_json_result(data=objs)
|
return get_json_result(data=objs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -307,13 +306,19 @@ def rm(token):
|
|||||||
type: boolean
|
type: boolean
|
||||||
description: Deletion status.
|
description: Deletion status.
|
||||||
"""
|
"""
|
||||||
APITokenService.filter_delete(
|
try:
|
||||||
[APIToken.tenant_id == current_user.id, APIToken.token == token]
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
)
|
if not tenants:
|
||||||
return get_json_result(data=True)
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
|
||||||
|
tenant_id = tenants[0].tenant_id
|
||||||
|
APITokenService.filter_delete([APIToken.tenant_id == tenant_id, APIToken.token == token])
|
||||||
|
return get_json_result(data=True)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/config', methods=['GET']) # noqa: F821
|
@manager.route("/config", methods=["GET"]) # noqa: F821
|
||||||
def get_config():
|
def get_config():
|
||||||
"""
|
"""
|
||||||
Get system configuration.
|
Get system configuration.
|
||||||
@ -330,6 +335,4 @@ def get_config():
|
|||||||
type: integer 0 means disabled, 1 means enabled
|
type: integer 0 means disabled, 1 means enabled
|
||||||
description: Whether user registration is enabled
|
description: Whether user registration is enabled
|
||||||
"""
|
"""
|
||||||
return get_json_result(data={
|
return get_json_result(data={"registerEnabled": settings.REGISTER_ENABLED})
|
||||||
"registerEnabled": settings.REGISTER_ENABLED
|
|
||||||
})
|
|
||||||
|
|||||||
@ -660,7 +660,7 @@ def user_register(user_id, user):
|
|||||||
tenant_llm = get_init_tenant_llm(user_id)
|
tenant_llm = get_init_tenant_llm(user_id)
|
||||||
|
|
||||||
if not UserService.save(**user):
|
if not UserService.save(**user):
|
||||||
return
|
return None
|
||||||
TenantService.insert(**tenant)
|
TenantService.insert(**tenant)
|
||||||
UserTenantService.insert(**usr_tenant)
|
UserTenantService.insert(**usr_tenant)
|
||||||
TenantLLMService.insert_many(tenant_llm)
|
TenantLLMService.insert_many(tenant_llm)
|
||||||
|
|||||||
@ -281,7 +281,11 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to reconnect: {e}")
|
logging.error(f"Failed to reconnect: {e}")
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
self.connect()
|
try:
|
||||||
|
self.connect()
|
||||||
|
except Exception as e2:
|
||||||
|
logging.error(f"Failed to reconnect on second attempt: {e2}")
|
||||||
|
raise
|
||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
@ -352,7 +356,11 @@ class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to reconnect to PostgreSQL: {e}")
|
logging.error(f"Failed to reconnect to PostgreSQL: {e}")
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
self.connect()
|
try:
|
||||||
|
self.connect()
|
||||||
|
except Exception as e2:
|
||||||
|
logging.error(f"Failed to reconnect to PostgreSQL on second attempt: {e2}")
|
||||||
|
raise
|
||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
@ -1189,7 +1197,7 @@ class Memory(DataBaseModel):
|
|||||||
permissions = CharField(max_length=16, null=False, index=True, help_text="me|team", default="me")
|
permissions = CharField(max_length=16, null=False, index=True, help_text="me|team", default="me")
|
||||||
description = TextField(null=True, help_text="description")
|
description = TextField(null=True, help_text="description")
|
||||||
memory_size = IntegerField(default=5242880, null=False, index=False)
|
memory_size = IntegerField(default=5242880, null=False, index=False)
|
||||||
forgetting_policy = CharField(max_length=32, null=False, default="fifo", index=False, help_text="lru|fifo")
|
forgetting_policy = CharField(max_length=32, null=False, default="FIFO", index=False, help_text="LRU|FIFO")
|
||||||
temperature = FloatField(default=0.5, index=False)
|
temperature = FloatField(default=0.5, index=False)
|
||||||
system_prompt = TextField(null=True, help_text="system prompt", index=False)
|
system_prompt = TextField(null=True, help_text="system prompt", index=False)
|
||||||
user_prompt = TextField(null=True, help_text="user prompt", index=False)
|
user_prompt = TextField(null=True, help_text="user prompt", index=False)
|
||||||
@ -1197,224 +1205,93 @@ class Memory(DataBaseModel):
|
|||||||
class Meta:
|
class Meta:
|
||||||
db_table = "memory"
|
db_table = "memory"
|
||||||
|
|
||||||
|
class SystemSettings(DataBaseModel):
|
||||||
|
name = CharField(max_length=128, primary_key=True)
|
||||||
|
source = CharField(max_length=32, null=False, index=False)
|
||||||
|
data_type = CharField(max_length=32, null=False, index=False)
|
||||||
|
value = CharField(max_length=1024, null=False, index=False)
|
||||||
|
class Meta:
|
||||||
|
db_table = "system_settings"
|
||||||
|
|
||||||
|
def alter_db_add_column(migrator, table_name, column_name, column_type):
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column(table_name, column_name, column_type))
|
||||||
|
except OperationalError as ex:
|
||||||
|
error_codes = [1060]
|
||||||
|
error_messages = ['Duplicate column name']
|
||||||
|
|
||||||
|
should_skip_error = (
|
||||||
|
(hasattr(ex, 'args') and ex.args and ex.args[0] in error_codes) or
|
||||||
|
(str(ex) in error_messages)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not should_skip_error:
|
||||||
|
logging.critical(f"Failed to add {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name}, operation error: {ex}")
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
|
logging.critical(f"Failed to add {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name}, error: {ex}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def alter_db_column_type(migrator, table_name, column_name, new_column_type):
|
||||||
|
try:
|
||||||
|
migrate(migrator.alter_column_type(table_name, column_name, new_column_type))
|
||||||
|
except Exception as ex:
|
||||||
|
logging.critical(f"Failed to alter {settings.DATABASE_TYPE.upper()}.{table_name} column {column_name} type, error: {ex}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def alter_db_rename_column(migrator, table_name, old_column_name, new_column_name):
|
||||||
|
try:
|
||||||
|
migrate(migrator.rename_column(table_name, old_column_name, new_column_name))
|
||||||
|
except Exception:
|
||||||
|
# rename fail will lead to a weired error.
|
||||||
|
# logging.critical(f"Failed to rename {settings.DATABASE_TYPE.upper()}.{table_name} column {old_column_name} to {new_column_name}, error: {ex}")
|
||||||
|
pass
|
||||||
|
|
||||||
def migrate_db():
|
def migrate_db():
|
||||||
logging.disable(logging.ERROR)
|
logging.disable(logging.ERROR)
|
||||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||||
try:
|
alter_db_add_column(migrator, "file", "source_type", CharField(max_length=128, null=False, default="", help_text="where dose this document come from", index=True))
|
||||||
migrate(migrator.add_column("file", "source_type", CharField(max_length=128, null=False, default="", help_text="where dose this document come from", index=True)))
|
alter_db_add_column(migrator, "tenant", "rerank_id", CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID"))
|
||||||
except Exception:
|
alter_db_add_column(migrator, "dialog", "rerank_id", CharField(max_length=128, null=False, default="", help_text="default rerank model ID"))
|
||||||
pass
|
alter_db_column_type(migrator, "dialog", "top_k", IntegerField(default=1024))
|
||||||
try:
|
alter_db_add_column(migrator, "tenant_llm", "api_key", CharField(max_length=2048, null=True, help_text="API KEY", index=True))
|
||||||
migrate(migrator.add_column("tenant", "rerank_id", CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID")))
|
alter_db_add_column(migrator, "api_token", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
|
||||||
except Exception:
|
alter_db_add_column(migrator, "tenant", "tts_id", CharField(max_length=256, null=True, help_text="default tts model ID", index=True))
|
||||||
pass
|
alter_db_add_column(migrator, "api_4_conversation", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
|
||||||
try:
|
alter_db_add_column(migrator, "task", "retry_count", IntegerField(default=0))
|
||||||
migrate(migrator.add_column("dialog", "rerank_id", CharField(max_length=128, null=False, default="", help_text="default rerank model ID")))
|
alter_db_column_type(migrator, "api_token", "dialog_id", CharField(max_length=32, null=True, index=True))
|
||||||
except Exception:
|
alter_db_add_column(migrator, "tenant_llm", "max_tokens", IntegerField(default=8192, index=True))
|
||||||
pass
|
alter_db_add_column(migrator, "api_4_conversation", "dsl", JSONField(null=True, default={}))
|
||||||
try:
|
alter_db_add_column(migrator, "knowledgebase", "pagerank", IntegerField(default=0, index=False))
|
||||||
migrate(migrator.add_column("dialog", "top_k", IntegerField(default=1024)))
|
alter_db_add_column(migrator, "api_token", "beta", CharField(max_length=255, null=True, index=True))
|
||||||
except Exception:
|
alter_db_add_column(migrator, "task", "digest", TextField(null=True, help_text="task digest", default=""))
|
||||||
pass
|
alter_db_add_column(migrator, "task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default=""))
|
||||||
try:
|
alter_db_add_column(migrator, "conversation", "user_id", CharField(max_length=255, null=True, help_text="user_id", index=True))
|
||||||
migrate(migrator.alter_column_type("tenant_llm", "api_key", CharField(max_length=2048, null=True, help_text="API KEY", index=True)))
|
alter_db_add_column(migrator, "document", "meta_fields", JSONField(null=True, default={}))
|
||||||
except Exception:
|
alter_db_add_column(migrator, "task", "task_type", CharField(max_length=32, null=False, default=""))
|
||||||
pass
|
alter_db_add_column(migrator, "task", "priority", IntegerField(default=0))
|
||||||
try:
|
alter_db_add_column(migrator, "user_canvas", "permission", CharField(max_length=16, null=False, help_text="me|team", default="me", index=True))
|
||||||
migrate(migrator.add_column("api_token", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)))
|
alter_db_add_column(migrator, "llm", "is_tools", BooleanField(null=False, help_text="support tools", default=False))
|
||||||
except Exception:
|
alter_db_add_column(migrator, "mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=dict))
|
||||||
pass
|
alter_db_rename_column(migrator, "task", "process_duation", "process_duration")
|
||||||
try:
|
alter_db_rename_column(migrator, "document", "process_duation", "process_duration")
|
||||||
migrate(migrator.add_column("tenant", "tts_id", CharField(max_length=256, null=True, help_text="default tts model ID", index=True)))
|
alter_db_add_column(migrator, "document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True))
|
||||||
except Exception:
|
alter_db_add_column(migrator, "api_4_conversation", "errors", TextField(null=True, help_text="errors"))
|
||||||
pass
|
alter_db_add_column(migrator, "dialog", "meta_data_filter", JSONField(null=True, default={}))
|
||||||
try:
|
alter_db_column_type(migrator, "canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title"))
|
||||||
migrate(migrator.add_column("api_4_conversation", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)))
|
alter_db_column_type(migrator, "canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description"))
|
||||||
except Exception:
|
alter_db_add_column(migrator, "user_canvas", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True))
|
||||||
pass
|
alter_db_add_column(migrator, "canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True))
|
||||||
try:
|
alter_db_add_column(migrator, "knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True))
|
||||||
migrate(migrator.add_column("task", "retry_count", IntegerField(default=0)))
|
alter_db_add_column(migrator, "document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True))
|
||||||
except Exception:
|
alter_db_add_column(migrator, "knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True))
|
||||||
pass
|
alter_db_add_column(migrator, "knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True))
|
||||||
try:
|
alter_db_add_column(migrator, "knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True))
|
||||||
migrate(migrator.alter_column_type("api_token", "dialog_id", CharField(max_length=32, null=True, index=True)))
|
alter_db_add_column(migrator, "knowledgebase", "raptor_task_finish_at", CharField(null=True))
|
||||||
except Exception:
|
alter_db_add_column(migrator, "knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True))
|
||||||
pass
|
alter_db_add_column(migrator, "knowledgebase", "mindmap_task_finish_at", CharField(null=True))
|
||||||
try:
|
alter_db_column_type(migrator, "tenant_llm", "api_key", TextField(null=True, help_text="API KEY"))
|
||||||
migrate(migrator.add_column("tenant_llm", "max_tokens", IntegerField(default=8192, index=True)))
|
alter_db_add_column(migrator, "tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True))
|
||||||
except Exception:
|
alter_db_add_column(migrator, "connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False))
|
||||||
pass
|
alter_db_add_column(migrator, "llm_factories", "rank", IntegerField(default=0, index=False))
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("api_4_conversation", "dsl", JSONField(null=True, default={})))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("knowledgebase", "pagerank", IntegerField(default=0, index=False)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("api_token", "beta", CharField(max_length=255, null=True, index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("task", "digest", TextField(null=True, help_text="task digest", default="")))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default="")))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("conversation", "user_id", CharField(max_length=255, null=True, help_text="user_id", index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("document", "meta_fields", JSONField(null=True, default={})))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("task", "task_type", CharField(max_length=32, null=False, default="")))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("task", "priority", IntegerField(default=0)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("user_canvas", "permission", CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("llm", "is_tools", BooleanField(null=False, help_text="support tools", default=False)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=dict)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.rename_column("task", "process_duation", "process_duration"))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.rename_column("document", "process_duation", "process_duration"))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("api_4_conversation", "errors", TextField(null=True, help_text="errors")))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.alter_column_type("canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title")))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.alter_column_type("canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description")))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("user_canvas", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("knowledgebase", "raptor_task_finish_at", CharField(null=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("knowledgebase", "mindmap_task_finish_at", CharField(null=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.alter_column_type("tenant_llm", "api_key", TextField(null=True, help_text="API KEY")))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# RAG Evaluation tables
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("evaluation_datasets", "tenant_id", CharField(max_length=32, null=False, index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("evaluation_datasets", "name", CharField(max_length=255, null=False, index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("evaluation_datasets", "description", TextField(null=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("evaluation_datasets", "kb_ids", JSONField(null=False)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("evaluation_datasets", "created_by", CharField(max_length=32, null=False, index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("evaluation_datasets", "create_time", BigIntegerField(null=False, index=True)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("evaluation_datasets", "update_time", BigIntegerField(null=False)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1)))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logging.disable(logging.NOTSET)
|
logging.disable(logging.NOTSET)
|
||||||
|
|||||||
@ -30,6 +30,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
|||||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||||
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
|
from api.db.services.system_settings_service import SystemSettingsService
|
||||||
|
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache
|
||||||
from common.constants import LLMType
|
from common.constants import LLMType
|
||||||
from common.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from common import settings
|
from common import settings
|
||||||
@ -157,20 +159,49 @@ def add_graph_templates():
|
|||||||
CanvasTemplateService.save(**cnvs)
|
CanvasTemplateService.save(**cnvs)
|
||||||
except Exception:
|
except Exception:
|
||||||
CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
|
CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logging.exception("Add agent templates error: ")
|
logging.exception(f"Add agent templates error: {e}")
|
||||||
|
|
||||||
|
|
||||||
def init_web_data():
|
def init_web_data():
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
init_table()
|
||||||
|
|
||||||
init_llm_factory()
|
init_llm_factory()
|
||||||
# if not UserService.get_all().count():
|
# if not UserService.get_all().count():
|
||||||
# init_superuser()
|
# init_superuser()
|
||||||
|
|
||||||
add_graph_templates()
|
add_graph_templates()
|
||||||
|
init_message_id_sequence()
|
||||||
|
init_memory_size_cache()
|
||||||
logging.info("init web data success:{}".format(time.time() - start_time))
|
logging.info("init web data success:{}".format(time.time() - start_time))
|
||||||
|
|
||||||
|
def init_table():
|
||||||
|
# init system_settings
|
||||||
|
with open(os.path.join(get_project_base_directory(), "conf", "system_settings.json"), "r") as f:
|
||||||
|
records_from_file = json.load(f)["system_settings"]
|
||||||
|
|
||||||
|
record_index = {}
|
||||||
|
records_from_db = SystemSettingsService.get_all()
|
||||||
|
for index, record in enumerate(records_from_db):
|
||||||
|
record_index[record.name] = index
|
||||||
|
|
||||||
|
to_save = []
|
||||||
|
for record in records_from_file:
|
||||||
|
setting_name = record["name"]
|
||||||
|
if setting_name not in record_index:
|
||||||
|
to_save.append(record)
|
||||||
|
|
||||||
|
len_to_save = len(to_save)
|
||||||
|
if len_to_save > 0:
|
||||||
|
# not initialized
|
||||||
|
try:
|
||||||
|
SystemSettingsService.insert_many(to_save, len_to_save)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("System settings init error: {}".format(e))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
init_web_db()
|
init_web_db()
|
||||||
|
|||||||
420
api/db/joint_services/memory_message_service.py
Normal file
420
api/db/joint_services/memory_message_service.py
Normal file
@ -0,0 +1,420 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from common import settings
|
||||||
|
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
|
||||||
|
from common.constants import MemoryType, LLMType
|
||||||
|
from common.doc_store.doc_store_base import FusionExpr
|
||||||
|
from common.misc_utils import get_uuid
|
||||||
|
from api.db.db_utils import bulk_insert_into_db
|
||||||
|
from api.db.db_models import Task
|
||||||
|
from api.db.services.task_service import TaskService
|
||||||
|
from api.db.services.memory_service import MemoryService
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.utils.memory_utils import get_memory_type_human
|
||||||
|
from memory.services.messages import MessageService
|
||||||
|
from memory.services.query import MsgTextQuery, get_vector
|
||||||
|
from memory.utils.prompt_util import PromptAssembler
|
||||||
|
from memory.utils.msg_util import get_json_result_from_llm_response
|
||||||
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
|
||||||
|
|
||||||
|
async def save_to_memory(memory_id: str, message_dict: dict):
|
||||||
|
"""
|
||||||
|
:param memory_id:
|
||||||
|
:param message_dict: {
|
||||||
|
"user_id": str,
|
||||||
|
"agent_id": str,
|
||||||
|
"session_id": str,
|
||||||
|
"user_input": str,
|
||||||
|
"agent_response": str
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
|
if not memory:
|
||||||
|
return False, f"Memory '{memory_id}' not found."
|
||||||
|
|
||||||
|
tenant_id = memory.tenant_id
|
||||||
|
extracted_content = await extract_by_llm(
|
||||||
|
tenant_id,
|
||||||
|
memory.llm_id,
|
||||||
|
{"temperature": memory.temperature},
|
||||||
|
get_memory_type_human(memory.memory_type),
|
||||||
|
message_dict.get("user_input", ""),
|
||||||
|
message_dict.get("agent_response", "")
|
||||||
|
) if memory.memory_type != MemoryType.RAW.value else [] # if only RAW, no need to extract
|
||||||
|
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
|
||||||
|
message_list = [{
|
||||||
|
"message_id": raw_message_id,
|
||||||
|
"message_type": MemoryType.RAW.name.lower(),
|
||||||
|
"source_id": 0,
|
||||||
|
"memory_id": memory_id,
|
||||||
|
"user_id": "",
|
||||||
|
"agent_id": message_dict["agent_id"],
|
||||||
|
"session_id": message_dict["session_id"],
|
||||||
|
"content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",
|
||||||
|
"valid_at": timestamp_to_date(current_timestamp()),
|
||||||
|
"invalid_at": None,
|
||||||
|
"forget_at": None,
|
||||||
|
"status": True
|
||||||
|
}, *[{
|
||||||
|
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
|
||||||
|
"message_type": content["message_type"],
|
||||||
|
"source_id": raw_message_id,
|
||||||
|
"memory_id": memory_id,
|
||||||
|
"user_id": "",
|
||||||
|
"agent_id": message_dict["agent_id"],
|
||||||
|
"session_id": message_dict["session_id"],
|
||||||
|
"content": content["content"],
|
||||||
|
"valid_at": content["valid_at"],
|
||||||
|
"invalid_at": content["invalid_at"] if content["invalid_at"] else None,
|
||||||
|
"forget_at": None,
|
||||||
|
"status": True
|
||||||
|
} for content in extracted_content]]
|
||||||
|
return await embed_and_save(memory, message_list)
|
||||||
|
|
||||||
|
|
||||||
|
async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int, task_id: str=None):
|
||||||
|
memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
|
if not memory:
|
||||||
|
msg = f"Memory '{memory_id}' not found."
|
||||||
|
if task_id:
|
||||||
|
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||||
|
return False, msg
|
||||||
|
|
||||||
|
if memory.memory_type == MemoryType.RAW.value:
|
||||||
|
msg = f"Memory '{memory_id}' don't need to extract."
|
||||||
|
if task_id:
|
||||||
|
TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||||
|
return True, msg
|
||||||
|
|
||||||
|
tenant_id = memory.tenant_id
|
||||||
|
extracted_content = await extract_by_llm(
|
||||||
|
tenant_id,
|
||||||
|
memory.llm_id,
|
||||||
|
{"temperature": memory.temperature},
|
||||||
|
get_memory_type_human(memory.memory_type),
|
||||||
|
message_dict.get("user_input", ""),
|
||||||
|
message_dict.get("agent_response", ""),
|
||||||
|
task_id=task_id
|
||||||
|
)
|
||||||
|
message_list = [{
|
||||||
|
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
|
||||||
|
"message_type": content["message_type"],
|
||||||
|
"source_id": source_message_id,
|
||||||
|
"memory_id": memory_id,
|
||||||
|
"user_id": "",
|
||||||
|
"agent_id": message_dict["agent_id"],
|
||||||
|
"session_id": message_dict["session_id"],
|
||||||
|
"content": content["content"],
|
||||||
|
"valid_at": content["valid_at"],
|
||||||
|
"invalid_at": content["invalid_at"] if content["invalid_at"] else None,
|
||||||
|
"forget_at": None,
|
||||||
|
"status": True
|
||||||
|
} for content in extracted_content]
|
||||||
|
if not message_list:
|
||||||
|
msg = "No memory extracted from raw message."
|
||||||
|
if task_id:
|
||||||
|
TaskService.update_progress(task_id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||||
|
return True, msg
|
||||||
|
|
||||||
|
if task_id:
|
||||||
|
TaskService.update_progress(task_id, {"progress": 0.5, "progress_msg": timestamp_to_date(current_timestamp())+ " " + f"Extracted {len(message_list)} messages from raw dialogue."})
|
||||||
|
return await embed_and_save(memory, message_list, task_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
|
||||||
|
agent_response: str, system_prompt: str = "", user_prompt: str="", task_id: str=None) -> List[dict]:
|
||||||
|
llm_type = TenantLLMService.llm_id2llm_type(llm_id)
|
||||||
|
if not llm_type:
|
||||||
|
raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
|
||||||
|
if not system_prompt:
|
||||||
|
system_prompt = PromptAssembler.assemble_system_prompt({"memory_type": memory_type})
|
||||||
|
conversation_content = f"User Input: {user_input}\nAgent Response: {agent_response}"
|
||||||
|
conversation_time = timestamp_to_date(current_timestamp())
|
||||||
|
user_prompts = []
|
||||||
|
if user_prompt:
|
||||||
|
user_prompts.append({"role": "user", "content": user_prompt})
|
||||||
|
user_prompts.append({"role": "user", "content": f"Conversation: {conversation_content}\nConversation Time: {conversation_time}\nCurrent Time: {conversation_time}"})
|
||||||
|
else:
|
||||||
|
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
|
||||||
|
llm = LLMBundle(tenant_id, llm_type, llm_id)
|
||||||
|
if task_id:
|
||||||
|
TaskService.update_progress(task_id, {"progress": 0.15, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared prompts and LLM."})
|
||||||
|
res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
|
||||||
|
res_json = get_json_result_from_llm_response(res)
|
||||||
|
if task_id:
|
||||||
|
TaskService.update_progress(task_id, {"progress": 0.35, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Get extracted result from LLM."})
|
||||||
|
return [{
|
||||||
|
"content": extracted_content["content"],
|
||||||
|
"valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]),
|
||||||
|
"invalid_at": format_iso_8601_to_ymd_hms(extracted_content["invalid_at"]) if extracted_content.get("invalid_at") else "",
|
||||||
|
"message_type": message_type
|
||||||
|
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
|
||||||
|
|
||||||
|
|
||||||
|
async def embed_and_save(memory, message_list: list[dict], task_id: str=None):
|
||||||
|
embedding_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||||
|
if task_id:
|
||||||
|
TaskService.update_progress(task_id, {"progress": 0.65, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Prepared embedding model."})
|
||||||
|
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
|
||||||
|
for idx, msg in enumerate(message_list):
|
||||||
|
msg["content_embed"] = vector_list[idx]
|
||||||
|
if task_id:
|
||||||
|
TaskService.update_progress(task_id, {"progress": 0.85, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Embedded extracted content."})
|
||||||
|
vector_dimension = len(vector_list[0])
|
||||||
|
if not MessageService.has_index(memory.tenant_id, memory.id):
|
||||||
|
created = MessageService.create_index(memory.tenant_id, memory.id, vector_size=vector_dimension)
|
||||||
|
if not created:
|
||||||
|
error_msg = "Failed to create message index."
|
||||||
|
if task_id:
|
||||||
|
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
|
||||||
|
current_memory_size = get_memory_size_cache(memory.tenant_id, memory.id)
|
||||||
|
if new_msg_size + current_memory_size > memory.memory_size:
|
||||||
|
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
|
||||||
|
if memory.forgetting_policy == "FIFO":
|
||||||
|
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory.id, memory.tenant_id,
|
||||||
|
size_to_delete)
|
||||||
|
MessageService.delete_message({"message_id": message_ids_to_delete}, memory.tenant_id, memory.id)
|
||||||
|
decrease_memory_size_cache(memory.id, delete_size)
|
||||||
|
else:
|
||||||
|
error_msg = "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
|
||||||
|
if task_id:
|
||||||
|
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
|
||||||
|
return False, error_msg
|
||||||
|
fail_cases = MessageService.insert_message(message_list, memory.tenant_id, memory.id)
|
||||||
|
if fail_cases:
|
||||||
|
error_msg = "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
|
||||||
|
if task_id:
|
||||||
|
TaskService.update_progress(task_id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + error_msg})
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
if task_id:
|
||||||
|
TaskService.update_progress(task_id, {"progress": 0.95, "progress_msg": timestamp_to_date(current_timestamp())+ " " + "Saved messages to storage."})
|
||||||
|
increase_memory_size_cache(memory.id, new_msg_size)
|
||||||
|
return True, "Message saved successfully."
|
||||||
|
|
||||||
|
|
||||||
|
def query_message(filter_dict: dict, params: dict):
|
||||||
|
"""
|
||||||
|
:param filter_dict: {
|
||||||
|
"memory_id": List[str],
|
||||||
|
"agent_id": optional
|
||||||
|
"session_id": optional
|
||||||
|
}
|
||||||
|
:param params: {
|
||||||
|
"query": question str,
|
||||||
|
"similarity_threshold": float,
|
||||||
|
"keywords_similarity_weight": float,
|
||||||
|
"top_n": int
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
memory_ids = filter_dict["memory_id"]
|
||||||
|
memory_list = MemoryService.get_by_ids(memory_ids)
|
||||||
|
if not memory_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
condition_dict = {k: v for k, v in filter_dict.items() if v}
|
||||||
|
uids = [memory.tenant_id for memory in memory_list]
|
||||||
|
|
||||||
|
question = params["query"]
|
||||||
|
question = question.strip()
|
||||||
|
memory = memory_list[0]
|
||||||
|
embd_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||||
|
match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"])
|
||||||
|
match_text, _ = MsgTextQuery().question(question, min_match=params["similarity_threshold"])
|
||||||
|
keywords_similarity_weight = params.get("keywords_similarity_weight", 0.7)
|
||||||
|
fusion_expr = FusionExpr("weighted_sum", params["top_n"], {"weights": ",".join([str(1 - keywords_similarity_weight), str(keywords_similarity_weight)])})
|
||||||
|
|
||||||
|
return MessageService.search_message(memory_ids, condition_dict, uids, [match_text, match_dense, fusion_expr], params["top_n"])
|
||||||
|
|
||||||
|
|
||||||
|
def init_message_id_sequence():
|
||||||
|
message_id_redis_key = "id_generator:memory"
|
||||||
|
if REDIS_CONN.exist(message_id_redis_key):
|
||||||
|
current_max_id = REDIS_CONN.get(message_id_redis_key)
|
||||||
|
logging.info(f"No need to init message_id sequence, current max id is {current_max_id}.")
|
||||||
|
else:
|
||||||
|
max_id = 1
|
||||||
|
exist_memory_list = MemoryService.get_all_memory()
|
||||||
|
if not exist_memory_list:
|
||||||
|
REDIS_CONN.set(message_id_redis_key, max_id)
|
||||||
|
else:
|
||||||
|
max_id = MessageService.get_max_message_id(
|
||||||
|
uid_list=[m.tenant_id for m in exist_memory_list],
|
||||||
|
memory_ids=[m.id for m in exist_memory_list]
|
||||||
|
)
|
||||||
|
REDIS_CONN.set(message_id_redis_key, max_id)
|
||||||
|
logging.info(f"Init message_id sequence done, current max id is {max_id}.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_size_cache(memory_id: str, uid: str):
|
||||||
|
redis_key = f"memory_{memory_id}"
|
||||||
|
if REDIS_CONN.exist(redis_key):
|
||||||
|
return int(REDIS_CONN.get(redis_key))
|
||||||
|
else:
|
||||||
|
memory_size_map = MessageService.calculate_memory_size(
|
||||||
|
[memory_id],
|
||||||
|
[uid]
|
||||||
|
)
|
||||||
|
memory_size = memory_size_map.get(memory_id, 0)
|
||||||
|
set_memory_size_cache(memory_id, memory_size)
|
||||||
|
return memory_size
|
||||||
|
|
||||||
|
|
||||||
|
def set_memory_size_cache(memory_id: str, size: int):
|
||||||
|
redis_key = f"memory_{memory_id}"
|
||||||
|
return REDIS_CONN.set(redis_key, size)
|
||||||
|
|
||||||
|
|
||||||
|
def increase_memory_size_cache(memory_id: str, size: int):
|
||||||
|
redis_key = f"memory_{memory_id}"
|
||||||
|
return REDIS_CONN.incrby(redis_key, size)
|
||||||
|
|
||||||
|
|
||||||
|
def decrease_memory_size_cache(memory_id: str, size: int):
|
||||||
|
redis_key = f"memory_{memory_id}"
|
||||||
|
return REDIS_CONN.decrby(redis_key, size)
|
||||||
|
|
||||||
|
|
||||||
|
def init_memory_size_cache():
|
||||||
|
memory_list = MemoryService.get_all_memory()
|
||||||
|
if not memory_list:
|
||||||
|
logging.info("No memory found, no need to init memory size.")
|
||||||
|
else:
|
||||||
|
for m in memory_list:
|
||||||
|
get_memory_size_cache(m.id, m.tenant_id)
|
||||||
|
logging.info("Memory size cache init done.")
|
||||||
|
|
||||||
|
|
||||||
|
def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]):
|
||||||
|
memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type)
|
||||||
|
return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list})
|
||||||
|
|
||||||
|
|
||||||
|
async def queue_save_to_memory_task(memory_ids: list[str], message_dict: dict):
|
||||||
|
"""
|
||||||
|
:param memory_ids:
|
||||||
|
:param message_dict: {
|
||||||
|
"user_id": str,
|
||||||
|
"agent_id": str,
|
||||||
|
"session_id": str,
|
||||||
|
"user_input": str,
|
||||||
|
"agent_response": str
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
def new_task(_memory_id: str, _source_id: int):
|
||||||
|
return {
|
||||||
|
"id": get_uuid(),
|
||||||
|
"doc_id": _memory_id,
|
||||||
|
"task_type": "memory",
|
||||||
|
"progress": 0.0,
|
||||||
|
"digest": str(_source_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
not_found_memory = []
|
||||||
|
failed_memory = []
|
||||||
|
for memory_id in memory_ids:
|
||||||
|
memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
|
if not memory:
|
||||||
|
not_found_memory.append(memory_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
|
||||||
|
raw_message = {
|
||||||
|
"message_id": raw_message_id,
|
||||||
|
"message_type": MemoryType.RAW.name.lower(),
|
||||||
|
"source_id": 0,
|
||||||
|
"memory_id": memory_id,
|
||||||
|
"user_id": "",
|
||||||
|
"agent_id": message_dict["agent_id"],
|
||||||
|
"session_id": message_dict["session_id"],
|
||||||
|
"content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",
|
||||||
|
"valid_at": timestamp_to_date(current_timestamp()),
|
||||||
|
"invalid_at": None,
|
||||||
|
"forget_at": None,
|
||||||
|
"status": True
|
||||||
|
}
|
||||||
|
res, msg = await embed_and_save(memory, [raw_message])
|
||||||
|
if not res:
|
||||||
|
failed_memory.append({"memory_id": memory_id, "fail_msg": msg})
|
||||||
|
continue
|
||||||
|
|
||||||
|
task = new_task(memory_id, raw_message_id)
|
||||||
|
bulk_insert_into_db(Task, [task], replace_on_conflict=True)
|
||||||
|
task_message = {
|
||||||
|
"id": task["id"],
|
||||||
|
"task_id": task["id"],
|
||||||
|
"task_type": task["task_type"],
|
||||||
|
"memory_id": memory_id,
|
||||||
|
"source_id": raw_message_id,
|
||||||
|
"message_dict": message_dict
|
||||||
|
}
|
||||||
|
if not REDIS_CONN.queue_product(settings.get_svr_queue_name(priority=0), message=task_message):
|
||||||
|
failed_memory.append({"memory_id": memory_id, "fail_msg": "Can't access Redis."})
|
||||||
|
|
||||||
|
error_msg = ""
|
||||||
|
if not_found_memory:
|
||||||
|
error_msg = f"Memory {not_found_memory} not found."
|
||||||
|
if failed_memory:
|
||||||
|
error_msg += "".join([f"Memory {fm['memory_id']} failed. Detail: {fm['fail_msg']}" for fm in failed_memory])
|
||||||
|
|
||||||
|
if error_msg:
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
return True, "All add to task."
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_save_to_memory_task(task_param: dict):
|
||||||
|
"""
|
||||||
|
:param task_param: {
|
||||||
|
"id": task_id
|
||||||
|
"memory_id": id
|
||||||
|
"source_id": id
|
||||||
|
"message_dict": {
|
||||||
|
"user_id": str,
|
||||||
|
"agent_id": str,
|
||||||
|
"session_id": str,
|
||||||
|
"user_input": str,
|
||||||
|
"agent_response": str
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
_, task = TaskService.get_by_id(task_param["id"])
|
||||||
|
if not task:
|
||||||
|
return False, f"Task {task_param['id']} is not found."
|
||||||
|
if task.progress == -1:
|
||||||
|
return False, f"Task {task_param['id']} is already failed."
|
||||||
|
now_time = current_timestamp()
|
||||||
|
TaskService.update_by_id(task_param["id"], {"begin_at": timestamp_to_date(now_time)})
|
||||||
|
|
||||||
|
memory_id = task_param["memory_id"]
|
||||||
|
source_id = task_param["source_id"]
|
||||||
|
message_dict = task_param["message_dict"]
|
||||||
|
success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id, task.id)
|
||||||
|
if success:
|
||||||
|
TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||||
|
return True, msg
|
||||||
|
|
||||||
|
logging.error(msg)
|
||||||
|
TaskService.update_progress(task.id, {"progress": -1, "progress_msg": timestamp_to_date(current_timestamp())+ " " + msg})
|
||||||
|
return False, msg
|
||||||
@ -34,6 +34,8 @@ from api.db.services.task_service import TaskService
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||||
|
from api.db.services.memory_service import MemoryService
|
||||||
|
from memory.services.messages import MessageService
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
from common.constants import ActiveEnum
|
from common.constants import ActiveEnum
|
||||||
from common import settings
|
from common import settings
|
||||||
@ -200,7 +202,16 @@ def delete_user_data(user_id: str) -> dict:
|
|||||||
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
|
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
|
||||||
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
|
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
|
||||||
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
|
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
|
||||||
# step1.3 delete own tenant
|
# step1.3 delete memory and messages
|
||||||
|
user_memory = MemoryService.get_by_tenant_id(tenant_id)
|
||||||
|
if user_memory:
|
||||||
|
for memory in user_memory:
|
||||||
|
if MessageService.has_index(tenant_id, memory.id):
|
||||||
|
MessageService.delete_index(tenant_id, memory.id)
|
||||||
|
done_msg += " Deleted memory index."
|
||||||
|
memory_delete_res = MemoryService.delete_by_ids([m.id for m in user_memory])
|
||||||
|
done_msg += f"Deleted {memory_delete_res} memory datasets."
|
||||||
|
# step1.4 delete own tenant
|
||||||
tenant_delete_res = TenantService.delete_by_id(tenant_id)
|
tenant_delete_res = TenantService.delete_by_id(tenant_id)
|
||||||
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
|
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
|
||||||
# step2 delete user-tenant relation
|
# step2 delete user-tenant relation
|
||||||
|
|||||||
@ -123,6 +123,19 @@ class UserCanvasService(CommonService):
|
|||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_basic_info_by_canvas_ids(cls, canvas_id):
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.avatar,
|
||||||
|
cls.model.user_id,
|
||||||
|
cls.model.title,
|
||||||
|
cls.model.permission,
|
||||||
|
cls.model.canvas_category
|
||||||
|
]
|
||||||
|
return cls.model.select(*fields).where(cls.model.id.in_(canvas_id)).dicts()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||||
@ -198,7 +211,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
|||||||
if not isinstance(cvs.dsl, str):
|
if not isinstance(cvs.dsl, str):
|
||||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
session_id=get_uuid()
|
session_id=get_uuid()
|
||||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
canvas = Canvas(cvs.dsl, tenant_id, agent_id, canvas_id=cvs.id)
|
||||||
canvas.reset()
|
canvas.reset()
|
||||||
conv = {
|
conv = {
|
||||||
"id": session_id,
|
"id": session_id,
|
||||||
|
|||||||
@ -169,10 +169,12 @@ class CommonService:
|
|||||||
"""
|
"""
|
||||||
if "id" not in kwargs:
|
if "id" not in kwargs:
|
||||||
kwargs["id"] = get_uuid()
|
kwargs["id"] = get_uuid()
|
||||||
kwargs["create_time"] = current_timestamp()
|
timestamp = current_timestamp()
|
||||||
kwargs["create_date"] = datetime_format(datetime.now())
|
cur_datetime = datetime_format(datetime.now())
|
||||||
kwargs["update_time"] = current_timestamp()
|
kwargs["create_time"] = timestamp
|
||||||
kwargs["update_date"] = datetime_format(datetime.now())
|
kwargs["create_date"] = cur_datetime
|
||||||
|
kwargs["update_time"] = timestamp
|
||||||
|
kwargs["update_date"] = cur_datetime
|
||||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||||
return sample_obj
|
return sample_obj
|
||||||
|
|
||||||
@ -188,10 +190,15 @@ class CommonService:
|
|||||||
data_list (list): List of dictionaries containing record data to insert.
|
data_list (list): List of dictionaries containing record data to insert.
|
||||||
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
|
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
|
||||||
"""
|
"""
|
||||||
|
current_ts = current_timestamp()
|
||||||
|
current_datetime = datetime_format(datetime.now())
|
||||||
with DB.atomic():
|
with DB.atomic():
|
||||||
for d in data_list:
|
for d in data_list:
|
||||||
d["create_time"] = current_timestamp()
|
d["create_time"] = current_ts
|
||||||
d["create_date"] = datetime_format(datetime.now())
|
d["create_date"] = current_datetime
|
||||||
|
d["update_time"] = current_ts
|
||||||
|
d["update_date"] = current_datetime
|
||||||
|
|
||||||
for i in range(0, len(data_list), batch_size):
|
for i in range(0, len(data_list), batch_size):
|
||||||
cls.model.insert_many(data_list[i : i + batch_size]).execute()
|
cls.model.insert_many(data_list[i : i + batch_size]).execute()
|
||||||
|
|
||||||
@ -207,10 +214,14 @@ class CommonService:
|
|||||||
data_list (list): List of dictionaries containing record data to update.
|
data_list (list): List of dictionaries containing record data to update.
|
||||||
Each dictionary must include an 'id' field.
|
Each dictionary must include an 'id' field.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
timestamp = current_timestamp()
|
||||||
|
cur_datetime = datetime_format(datetime.now())
|
||||||
|
for data in data_list:
|
||||||
|
data["update_time"] = timestamp
|
||||||
|
data["update_date"] = cur_datetime
|
||||||
with DB.atomic():
|
with DB.atomic():
|
||||||
for data in data_list:
|
for data in data_list:
|
||||||
data["update_time"] = current_timestamp()
|
|
||||||
data["update_date"] = datetime_format(datetime.now())
|
|
||||||
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -29,7 +29,6 @@ from common.misc_utils import get_uuid
|
|||||||
from common.constants import TaskStatus
|
from common.constants import TaskStatus
|
||||||
from common.time_utils import current_timestamp, timestamp_to_date
|
from common.time_utils import current_timestamp, timestamp_to_date
|
||||||
|
|
||||||
|
|
||||||
class ConnectorService(CommonService):
|
class ConnectorService(CommonService):
|
||||||
model = Connector
|
model = Connector
|
||||||
|
|
||||||
@ -202,6 +201,7 @@ class SyncLogsService(CommonService):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
class FileObj(BaseModel):
|
class FileObj(BaseModel):
|
||||||
|
id: str
|
||||||
filename: str
|
filename: str
|
||||||
blob: bytes
|
blob: bytes
|
||||||
|
|
||||||
@ -209,7 +209,7 @@ class SyncLogsService(CommonService):
|
|||||||
return self.blob
|
return self.blob
|
||||||
|
|
||||||
errs = []
|
errs = []
|
||||||
files = [FileObj(filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs]
|
files = [FileObj(id=d["id"], filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs]
|
||||||
doc_ids = []
|
doc_ids = []
|
||||||
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
|
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
|
||||||
errs.extend(err)
|
errs.extend(err)
|
||||||
|
|||||||
@ -64,11 +64,13 @@ class ConversationService(CommonService):
|
|||||||
offset += limit
|
offset += limit
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def structure_answer(conv, ans, message_id, session_id):
|
def structure_answer(conv, ans, message_id, session_id):
|
||||||
reference = ans["reference"]
|
reference = ans["reference"]
|
||||||
if not isinstance(reference, dict):
|
if not isinstance(reference, dict):
|
||||||
reference = {}
|
reference = {}
|
||||||
ans["reference"] = {}
|
ans["reference"] = {}
|
||||||
|
is_final = ans.get("final", True)
|
||||||
|
|
||||||
chunk_list = chunks_format(reference)
|
chunk_list = chunks_format(reference)
|
||||||
|
|
||||||
@ -81,14 +83,32 @@ def structure_answer(conv, ans, message_id, session_id):
|
|||||||
|
|
||||||
if not conv.message:
|
if not conv.message:
|
||||||
conv.message = []
|
conv.message = []
|
||||||
|
content = ans["answer"]
|
||||||
|
if ans.get("start_to_think"):
|
||||||
|
content = "<think>"
|
||||||
|
elif ans.get("end_to_think"):
|
||||||
|
content = "</think>"
|
||||||
|
|
||||||
if not conv.message or conv.message[-1].get("role", "") != "assistant":
|
if not conv.message or conv.message[-1].get("role", "") != "assistant":
|
||||||
conv.message.append({"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id})
|
conv.message.append({"role": "assistant", "content": content, "created_at": time.time(), "id": message_id})
|
||||||
else:
|
else:
|
||||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
|
if is_final:
|
||||||
|
if ans.get("answer"):
|
||||||
|
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
|
||||||
|
else:
|
||||||
|
conv.message[-1]["created_at"] = time.time()
|
||||||
|
conv.message[-1]["id"] = message_id
|
||||||
|
else:
|
||||||
|
conv.message[-1]["content"] = (conv.message[-1].get("content") or "") + content
|
||||||
|
conv.message[-1]["created_at"] = time.time()
|
||||||
|
conv.message[-1]["id"] = message_id
|
||||||
if conv.reference:
|
if conv.reference:
|
||||||
conv.reference[-1] = reference
|
should_update_reference = is_final or bool(reference.get("chunks")) or bool(reference.get("doc_aggs"))
|
||||||
|
if should_update_reference:
|
||||||
|
conv.reference[-1] = reference
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
||||||
assert name, "`name` can not be empty."
|
assert name, "`name` can not be empty."
|
||||||
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||||
@ -116,6 +136,16 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses
|
|||||||
ensure_ascii=False) + "\n\n"
|
ensure_ascii=False) + "\n\n"
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
answer = {
|
||||||
|
"answer": conv["message"][0]["content"],
|
||||||
|
"reference": {},
|
||||||
|
"audio_binary": None,
|
||||||
|
"id": None,
|
||||||
|
"session_id": session_id
|
||||||
|
}
|
||||||
|
yield answer
|
||||||
|
return
|
||||||
|
|
||||||
conv = ConversationService.query(id=session_id, dialog_id=chat_id)
|
conv = ConversationService.query(id=session_id, dialog_id=chat_id)
|
||||||
if not conv:
|
if not conv:
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import binascii
|
import binascii
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
@ -23,7 +24,6 @@ from functools import partial
|
|||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from langfuse import Langfuse
|
from langfuse import Langfuse
|
||||||
from peewee import fn
|
from peewee import fn
|
||||||
from agentic_reasoning import DeepResearcher
|
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from common.constants import LLMType, ParserType, StatusEnum
|
from common.constants import LLMType, ParserType, StatusEnum
|
||||||
from api.db.db_models import DB, Dialog
|
from api.db.db_models import DB, Dialog
|
||||||
@ -36,7 +36,7 @@ from common.metadata_utils import apply_meta_data_filter
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from common.time_utils import current_timestamp, datetime_format
|
from common.time_utils import current_timestamp, datetime_format
|
||||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||||
from rag.app.resume import forbidden_select_fields4resume
|
from rag.advanced_rag import DeepResearcher
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp.search import index_name
|
from rag.nlp.search import index_name
|
||||||
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||||
@ -196,19 +196,13 @@ async def async_chat_solo(dialog, messages, stream=True):
|
|||||||
if attachments and msg:
|
if attachments and msg:
|
||||||
msg[-1]["content"] += attachments
|
msg[-1]["content"] += attachments
|
||||||
if stream:
|
if stream:
|
||||||
last_ans = ""
|
stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||||
delta_ans = ""
|
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||||
answer = ""
|
if kind == "marker":
|
||||||
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||||||
answer = ans
|
yield {"answer": "", "reference": {}, "audio_binary": None, "prompt": "", "created_at": time.time(), "final": False, **flags}
|
||||||
delta_ans = ans[len(last_ans):]
|
|
||||||
if num_tokens_from_string(delta_ans) < 16:
|
|
||||||
continue
|
continue
|
||||||
last_ans = answer
|
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False}
|
||||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
|
||||||
delta_ans = ""
|
|
||||||
if delta_ans:
|
|
||||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
|
||||||
else:
|
else:
|
||||||
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||||
user_content = msg[-1].get("content", "[content not available]")
|
user_content = msg[-1].get("content", "[content not available]")
|
||||||
@ -279,6 +273,7 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
|||||||
|
|
||||||
|
|
||||||
async def async_chat(dialog, messages, stream=True, **kwargs):
|
async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||||
|
logging.debug("Begin async_chat")
|
||||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||||
async for ans in async_chat_solo(dialog, messages, stream):
|
async for ans in async_chat_solo(dialog, messages, stream):
|
||||||
@ -301,10 +296,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id)
|
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id)
|
||||||
if langfuse_keys:
|
if langfuse_keys:
|
||||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||||||
if langfuse.auth_check():
|
try:
|
||||||
langfuse_tracer = langfuse
|
if langfuse.auth_check():
|
||||||
trace_id = langfuse_tracer.create_trace_id()
|
langfuse_tracer = langfuse
|
||||||
trace_context = {"trace_id": trace_id}
|
trace_id = langfuse_tracer.create_trace_id()
|
||||||
|
trace_context = {"trace_id": trace_id}
|
||||||
|
except Exception:
|
||||||
|
# Skip langfuse tracing if connection fails
|
||||||
|
pass
|
||||||
|
|
||||||
check_langfuse_tracer_ts = timer()
|
check_langfuse_tracer_ts = timer()
|
||||||
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
|
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
|
||||||
@ -324,13 +323,20 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
|
|
||||||
prompt_config = dialog.prompt_config
|
prompt_config = dialog.prompt_config
|
||||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||||
|
logging.debug(f"field_map retrieved: {field_map}")
|
||||||
# try to use sql if field mapping is good to go
|
# try to use sql if field mapping is good to go
|
||||||
if field_map:
|
if field_map:
|
||||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||||
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||||
if ans:
|
# For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid
|
||||||
|
if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")):
|
||||||
yield ans
|
yield ans
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
logging.debug("SQL failed or returned no results, falling back to vector search")
|
||||||
|
|
||||||
|
param_keys = [p["key"] for p in prompt_config.get("parameters", [])]
|
||||||
|
logging.debug(f"attachments={attachments}, param_keys={param_keys}, embd_mdl={embd_mdl}")
|
||||||
|
|
||||||
for p in prompt_config["parameters"]:
|
for p in prompt_config["parameters"]:
|
||||||
if p["key"] == "knowledge":
|
if p["key"] == "knowledge":
|
||||||
@ -367,10 +373,11 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||||
knowledges = []
|
knowledges = []
|
||||||
|
|
||||||
if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]:
|
if attachments is not None and "knowledge" in param_keys:
|
||||||
|
logging.debug("Proceeding with retrieval")
|
||||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||||
knowledges = []
|
knowledges = []
|
||||||
if prompt_config.get("reasoning", False):
|
if prompt_config.get("reasoning", False) or kwargs.get("reasoning"):
|
||||||
reasoner = DeepResearcher(
|
reasoner = DeepResearcher(
|
||||||
chat_mdl,
|
chat_mdl,
|
||||||
prompt_config,
|
prompt_config,
|
||||||
@ -386,16 +393,28 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
doc_ids=attachments,
|
doc_ids=attachments,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
async def callback(msg:str):
|
||||||
|
nonlocal queue
|
||||||
|
await queue.put(msg + "<br/>")
|
||||||
|
|
||||||
|
await callback("<START_DEEP_RESEARCH>")
|
||||||
|
task = asyncio.create_task(reasoner.research(kbinfos, questions[-1], questions[-1], callback=callback))
|
||||||
|
while True:
|
||||||
|
msg = await queue.get()
|
||||||
|
if msg.find("<START_DEEP_RESEARCH>") == 0:
|
||||||
|
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "start_to_think": True}
|
||||||
|
elif msg.find("<END_DEEP_RESEARCH>") == 0:
|
||||||
|
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True}
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
yield {"answer": msg, "reference": {}, "audio_binary": None, "final": False}
|
||||||
|
|
||||||
|
await task
|
||||||
|
|
||||||
async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
|
|
||||||
if isinstance(think, str):
|
|
||||||
thought = think
|
|
||||||
knowledges = [t for t in think.split("\n") if t]
|
|
||||||
elif stream:
|
|
||||||
yield think
|
|
||||||
else:
|
else:
|
||||||
if embd_mdl:
|
if embd_mdl:
|
||||||
kbinfos = retriever.retrieval(
|
kbinfos = await retriever.retrieval(
|
||||||
" ".join(questions),
|
" ".join(questions),
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
@ -406,12 +425,12 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
dialog.vector_similarity_weight,
|
dialog.vector_similarity_weight,
|
||||||
doc_ids=attachments,
|
doc_ids=attachments,
|
||||||
top=dialog.top_k,
|
top=dialog.top_k,
|
||||||
aggs=False,
|
aggs=True,
|
||||||
rerank_mdl=rerank_mdl,
|
rerank_mdl=rerank_mdl,
|
||||||
rank_feature=label_question(" ".join(questions), kbs),
|
rank_feature=label_question(" ".join(questions), kbs),
|
||||||
)
|
)
|
||||||
if prompt_config.get("toc_enhance"):
|
if prompt_config.get("toc_enhance"):
|
||||||
cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
|
cks = await retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
|
||||||
if cks:
|
if cks:
|
||||||
kbinfos["chunks"] = cks
|
kbinfos["chunks"] = cks
|
||||||
kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids)
|
kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids)
|
||||||
@ -421,21 +440,19 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||||
if prompt_config.get("use_kg"):
|
if prompt_config.get("use_kg"):
|
||||||
ck = settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
||||||
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
kbinfos["chunks"].insert(0, ck)
|
kbinfos["chunks"].insert(0, ck)
|
||||||
|
|
||||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||||
|
|
||||||
logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||||
|
|
||||||
retrieval_ts = timer()
|
retrieval_ts = timer()
|
||||||
if not knowledges and prompt_config.get("empty_response"):
|
if not knowledges and prompt_config.get("empty_response"):
|
||||||
empty_res = prompt_config["empty_response"]
|
empty_res = prompt_config["empty_response"]
|
||||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||||
"audio_binary": tts(tts_mdl, empty_res)}
|
"audio_binary": tts(tts_mdl, empty_res), "final": True}
|
||||||
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
|
||||||
return
|
return
|
||||||
|
|
||||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||||
@ -538,21 +555,22 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
last_ans = ""
|
stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf)
|
||||||
answer = ""
|
last_state = None
|
||||||
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||||
if thought:
|
last_state = state
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
if kind == "marker":
|
||||||
answer = ans
|
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||||||
delta_ans = ans[len(last_ans):]
|
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, **flags}
|
||||||
if num_tokens_from_string(delta_ans) < 16:
|
|
||||||
continue
|
continue
|
||||||
last_ans = answer
|
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "final": False}
|
||||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
full_answer = last_state.full_text if last_state else ""
|
||||||
delta_ans = answer[len(last_ans):]
|
if full_answer:
|
||||||
if delta_ans:
|
final = decorate_answer(thought + full_answer)
|
||||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
final["final"] = True
|
||||||
yield decorate_answer(thought + answer)
|
final["audio_binary"] = None
|
||||||
|
final["answer"] = ""
|
||||||
|
yield final
|
||||||
else:
|
else:
|
||||||
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||||
user_content = msg[-1].get("content", "[content not available]")
|
user_content = msg[-1].get("content", "[content not available]")
|
||||||
@ -565,112 +583,306 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||||
sys_prompt = """
|
logging.debug(f"use_sql: Question: {question}")
|
||||||
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
|
||||||
Ensure that:
|
|
||||||
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
|
|
||||||
2. Write only the SQL, no explanations or additional text.
|
|
||||||
"""
|
|
||||||
user_prompt = """
|
|
||||||
Table name: {};
|
|
||||||
Table of database fields are as follows:
|
|
||||||
{}
|
|
||||||
|
|
||||||
Question are as follows:
|
# Determine which document engine we're using
|
||||||
|
doc_engine = "infinity" if settings.DOC_ENGINE_INFINITY else "es"
|
||||||
|
|
||||||
|
# Construct the full table name
|
||||||
|
# For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause)
|
||||||
|
# For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table)
|
||||||
|
base_table = index_name(tenant_id)
|
||||||
|
if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1:
|
||||||
|
# Infinity: append kb_id to table name
|
||||||
|
table_name = f"{base_table}_{kb_ids[0]}"
|
||||||
|
logging.debug(f"use_sql: Using Infinity table name: {table_name}")
|
||||||
|
else:
|
||||||
|
# Elasticsearch/OpenSearch: use base index name
|
||||||
|
table_name = base_table
|
||||||
|
logging.debug(f"use_sql: Using ES/OS table name: {table_name}")
|
||||||
|
|
||||||
|
# Generate engine-specific SQL prompts
|
||||||
|
if doc_engine == "infinity":
|
||||||
|
# Build Infinity prompts with JSON extraction context
|
||||||
|
json_field_names = list(field_map.keys())
|
||||||
|
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
|
||||||
|
|
||||||
|
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
|
||||||
|
Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT)
|
||||||
|
NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false
|
||||||
|
|
||||||
|
RULES:
|
||||||
|
1. Use EXACT field names (case-sensitive) from the list below
|
||||||
|
2. For SELECT: include doc_id, docnm, and json_extract_string() for requested fields
|
||||||
|
3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...))
|
||||||
|
4. Add AS alias for extracted field names
|
||||||
|
5. DO NOT select 'content' field
|
||||||
|
6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when:
|
||||||
|
- Question asks to "show me" or "display" specific columns
|
||||||
|
- Question mentions "not null" or "excluding null"
|
||||||
|
- Add NULL check for count specific column
|
||||||
|
- DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls)
|
||||||
|
7. Output ONLY the SQL, no explanations"""
|
||||||
|
user_prompt = """Table: {}
|
||||||
|
Fields (EXACT case): {}
|
||||||
{}
|
{}
|
||||||
Please write the SQL, only SQL, without any other explanations or text.
|
Question: {}
|
||||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
|
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format(
|
||||||
|
table_name,
|
||||||
|
", ".join(json_field_names),
|
||||||
|
"\n".join([f" - {field}" for field in json_field_names]),
|
||||||
|
question
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Build ES/OS prompts with direct field access
|
||||||
|
sys_prompt = """You are a Database Administrator. Write SQL queries.
|
||||||
|
|
||||||
|
RULES:
|
||||||
|
1. Use EXACT field names from the schema below (e.g., product_tks, not product)
|
||||||
|
2. Quote field names starting with digit: "123_field"
|
||||||
|
3. Add IS NOT NULL in WHERE clause when:
|
||||||
|
- Question asks to "show me" or "display" specific columns
|
||||||
|
4. Include doc_id/docnm in non-aggregate statement
|
||||||
|
5. Output ONLY the SQL, no explanations"""
|
||||||
|
user_prompt = """Table: {}
|
||||||
|
Available fields:
|
||||||
|
{}
|
||||||
|
Question: {}
|
||||||
|
Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
|
||||||
|
table_name,
|
||||||
|
"\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
|
||||||
|
question
|
||||||
|
)
|
||||||
|
|
||||||
tried_times = 0
|
tried_times = 0
|
||||||
|
|
||||||
async def get_table():
|
async def get_table():
|
||||||
nonlocal sys_prompt, user_prompt, question, tried_times
|
nonlocal sys_prompt, user_prompt, question, tried_times
|
||||||
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||||
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
|
logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
|
||||||
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
# Remove think blocks if present (format: </think>...)
|
||||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
sql = re.sub(r"</think>\n.*?\n\s*", "", sql, flags=re.DOTALL)
|
||||||
sql = re.sub(r".*select ", "select ", sql.lower())
|
sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL)
|
||||||
sql = re.sub(r" +", " ", sql)
|
# Remove markdown code blocks (```sql ... ```)
|
||||||
sql = re.sub(r"([;;]|```).*", "", sql)
|
sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
|
||||||
sql = re.sub(r"&", "and", sql)
|
sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
|
||||||
if sql[: len("select ")] != "select ":
|
# Remove trailing semicolon that ES SQL parser doesn't like
|
||||||
return None, None
|
sql = sql.rstrip().rstrip(';').strip()
|
||||||
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
|
|
||||||
if sql[: len("select *")] != "select *":
|
|
||||||
sql = "select doc_id,docnm_kwd," + sql[6:]
|
|
||||||
else:
|
|
||||||
flds = []
|
|
||||||
for k in field_map.keys():
|
|
||||||
if k in forbidden_select_fields4resume:
|
|
||||||
continue
|
|
||||||
if len(flds) > 11:
|
|
||||||
break
|
|
||||||
flds.append(k)
|
|
||||||
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
|
||||||
|
|
||||||
if kb_ids:
|
# Add kb_id filter for ES/OS only (Infinity already has it in table name)
|
||||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
if doc_engine != "infinity" and kb_ids:
|
||||||
if "where" not in sql.lower():
|
# Build kb_filter: single KB or multiple KBs with OR
|
||||||
|
if len(kb_ids) == 1:
|
||||||
|
kb_filter = f"kb_id = '{kb_ids[0]}'"
|
||||||
|
else:
|
||||||
|
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||||
|
|
||||||
|
if "where " not in sql.lower():
|
||||||
o = sql.lower().split("order by")
|
o = sql.lower().split("order by")
|
||||||
if len(o) > 1:
|
if len(o) > 1:
|
||||||
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
|
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
|
||||||
else:
|
else:
|
||||||
sql += f" WHERE {kb_filter}"
|
sql += f" WHERE {kb_filter}"
|
||||||
else:
|
elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower():
|
||||||
sql += f" AND {kb_filter}"
|
sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE)
|
||||||
|
|
||||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||||
tried_times += 1
|
tried_times += 1
|
||||||
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
logging.debug(f"use_sql: Executing SQL retrieval (attempt {tried_times})")
|
||||||
|
tbl = settings.retriever.sql_retrieval(sql, format="json")
|
||||||
|
if tbl is None:
|
||||||
|
logging.debug("use_sql: SQL retrieval returned None")
|
||||||
|
return None, sql
|
||||||
|
logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows")
|
||||||
|
return tbl, sql
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tbl, sql = await get_table()
|
tbl, sql = await get_table()
|
||||||
|
logging.debug(f"use_sql: Initial SQL execution SUCCESS. SQL: {sql}")
|
||||||
|
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows, columns: {[c['name'] for c in tbl.get('columns', [])]}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
user_prompt = """
|
logging.warning(f"use_sql: Initial SQL execution FAILED with error: {e}")
|
||||||
|
# Build retry prompt with error information
|
||||||
|
if doc_engine == "infinity":
|
||||||
|
# Build Infinity error retry prompt
|
||||||
|
json_field_names = list(field_map.keys())
|
||||||
|
user_prompt = """
|
||||||
|
Table name: {};
|
||||||
|
JSON fields available in 'chunk_data' column (use these exact names in json_extract_string):
|
||||||
|
{}
|
||||||
|
|
||||||
|
Question: {}
|
||||||
|
Please write the SQL using json_extract_string(chunk_data, '$.field_name') with the field names from the list above. Only SQL, no explanations.
|
||||||
|
|
||||||
|
|
||||||
|
The SQL error you provided last time is as follows:
|
||||||
|
{}
|
||||||
|
|
||||||
|
Please correct the error and write SQL again using json_extract_string(chunk_data, '$.field_name') syntax with the correct field names. Only SQL, no explanations.
|
||||||
|
""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, e)
|
||||||
|
else:
|
||||||
|
# Build ES/OS error retry prompt
|
||||||
|
user_prompt = """
|
||||||
Table name: {};
|
Table name: {};
|
||||||
Table of database fields are as follows:
|
Table of database fields are as follows (use the field names directly in SQL):
|
||||||
{}
|
{}
|
||||||
|
|
||||||
Question are as follows:
|
Question are as follows:
|
||||||
{}
|
{}
|
||||||
Please write the SQL, only SQL, without any other explanations or text.
|
Please write the SQL using the exact field names above, only SQL, without any other explanations or text.
|
||||||
|
|
||||||
|
|
||||||
The SQL error you provided last time is as follows:
|
The SQL error you provided last time is as follows:
|
||||||
{}
|
{}
|
||||||
|
|
||||||
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
Please correct the error and write SQL again using the exact field names above, only SQL, without any other explanations or text.
|
||||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
|
""".format(table_name, "\n".join([f"{k} ({v})" for k, v in field_map.items()]), question, e)
|
||||||
try:
|
try:
|
||||||
tbl, sql = await get_table()
|
tbl, sql = await get_table()
|
||||||
|
logging.debug(f"use_sql: Retry SQL execution SUCCESS. SQL: {sql}")
|
||||||
|
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows on retry")
|
||||||
except Exception:
|
except Exception:
|
||||||
|
logging.error("use_sql: Retry SQL execution also FAILED, returning None")
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(tbl["rows"]) == 0:
|
if len(tbl["rows"]) == 0:
|
||||||
|
logging.warning(f"use_sql: No rows returned from SQL query, returning None. SQL: {sql}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
logging.debug(f"use_sql: Proceeding with {len(tbl['rows'])} rows to build answer")
|
||||||
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
|
||||||
|
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"])
|
||||||
|
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]])
|
||||||
|
|
||||||
|
logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}")
|
||||||
|
logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}")
|
||||||
|
|
||||||
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
|
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
|
||||||
|
|
||||||
|
logging.debug(f"use_sql: column_idx={column_idx}")
|
||||||
|
logging.debug(f"use_sql: field_map={field_map}")
|
||||||
|
|
||||||
|
# Helper function to map column names to display names
|
||||||
|
def map_column_name(col_name):
|
||||||
|
if col_name.lower() == "count(star)":
|
||||||
|
return "COUNT(*)"
|
||||||
|
|
||||||
|
# First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.)
|
||||||
|
# Pattern: anything AS alias_name
|
||||||
|
as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE)
|
||||||
|
if as_match:
|
||||||
|
alias = as_match.group(1).strip('"\'')
|
||||||
|
|
||||||
|
# Use the alias for display name lookup
|
||||||
|
if alias in field_map:
|
||||||
|
display = field_map[alias]
|
||||||
|
return re.sub(r"(/.*|([^()]+))", "", display)
|
||||||
|
# If alias not in field_map, try to match case-insensitively
|
||||||
|
for field_key, display_value in field_map.items():
|
||||||
|
if field_key.lower() == alias.lower():
|
||||||
|
return re.sub(r"(/.*|([^()]+))", "", display_value)
|
||||||
|
# Return alias as-is if no mapping found
|
||||||
|
return alias
|
||||||
|
|
||||||
|
# Try direct mapping first (for simple column names)
|
||||||
|
if col_name in field_map:
|
||||||
|
display = field_map[col_name]
|
||||||
|
# Clean up any suffix patterns
|
||||||
|
return re.sub(r"(/.*|([^()]+))", "", display)
|
||||||
|
|
||||||
|
# Try case-insensitive match for simple column names
|
||||||
|
col_lower = col_name.lower()
|
||||||
|
for field_key, display_value in field_map.items():
|
||||||
|
if field_key.lower() == col_lower:
|
||||||
|
return re.sub(r"(/.*|([^()]+))", "", display_value)
|
||||||
|
|
||||||
|
# For aggregate expressions or complex expressions without AS alias,
|
||||||
|
# try to replace field names with display names
|
||||||
|
result = col_name
|
||||||
|
for field_name, display_name in field_map.items():
|
||||||
|
# Replace field_name with display_name in the expression
|
||||||
|
result = result.replace(field_name, display_name)
|
||||||
|
|
||||||
|
# Clean up any suffix patterns
|
||||||
|
result = re.sub(r"(/.*|([^()]+))", "", result)
|
||||||
|
return result
|
||||||
|
|
||||||
# compose Markdown table
|
# compose Markdown table
|
||||||
columns = (
|
columns = (
|
||||||
"|" + "|".join(
|
"|" + "|".join(
|
||||||
[re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
|
[map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + (
|
||||||
"|Source|" if docid_idx and docid_idx else "|")
|
"|Source|" if docid_idx and doc_name_idx else "|")
|
||||||
)
|
)
|
||||||
|
|
||||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||||
|
|
||||||
rows = ["|" + "|".join([remove_redundant_spaces(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
|
# Build rows ensuring column names match values - create a dict for each row
|
||||||
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
# keyed by column name to handle any SQL column order
|
||||||
|
rows = []
|
||||||
|
for row_idx, r in enumerate(tbl["rows"]):
|
||||||
|
row_dict = {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)}
|
||||||
|
if row_idx == 0:
|
||||||
|
logging.debug(f"use_sql: First row data: {row_dict}")
|
||||||
|
row_values = []
|
||||||
|
for col_idx in column_idx:
|
||||||
|
col_name = tbl["columns"][col_idx]["name"]
|
||||||
|
value = row_dict.get(col_name, " ")
|
||||||
|
row_values.append(remove_redundant_spaces(str(value)).replace("None", " "))
|
||||||
|
# Add Source column with citation marker if Source column exists
|
||||||
|
if docid_idx and doc_name_idx:
|
||||||
|
row_values.append(f" ##{row_idx}$$")
|
||||||
|
row_str = "|" + "|".join(row_values) + "|"
|
||||||
|
if re.sub(r"[ |]+", "", row_str):
|
||||||
|
rows.append(row_str)
|
||||||
if quota:
|
if quota:
|
||||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
rows = "\n".join(rows)
|
||||||
else:
|
else:
|
||||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
rows = "\n".join(rows)
|
||||||
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
||||||
|
|
||||||
if not docid_idx or not doc_name_idx:
|
if not docid_idx or not doc_name_idx:
|
||||||
logging.warning("SQL missing field: " + sql)
|
logging.warning(f"use_sql: SQL missing required doc_id or docnm_kwd field. docid_idx={docid_idx}, doc_name_idx={doc_name_idx}. SQL: {sql}")
|
||||||
|
# For aggregate queries (COUNT, SUM, AVG, MAX, MIN, DISTINCT), fetch doc_id, docnm_kwd separately
|
||||||
|
# to provide source chunks, but keep the original table format answer
|
||||||
|
if re.search(r"(count|sum|avg|max|min|distinct)\s*\(", sql.lower()):
|
||||||
|
# Keep original table format as answer
|
||||||
|
answer = "\n".join([columns, line, rows])
|
||||||
|
|
||||||
|
# Now fetch doc_id, docnm_kwd to provide source chunks
|
||||||
|
# Extract WHERE clause from the original SQL
|
||||||
|
where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE)
|
||||||
|
if where_match:
|
||||||
|
where_clause = where_match.group(1).strip()
|
||||||
|
# Build a query to get doc_id and docnm_kwd with the same WHERE clause
|
||||||
|
chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}"
|
||||||
|
# Add LIMIT to avoid fetching too many chunks
|
||||||
|
if "limit" not in chunks_sql.lower():
|
||||||
|
chunks_sql += " limit 20"
|
||||||
|
logging.debug(f"use_sql: Fetching chunks with SQL: {chunks_sql}")
|
||||||
|
try:
|
||||||
|
chunks_tbl = settings.retriever.sql_retrieval(chunks_sql, format="json")
|
||||||
|
if chunks_tbl.get("rows") and len(chunks_tbl["rows"]) > 0:
|
||||||
|
# Build chunks reference - use case-insensitive matching
|
||||||
|
chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None)
|
||||||
|
chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None)
|
||||||
|
if chunks_did_idx is not None and chunks_dn_idx is not None:
|
||||||
|
chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]]
|
||||||
|
# Build doc_aggs
|
||||||
|
doc_aggs = {}
|
||||||
|
for r in chunks_tbl["rows"]:
|
||||||
|
doc_id = r[chunks_did_idx]
|
||||||
|
doc_name = r[chunks_dn_idx]
|
||||||
|
if doc_id not in doc_aggs:
|
||||||
|
doc_aggs[doc_id] = {"doc_name": doc_name, "count": 0}
|
||||||
|
doc_aggs[doc_id]["count"] += 1
|
||||||
|
doc_aggs_list = [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]
|
||||||
|
logging.debug(f"use_sql: Returning aggregate answer with {len(chunks)} chunks from {len(doc_aggs)} documents")
|
||||||
|
return {"answer": answer, "reference": {"chunks": chunks, "doc_aggs": doc_aggs_list}, "prompt": sys_prompt}
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"use_sql: Failed to fetch chunks: {e}")
|
||||||
|
# Fallback: return answer without chunks
|
||||||
|
return {"answer": answer, "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
|
||||||
|
# Fallback to table format for other cases
|
||||||
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
|
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
|
||||||
|
|
||||||
docid_idx = list(docid_idx)[0]
|
docid_idx = list(docid_idx)[0]
|
||||||
@ -680,7 +892,8 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
if r[docid_idx] not in doc_aggs:
|
if r[docid_idx] not in doc_aggs:
|
||||||
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
|
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
|
||||||
doc_aggs[r[docid_idx]]["count"] += 1
|
doc_aggs[r[docid_idx]]["count"] += 1
|
||||||
return {
|
|
||||||
|
result = {
|
||||||
"answer": "\n".join([columns, line, rows]),
|
"answer": "\n".join([columns, line, rows]),
|
||||||
"reference": {
|
"reference": {
|
||||||
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
|
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
|
||||||
@ -688,6 +901,8 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
},
|
},
|
||||||
"prompt": sys_prompt,
|
"prompt": sys_prompt,
|
||||||
}
|
}
|
||||||
|
logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents")
|
||||||
|
return result
|
||||||
|
|
||||||
def clean_tts_text(text: str) -> str:
|
def clean_tts_text(text: str) -> str:
|
||||||
if not text:
|
if not text:
|
||||||
@ -733,6 +948,84 @@ def tts(tts_mdl, text):
|
|||||||
return None
|
return None
|
||||||
return binascii.hexlify(bin).decode("utf-8")
|
return binascii.hexlify(bin).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class _ThinkStreamState:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.full_text = ""
|
||||||
|
self.last_idx = 0
|
||||||
|
self.endswith_think = False
|
||||||
|
self.last_full = ""
|
||||||
|
self.last_model_full = ""
|
||||||
|
self.in_think = False
|
||||||
|
self.buffer = ""
|
||||||
|
|
||||||
|
|
||||||
|
def _next_think_delta(state: _ThinkStreamState) -> str:
|
||||||
|
full_text = state.full_text
|
||||||
|
if full_text == state.last_full:
|
||||||
|
return ""
|
||||||
|
state.last_full = full_text
|
||||||
|
delta_ans = full_text[state.last_idx:]
|
||||||
|
|
||||||
|
if delta_ans.find("<think>") == 0:
|
||||||
|
state.last_idx += len("<think>")
|
||||||
|
return "<think>"
|
||||||
|
if delta_ans.find("<think>") > 0:
|
||||||
|
delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("<think>")]
|
||||||
|
state.last_idx += delta_ans.find("<think>")
|
||||||
|
return delta_text
|
||||||
|
if delta_ans.endswith("</think>"):
|
||||||
|
state.endswith_think = True
|
||||||
|
elif state.endswith_think:
|
||||||
|
state.endswith_think = False
|
||||||
|
return "</think>"
|
||||||
|
|
||||||
|
state.last_idx = len(full_text)
|
||||||
|
if full_text.endswith("</think>"):
|
||||||
|
state.last_idx -= len("</think>")
|
||||||
|
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||||
|
|
||||||
|
|
||||||
|
async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
|
||||||
|
state = _ThinkStreamState()
|
||||||
|
async for chunk in stream_iter:
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
if chunk.startswith(state.last_model_full):
|
||||||
|
new_part = chunk[len(state.last_model_full):]
|
||||||
|
state.last_model_full = chunk
|
||||||
|
else:
|
||||||
|
new_part = chunk
|
||||||
|
state.last_model_full += chunk
|
||||||
|
if not new_part:
|
||||||
|
continue
|
||||||
|
state.full_text += new_part
|
||||||
|
delta = _next_think_delta(state)
|
||||||
|
if not delta:
|
||||||
|
continue
|
||||||
|
if delta in ("<think>", "</think>"):
|
||||||
|
if delta == "<think>" and state.in_think:
|
||||||
|
continue
|
||||||
|
if delta == "</think>" and not state.in_think:
|
||||||
|
continue
|
||||||
|
if state.buffer:
|
||||||
|
yield ("text", state.buffer, state)
|
||||||
|
state.buffer = ""
|
||||||
|
state.in_think = delta == "<think>"
|
||||||
|
yield ("marker", delta, state)
|
||||||
|
continue
|
||||||
|
state.buffer += delta
|
||||||
|
if num_tokens_from_string(state.buffer) < min_tokens:
|
||||||
|
continue
|
||||||
|
yield ("text", state.buffer, state)
|
||||||
|
state.buffer = ""
|
||||||
|
|
||||||
|
if state.buffer:
|
||||||
|
yield ("text", state.buffer, state)
|
||||||
|
state.buffer = ""
|
||||||
|
if state.endswith_think:
|
||||||
|
yield ("marker", "</think>", state)
|
||||||
|
|
||||||
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||||
doc_ids = search_config.get("doc_ids", [])
|
doc_ids = search_config.get("doc_ids", [])
|
||||||
rerank_mdl = None
|
rerank_mdl = None
|
||||||
@ -758,7 +1051,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
|||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||||||
|
|
||||||
kbinfos = retriever.retrieval(
|
kbinfos = await retriever.retrieval(
|
||||||
question=question,
|
question=question,
|
||||||
embd_mdl=embd_mdl,
|
embd_mdl=embd_mdl,
|
||||||
tenant_ids=tenant_ids,
|
tenant_ids=tenant_ids,
|
||||||
@ -769,7 +1062,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
|||||||
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
|
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
|
||||||
top=search_config.get("top_k", 1024),
|
top=search_config.get("top_k", 1024),
|
||||||
doc_ids=doc_ids,
|
doc_ids=doc_ids,
|
||||||
aggs=False,
|
aggs=True,
|
||||||
rerank_mdl=rerank_mdl,
|
rerank_mdl=rerank_mdl,
|
||||||
rank_feature=label_question(question, kbs)
|
rank_feature=label_question(question, kbs)
|
||||||
)
|
)
|
||||||
@ -798,11 +1091,20 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
|||||||
refs["chunks"] = chunks_format(refs)
|
refs["chunks"] = chunks_format(refs)
|
||||||
return {"answer": answer, "reference": refs}
|
return {"answer": answer, "reference": refs}
|
||||||
|
|
||||||
answer = ""
|
stream_iter = chat_mdl.async_chat_streamly_delta(sys_prompt, msg, {"temperature": 0.1})
|
||||||
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
last_state = None
|
||||||
answer = ans
|
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||||
yield {"answer": answer, "reference": {}}
|
last_state = state
|
||||||
yield decorate_answer(answer)
|
if kind == "marker":
|
||||||
|
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||||||
|
yield {"answer": "", "reference": {}, "final": False, **flags}
|
||||||
|
continue
|
||||||
|
yield {"answer": value, "reference": {}, "final": False}
|
||||||
|
full_answer = last_state.full_text if last_state else ""
|
||||||
|
final = decorate_answer(full_answer)
|
||||||
|
final["final"] = True
|
||||||
|
final["answer"] = ""
|
||||||
|
yield final
|
||||||
|
|
||||||
|
|
||||||
async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||||
@ -825,7 +1127,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
|||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||||||
|
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = await settings.retriever.retrieval(
|
||||||
question=question,
|
question=question,
|
||||||
embd_mdl=embd_mdl,
|
embd_mdl=embd_mdl,
|
||||||
tenant_ids=tenant_ids,
|
tenant_ids=tenant_ids,
|
||||||
|
|||||||
@ -33,12 +33,13 @@ from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTena
|
|||||||
from api.db.db_utils import bulk_insert_into_db
|
from api.db.db_utils import bulk_insert_into_db
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from common.metadata_utils import dedupe_list
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.time_utils import current_timestamp, get_format_time
|
from common.time_utils import current_timestamp, get_format_time
|
||||||
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
|
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from rag.utils.doc_store_conn import OrderByExpr
|
from common.doc_store.doc_store_base import OrderByExpr
|
||||||
from common import settings
|
from common import settings
|
||||||
|
|
||||||
|
|
||||||
@ -124,26 +125,26 @@ class DocumentService(CommonService):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
|
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, doc_ids=None, return_empty_metadata=False):
|
||||||
orderby, desc, keywords, run_status, types, suffix, doc_ids=None):
|
|
||||||
fields = cls.get_cls_model_fields()
|
fields = cls.get_cls_model_fields()
|
||||||
if keywords:
|
if keywords:
|
||||||
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
|
docs = (
|
||||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])
|
||||||
.join(File, on=(File.id == File2Document.file_id))\
|
.join(File2Document, on=(File2Document.document_id == cls.model.id))
|
||||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
.join(File, on=(File.id == File2Document.file_id))
|
||||||
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
|
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)
|
||||||
.where(
|
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)
|
||||||
(cls.model.kb_id == kb_id),
|
.where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())))
|
||||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
|
docs = (
|
||||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])
|
||||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
.join(File2Document, on=(File2Document.document_id == cls.model.id))
|
||||||
.join(File, on=(File.id == File2Document.file_id))\
|
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)
|
||||||
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
|
.join(File, on=(File.id == File2Document.file_id))
|
||||||
|
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)
|
||||||
.where(cls.model.kb_id == kb_id)
|
.where(cls.model.kb_id == kb_id)
|
||||||
|
)
|
||||||
|
|
||||||
if doc_ids:
|
if doc_ids:
|
||||||
docs = docs.where(cls.model.id.in_(doc_ids))
|
docs = docs.where(cls.model.id.in_(doc_ids))
|
||||||
@ -153,6 +154,8 @@ class DocumentService(CommonService):
|
|||||||
docs = docs.where(cls.model.type.in_(types))
|
docs = docs.where(cls.model.type.in_(types))
|
||||||
if suffix:
|
if suffix:
|
||||||
docs = docs.where(cls.model.suffix.in_(suffix))
|
docs = docs.where(cls.model.suffix.in_(suffix))
|
||||||
|
if return_empty_metadata:
|
||||||
|
docs = docs.where(fn.COALESCE(fn.JSON_LENGTH(cls.model.meta_fields), 0) == 0)
|
||||||
|
|
||||||
count = docs.count()
|
count = docs.count()
|
||||||
if desc:
|
if desc:
|
||||||
@ -160,7 +163,6 @@ class DocumentService(CommonService):
|
|||||||
else:
|
else:
|
||||||
docs = docs.order_by(cls.model.getter_by(orderby).asc())
|
docs = docs.order_by(cls.model.getter_by(orderby).asc())
|
||||||
|
|
||||||
|
|
||||||
if page_number and items_per_page:
|
if page_number and items_per_page:
|
||||||
docs = docs.paginate(page_number, items_per_page)
|
docs = docs.paginate(page_number, items_per_page)
|
||||||
|
|
||||||
@ -180,6 +182,16 @@ class DocumentService(CommonService):
|
|||||||
"1": 2,
|
"1": 2,
|
||||||
"2": 2
|
"2": 2
|
||||||
}
|
}
|
||||||
|
"metadata": {
|
||||||
|
"key1": {
|
||||||
|
"key1_value1": 1,
|
||||||
|
"key1_value2": 2,
|
||||||
|
},
|
||||||
|
"key2": {
|
||||||
|
"key2_value1": 2,
|
||||||
|
"key2_value2": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
}, total
|
}, total
|
||||||
where "1" => RUNNING, "2" => CANCEL
|
where "1" => RUNNING, "2" => CANCEL
|
||||||
"""
|
"""
|
||||||
@ -200,19 +212,42 @@ class DocumentService(CommonService):
|
|||||||
if suffix:
|
if suffix:
|
||||||
query = query.where(cls.model.suffix.in_(suffix))
|
query = query.where(cls.model.suffix.in_(suffix))
|
||||||
|
|
||||||
rows = query.select(cls.model.run, cls.model.suffix)
|
rows = query.select(cls.model.run, cls.model.suffix, cls.model.meta_fields)
|
||||||
total = rows.count()
|
total = rows.count()
|
||||||
|
|
||||||
suffix_counter = {}
|
suffix_counter = {}
|
||||||
run_status_counter = {}
|
run_status_counter = {}
|
||||||
|
metadata_counter = {}
|
||||||
|
empty_metadata_count = 0
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1
|
suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1
|
||||||
run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1
|
run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1
|
||||||
|
meta_fields = row.meta_fields or {}
|
||||||
|
if not meta_fields:
|
||||||
|
empty_metadata_count += 1
|
||||||
|
continue
|
||||||
|
has_valid_meta = False
|
||||||
|
for key, value in meta_fields.items():
|
||||||
|
values = value if isinstance(value, list) else [value]
|
||||||
|
for vv in values:
|
||||||
|
if vv is None:
|
||||||
|
continue
|
||||||
|
if isinstance(vv, str) and not vv.strip():
|
||||||
|
continue
|
||||||
|
sv = str(vv)
|
||||||
|
if key not in metadata_counter:
|
||||||
|
metadata_counter[key] = {}
|
||||||
|
metadata_counter[key][sv] = metadata_counter[key].get(sv, 0) + 1
|
||||||
|
has_valid_meta = True
|
||||||
|
if not has_valid_meta:
|
||||||
|
empty_metadata_count += 1
|
||||||
|
|
||||||
|
metadata_counter["empty_metadata"] = {"true": empty_metadata_count}
|
||||||
return {
|
return {
|
||||||
"suffix": suffix_counter,
|
"suffix": suffix_counter,
|
||||||
"run_status": run_status_counter
|
"run_status": run_status_counter,
|
||||||
|
"metadata": metadata_counter,
|
||||||
}, total
|
}, total
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -305,28 +340,35 @@ class DocumentService(CommonService):
|
|||||||
def remove_document(cls, doc, tenant_id):
|
def remove_document(cls, doc, tenant_id):
|
||||||
from api.db.services.task_service import TaskService
|
from api.db.services.task_service import TaskService
|
||||||
cls.clear_chunk_num(doc.id)
|
cls.clear_chunk_num(doc.id)
|
||||||
|
|
||||||
|
# Delete tasks first
|
||||||
try:
|
try:
|
||||||
TaskService.filter_delete([Task.doc_id == doc.id])
|
TaskService.filter_delete([Task.doc_id == doc.id])
|
||||||
page = 0
|
except Exception as e:
|
||||||
page_size = 1000
|
logging.warning(f"Failed to delete tasks for document {doc.id}: {e}")
|
||||||
all_chunk_ids = []
|
|
||||||
while True:
|
# Delete chunk images (non-critical, log and continue)
|
||||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
try:
|
||||||
page * page_size, page_size, search.index_name(tenant_id),
|
cls.delete_chunk_images(doc, tenant_id)
|
||||||
[doc.kb_id])
|
except Exception as e:
|
||||||
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
|
logging.warning(f"Failed to delete chunk images for document {doc.id}: {e}")
|
||||||
if not chunk_ids:
|
|
||||||
break
|
# Delete thumbnail (non-critical, log and continue)
|
||||||
all_chunk_ids.extend(chunk_ids)
|
try:
|
||||||
page += 1
|
|
||||||
for cid in all_chunk_ids:
|
|
||||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
|
||||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
|
||||||
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
|
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
|
||||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
|
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
|
||||||
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
||||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to delete thumbnail for document {doc.id}: {e}")
|
||||||
|
|
||||||
|
# Delete chunks from doc store - this is critical, log errors
|
||||||
|
try:
|
||||||
|
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to delete chunks from doc store for document {doc.id}: {e}")
|
||||||
|
|
||||||
|
# Cleanup knowledge graph references (non-critical, log and continue)
|
||||||
|
try:
|
||||||
graph_source = settings.docStoreConn.get_fields(
|
graph_source = settings.docStoreConn.get_fields(
|
||||||
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||||
)
|
)
|
||||||
@ -339,10 +381,28 @@ class DocumentService(CommonService):
|
|||||||
search.index_name(tenant_id), doc.kb_id)
|
search.index_name(tenant_id), doc.kb_id)
|
||||||
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
|
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
|
||||||
search.index_name(tenant_id), doc.kb_id)
|
search.index_name(tenant_id), doc.kb_id)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logging.warning(f"Failed to cleanup knowledge graph for document {doc.id}: {e}")
|
||||||
|
|
||||||
return cls.delete_by_id(doc.id)
|
return cls.delete_by_id(doc.id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_chunk_images(cls, doc, tenant_id):
|
||||||
|
page = 0
|
||||||
|
page_size = 1000
|
||||||
|
while True:
|
||||||
|
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||||
|
page * page_size, page_size, search.index_name(tenant_id),
|
||||||
|
[doc.kb_id])
|
||||||
|
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
|
||||||
|
if not chunk_ids:
|
||||||
|
break
|
||||||
|
for cid in chunk_ids:
|
||||||
|
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||||
|
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||||
|
page += 1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_newly_uploaded(cls):
|
def get_newly_uploaded(cls):
|
||||||
@ -385,6 +445,7 @@ class DocumentService(CommonService):
|
|||||||
.where(
|
.where(
|
||||||
cls.model.status == StatusEnum.VALID.value,
|
cls.model.status == StatusEnum.VALID.value,
|
||||||
~(cls.model.type == FileType.VIRTUAL.value),
|
~(cls.model.type == FileType.VIRTUAL.value),
|
||||||
|
((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)),
|
||||||
(((cls.model.progress < 1) & (cls.model.progress > 0)) |
|
(((cls.model.progress < 1) & (cls.model.progress > 0)) |
|
||||||
(cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
|
(cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
|
||||||
return list(docs.dicts())
|
return list(docs.dicts())
|
||||||
@ -665,10 +726,14 @@ class DocumentService(CommonService):
|
|||||||
for k,v in r.meta_fields.items():
|
for k,v in r.meta_fields.items():
|
||||||
if k not in meta:
|
if k not in meta:
|
||||||
meta[k] = {}
|
meta[k] = {}
|
||||||
v = str(v)
|
if not isinstance(v, list):
|
||||||
if v not in meta[k]:
|
v = [v]
|
||||||
meta[k][v] = []
|
for vv in v:
|
||||||
meta[k][v].append(doc_id)
|
if vv not in meta[k]:
|
||||||
|
if isinstance(vv, list) or isinstance(vv, dict):
|
||||||
|
continue
|
||||||
|
meta[k][vv] = []
|
||||||
|
meta[k][vv].append(doc_id)
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -711,10 +776,25 @@ class DocumentService(CommonService):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_metadata_summary(cls, kb_id):
|
def get_metadata_summary(cls, kb_id, document_ids=None):
|
||||||
|
def _meta_value_type(value):
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, list):
|
||||||
|
return "list"
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return "string"
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return "number"
|
||||||
|
return "string"
|
||||||
|
|
||||||
fields = [cls.model.id, cls.model.meta_fields]
|
fields = [cls.model.id, cls.model.meta_fields]
|
||||||
summary = {}
|
summary = {}
|
||||||
for r in cls.model.select(*fields).where(cls.model.kb_id == kb_id):
|
type_counter = {}
|
||||||
|
query = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
|
||||||
|
if document_ids:
|
||||||
|
query = query.where(cls.model.id.in_(document_ids))
|
||||||
|
for r in query:
|
||||||
meta_fields = r.meta_fields or {}
|
meta_fields = r.meta_fields or {}
|
||||||
if isinstance(meta_fields, str):
|
if isinstance(meta_fields, str):
|
||||||
try:
|
try:
|
||||||
@ -724,6 +804,11 @@ class DocumentService(CommonService):
|
|||||||
if not isinstance(meta_fields, dict):
|
if not isinstance(meta_fields, dict):
|
||||||
continue
|
continue
|
||||||
for k, v in meta_fields.items():
|
for k, v in meta_fields.items():
|
||||||
|
value_type = _meta_value_type(v)
|
||||||
|
if value_type:
|
||||||
|
if k not in type_counter:
|
||||||
|
type_counter[k] = {}
|
||||||
|
type_counter[k][value_type] = type_counter[k].get(value_type, 0) + 1
|
||||||
values = v if isinstance(v, list) else [v]
|
values = v if isinstance(v, list) else [v]
|
||||||
for vv in values:
|
for vv in values:
|
||||||
if not vv:
|
if not vv:
|
||||||
@ -732,11 +817,19 @@ class DocumentService(CommonService):
|
|||||||
if k not in summary:
|
if k not in summary:
|
||||||
summary[k] = {}
|
summary[k] = {}
|
||||||
summary[k][sv] = summary[k].get(sv, 0) + 1
|
summary[k][sv] = summary[k].get(sv, 0) + 1
|
||||||
return {k: sorted([(val, cnt) for val, cnt in v.items()], key=lambda x: x[1], reverse=True) for k, v in summary.items()}
|
result = {}
|
||||||
|
for k, v in summary.items():
|
||||||
|
values = sorted([(val, cnt) for val, cnt in v.items()], key=lambda x: x[1], reverse=True)
|
||||||
|
type_counts = type_counter.get(k, {})
|
||||||
|
value_type = "string"
|
||||||
|
if type_counts:
|
||||||
|
value_type = max(type_counts.items(), key=lambda item: item[1])[0]
|
||||||
|
result[k] = {"type": value_type, "values": values}
|
||||||
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def batch_update_metadata(cls, kb_id, doc_ids, updates=None, deletes=None):
|
def batch_update_metadata(cls, kb_id, doc_ids, updates=None, deletes=None, adds=None):
|
||||||
updates = updates or []
|
updates = updates or []
|
||||||
deletes = deletes or []
|
deletes = deletes or []
|
||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
@ -759,14 +852,26 @@ class DocumentService(CommonService):
|
|||||||
changed = False
|
changed = False
|
||||||
for upd in updates:
|
for upd in updates:
|
||||||
key = upd.get("key")
|
key = upd.get("key")
|
||||||
if not key or key not in meta:
|
if not key:
|
||||||
continue
|
continue
|
||||||
|
if key not in meta:
|
||||||
|
meta[key] = upd.get("value")
|
||||||
|
|
||||||
new_value = upd.get("value")
|
new_value = upd.get("value")
|
||||||
match_provided = "match" in upd
|
match_provided = "match" in upd
|
||||||
|
if key not in meta:
|
||||||
|
if match_provided:
|
||||||
|
continue
|
||||||
|
meta[key] = dedupe_list(new_value) if isinstance(new_value, list) else new_value
|
||||||
|
changed = True
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(meta[key], list):
|
if isinstance(meta[key], list):
|
||||||
if not match_provided:
|
if not match_provided:
|
||||||
meta[key] = new_value
|
if isinstance(new_value, list):
|
||||||
|
meta[key] = dedupe_list(new_value)
|
||||||
|
else:
|
||||||
|
meta[key] = new_value
|
||||||
changed = True
|
changed = True
|
||||||
else:
|
else:
|
||||||
match_value = upd.get("match")
|
match_value = upd.get("match")
|
||||||
@ -779,7 +884,7 @@ class DocumentService(CommonService):
|
|||||||
else:
|
else:
|
||||||
new_list.append(item)
|
new_list.append(item)
|
||||||
if replaced:
|
if replaced:
|
||||||
meta[key] = new_list
|
meta[key] = dedupe_list(new_list)
|
||||||
changed = True
|
changed = True
|
||||||
else:
|
else:
|
||||||
if not match_provided:
|
if not match_provided:
|
||||||
@ -820,7 +925,7 @@ class DocumentService(CommonService):
|
|||||||
updated_docs = 0
|
updated_docs = 0
|
||||||
with DB.atomic():
|
with DB.atomic():
|
||||||
rows = cls.model.select(cls.model.id, cls.model.meta_fields).where(
|
rows = cls.model.select(cls.model.id, cls.model.meta_fields).where(
|
||||||
(cls.model.id.in_(doc_ids)) & (cls.model.kb_id == kb_id)
|
cls.model.id.in_(doc_ids)
|
||||||
)
|
)
|
||||||
for r in rows:
|
for r in rows:
|
||||||
meta = _normalize_meta(r.meta_fields or {})
|
meta = _normalize_meta(r.meta_fields or {})
|
||||||
@ -869,6 +974,8 @@ class DocumentService(CommonService):
|
|||||||
bad = 0
|
bad = 0
|
||||||
e, doc = DocumentService.get_by_id(d["id"])
|
e, doc = DocumentService.get_by_id(d["id"])
|
||||||
status = doc.run # TaskStatus.RUNNING.value
|
status = doc.run # TaskStatus.RUNNING.value
|
||||||
|
if status == TaskStatus.CANCEL.value:
|
||||||
|
continue
|
||||||
doc_progress = doc.progress if doc and doc.progress else 0.0
|
doc_progress = doc.progress if doc and doc.progress else 0.0
|
||||||
special_task_running = False
|
special_task_running = False
|
||||||
priority = 0
|
priority = 0
|
||||||
@ -912,7 +1019,16 @@ class DocumentService(CommonService):
|
|||||||
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
|
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||||
else:
|
else:
|
||||||
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
|
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||||
cls.update_by_id(d["id"], info)
|
info["update_time"] = current_timestamp()
|
||||||
|
info["update_date"] = get_format_time()
|
||||||
|
(
|
||||||
|
cls.model.update(info)
|
||||||
|
.where(
|
||||||
|
(cls.model.id == d["id"])
|
||||||
|
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))
|
||||||
|
)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("'0'") < 0:
|
if str(e).find("'0'") < 0:
|
||||||
logging.exception("fetch task exception")
|
logging.exception("fetch task exception")
|
||||||
@ -945,7 +1061,7 @@ class DocumentService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
|
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
|
||||||
# cancelled: run == "2" but progress can vary
|
# cancelled: run == "2"
|
||||||
cancelled = (
|
cancelled = (
|
||||||
cls.model.select(fn.COUNT(1))
|
cls.model.select(fn.COUNT(1))
|
||||||
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
|
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
|
||||||
@ -1199,8 +1315,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
d["q_%d_vec" % len(v)] = v
|
d["q_%d_vec" % len(v)] = v
|
||||||
for b in range(0, len(cks), es_bulk_size):
|
for b in range(0, len(cks), es_bulk_size):
|
||||||
if try_create_idx:
|
if try_create_idx:
|
||||||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
if not settings.docStoreConn.index_exist(idxnm, kb_id):
|
||||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0]))
|
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
|
||||||
try_create_idx = False
|
try_create_idx = False
|
||||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||||
|
|
||||||
|
|||||||
@ -65,6 +65,7 @@ class EvaluationService(CommonService):
|
|||||||
(success, dataset_id or error_message)
|
(success, dataset_id or error_message)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
timestamp= current_timestamp()
|
||||||
dataset_id = get_uuid()
|
dataset_id = get_uuid()
|
||||||
dataset = {
|
dataset = {
|
||||||
"id": dataset_id,
|
"id": dataset_id,
|
||||||
@ -73,8 +74,8 @@ class EvaluationService(CommonService):
|
|||||||
"description": description,
|
"description": description,
|
||||||
"kb_ids": kb_ids,
|
"kb_ids": kb_ids,
|
||||||
"created_by": user_id,
|
"created_by": user_id,
|
||||||
"create_time": current_timestamp(),
|
"create_time": timestamp,
|
||||||
"update_time": current_timestamp(),
|
"update_time": timestamp,
|
||||||
"status": StatusEnum.VALID.value
|
"status": StatusEnum.VALID.value
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,21 +225,36 @@ class EvaluationService(CommonService):
|
|||||||
"""
|
"""
|
||||||
success_count = 0
|
success_count = 0
|
||||||
failure_count = 0
|
failure_count = 0
|
||||||
|
case_instances = []
|
||||||
|
|
||||||
for case_data in cases:
|
if not cases:
|
||||||
success, _ = cls.add_test_case(
|
return success_count, failure_count
|
||||||
dataset_id=dataset_id,
|
|
||||||
question=case_data.get("question", ""),
|
|
||||||
reference_answer=case_data.get("reference_answer"),
|
|
||||||
relevant_doc_ids=case_data.get("relevant_doc_ids"),
|
|
||||||
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
|
|
||||||
metadata=case_data.get("metadata")
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
cur_timestamp = current_timestamp()
|
||||||
success_count += 1
|
|
||||||
else:
|
try:
|
||||||
failure_count += 1
|
for case_data in cases:
|
||||||
|
case_id = get_uuid()
|
||||||
|
case_info = {
|
||||||
|
"id": case_id,
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
"question": case_data.get("question", ""),
|
||||||
|
"reference_answer": case_data.get("reference_answer"),
|
||||||
|
"relevant_doc_ids": case_data.get("relevant_doc_ids"),
|
||||||
|
"relevant_chunk_ids": case_data.get("relevant_chunk_ids"),
|
||||||
|
"metadata": case_data.get("metadata"),
|
||||||
|
"create_time": cur_timestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
case_instances.append(EvaluationCase(**case_info))
|
||||||
|
EvaluationCase.bulk_create(case_instances, batch_size=300)
|
||||||
|
success_count = len(case_instances)
|
||||||
|
failure_count = 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error bulk importing test cases: {str(e)}")
|
||||||
|
failure_count = len(cases)
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
return success_count, failure_count
|
return success_count, failure_count
|
||||||
|
|
||||||
|
|||||||
@ -100,7 +100,7 @@ class FileService(CommonService):
|
|||||||
# Returns:
|
# Returns:
|
||||||
# List of dictionaries containing dataset IDs and names
|
# List of dictionaries containing dataset IDs and names
|
||||||
kbs = (
|
kbs = (
|
||||||
cls.model.select(*[Knowledgebase.id, Knowledgebase.name])
|
cls.model.select(*[Knowledgebase.id, Knowledgebase.name, File2Document.document_id])
|
||||||
.join(File2Document, on=(File2Document.file_id == file_id))
|
.join(File2Document, on=(File2Document.file_id == file_id))
|
||||||
.join(Document, on=(File2Document.document_id == Document.id))
|
.join(Document, on=(File2Document.document_id == Document.id))
|
||||||
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
|
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
|
||||||
@ -110,7 +110,7 @@ class FileService(CommonService):
|
|||||||
return []
|
return []
|
||||||
kbs_info_list = []
|
kbs_info_list = []
|
||||||
for kb in list(kbs.dicts()):
|
for kb in list(kbs.dicts()):
|
||||||
kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"]})
|
kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"], "document_id": kb["document_id"]})
|
||||||
return kbs_info_list
|
return kbs_info_list
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -439,6 +439,15 @@ class FileService(CommonService):
|
|||||||
|
|
||||||
err, files = [], []
|
err, files = [], []
|
||||||
for file in file_objs:
|
for file in file_objs:
|
||||||
|
doc_id = file.id if hasattr(file, "id") else get_uuid()
|
||||||
|
e, doc = DocumentService.get_by_id(doc_id)
|
||||||
|
if e:
|
||||||
|
blob = file.read()
|
||||||
|
settings.STORAGE_IMPL.put(kb.id, doc.location, blob, kb.tenant_id)
|
||||||
|
doc.size = len(blob)
|
||||||
|
doc = doc.to_dict()
|
||||||
|
DocumentService.update_by_id(doc["id"], doc)
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
DocumentService.check_doc_health(kb.tenant_id, file.filename)
|
DocumentService.check_doc_health(kb.tenant_id, file.filename)
|
||||||
filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id)
|
filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id)
|
||||||
@ -455,7 +464,6 @@ class FileService(CommonService):
|
|||||||
blob = read_potential_broken_pdf(blob)
|
blob = read_potential_broken_pdf(blob)
|
||||||
settings.STORAGE_IMPL.put(kb.id, location, blob)
|
settings.STORAGE_IMPL.put(kb.id, location, blob)
|
||||||
|
|
||||||
doc_id = get_uuid()
|
|
||||||
|
|
||||||
img = thumbnail_img(filename, blob)
|
img = thumbnail_img(filename, blob)
|
||||||
thumbnail_location = ""
|
thumbnail_location = ""
|
||||||
|
|||||||
@ -397,7 +397,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
if dataset_name == "":
|
if dataset_name == "":
|
||||||
return False, get_data_error_result(message="Dataset name can't be empty.")
|
return False, get_data_error_result(message="Dataset name can't be empty.")
|
||||||
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
|
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||||
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
|
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is large than {DATASET_NAME_LIMIT}")
|
||||||
|
|
||||||
# Deduplicate name within tenant
|
# Deduplicate name within tenant
|
||||||
dataset_name = duplicate_name(
|
dataset_name = duplicate_name(
|
||||||
@ -425,6 +425,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
|
|
||||||
# Update parser_config (always override with validated default/merged config)
|
# Update parser_config (always override with validated default/merged config)
|
||||||
payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config"))
|
payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config"))
|
||||||
|
payload["parser_config"]["llm_id"] = _t.llm_id
|
||||||
|
|
||||||
return True, payload
|
return True, payload
|
||||||
|
|
||||||
|
|||||||
@ -64,10 +64,13 @@ class TenantLangfuseService(CommonService):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def save(cls, **kwargs):
|
def save(cls, **kwargs):
|
||||||
kwargs["create_time"] = current_timestamp()
|
current_ts = current_timestamp()
|
||||||
kwargs["create_date"] = datetime_format(datetime.now())
|
current_date = datetime_format(datetime.now())
|
||||||
kwargs["update_time"] = current_timestamp()
|
|
||||||
kwargs["update_date"] = datetime_format(datetime.now())
|
kwargs["create_time"] = current_ts
|
||||||
|
kwargs["create_date"] = current_date
|
||||||
|
kwargs["update_time"] = current_ts
|
||||||
|
kwargs["update_date"] = current_date
|
||||||
obj = cls.model.create(**kwargs)
|
obj = cls.model.create(**kwargs)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|||||||
@ -441,3 +441,46 @@ class LLMBundle(LLM4Tenant):
|
|||||||
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||||
generation.end()
|
generation.end()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
async def async_chat_streamly_delta(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||||
|
total_tokens = 0
|
||||||
|
ans = ""
|
||||||
|
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
|
||||||
|
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
||||||
|
elif hasattr(self.mdl, "async_chat_streamly"):
|
||||||
|
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
|
||||||
|
|
||||||
|
generation = None
|
||||||
|
if self.langfuse:
|
||||||
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||||
|
|
||||||
|
if stream_fn:
|
||||||
|
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||||
|
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||||
|
try:
|
||||||
|
async for txt in chat_partial(**use_kwargs):
|
||||||
|
if isinstance(txt, int):
|
||||||
|
total_tokens = txt
|
||||||
|
break
|
||||||
|
|
||||||
|
if txt.endswith("</think>"):
|
||||||
|
ans = ans[: -len("</think>")]
|
||||||
|
|
||||||
|
if not self.verbose_tool_use:
|
||||||
|
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||||
|
|
||||||
|
ans += txt
|
||||||
|
yield txt
|
||||||
|
except Exception as e:
|
||||||
|
if generation:
|
||||||
|
generation.update(output={"error": str(e)})
|
||||||
|
generation.end()
|
||||||
|
raise
|
||||||
|
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||||
|
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||||
|
if generation:
|
||||||
|
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||||
|
generation.end()
|
||||||
|
return
|
||||||
|
|||||||
@ -15,7 +15,6 @@
|
|||||||
#
|
#
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from api.apps import current_user
|
|
||||||
from api.db.db_models import DB, Memory, User
|
from api.db.db_models import DB, Memory, User
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
@ -23,6 +22,7 @@ from api.utils.memory_utils import calculate_memory_type
|
|||||||
from api.constants import MEMORY_NAME_LIMIT
|
from api.constants import MEMORY_NAME_LIMIT
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.time_utils import get_format_time, current_timestamp
|
from common.time_utils import get_format_time, current_timestamp
|
||||||
|
from memory.utils.prompt_util import PromptAssembler
|
||||||
|
|
||||||
|
|
||||||
class MemoryService(CommonService):
|
class MemoryService(CommonService):
|
||||||
@ -34,6 +34,17 @@ class MemoryService(CommonService):
|
|||||||
def get_by_memory_id(cls, memory_id: str):
|
def get_by_memory_id(cls, memory_id: str):
|
||||||
return cls.model.select().where(cls.model.id == memory_id).first()
|
return cls.model.select().where(cls.model.id == memory_id).first()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_by_tenant_id(cls, tenant_id: str):
|
||||||
|
return cls.model.select().where(cls.model.tenant_id == tenant_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_memory(cls):
|
||||||
|
memory_list = cls.model.select()
|
||||||
|
return list(memory_list)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_with_owner_name_by_id(cls, memory_id: str):
|
def get_with_owner_name_by_id(cls, memory_id: str):
|
||||||
@ -53,7 +64,9 @@ class MemoryService(CommonService):
|
|||||||
cls.model.forgetting_policy,
|
cls.model.forgetting_policy,
|
||||||
cls.model.temperature,
|
cls.model.temperature,
|
||||||
cls.model.system_prompt,
|
cls.model.system_prompt,
|
||||||
cls.model.user_prompt
|
cls.model.user_prompt,
|
||||||
|
cls.model.create_date,
|
||||||
|
cls.model.create_time
|
||||||
]
|
]
|
||||||
memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
|
memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
|
||||||
cls.model.id == memory_id
|
cls.model.id == memory_id
|
||||||
@ -72,7 +85,9 @@ class MemoryService(CommonService):
|
|||||||
cls.model.memory_type,
|
cls.model.memory_type,
|
||||||
cls.model.storage_type,
|
cls.model.storage_type,
|
||||||
cls.model.permissions,
|
cls.model.permissions,
|
||||||
cls.model.description
|
cls.model.description,
|
||||||
|
cls.model.create_time,
|
||||||
|
cls.model.create_date
|
||||||
]
|
]
|
||||||
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
|
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
|
||||||
if filter_dict.get("tenant_id"):
|
if filter_dict.get("tenant_id"):
|
||||||
@ -102,6 +117,8 @@ class MemoryService(CommonService):
|
|||||||
if len(memory_name) > MEMORY_NAME_LIMIT:
|
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||||
return False, f"Memory name {memory_name} exceeds limit of {MEMORY_NAME_LIMIT}."
|
return False, f"Memory name {memory_name} exceeds limit of {MEMORY_NAME_LIMIT}."
|
||||||
|
|
||||||
|
timestamp = current_timestamp()
|
||||||
|
format_time = get_format_time()
|
||||||
# build create dict
|
# build create dict
|
||||||
memory_info = {
|
memory_info = {
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
@ -110,10 +127,11 @@ class MemoryService(CommonService):
|
|||||||
"tenant_id": tenant_id,
|
"tenant_id": tenant_id,
|
||||||
"embd_id": embd_id,
|
"embd_id": embd_id,
|
||||||
"llm_id": llm_id,
|
"llm_id": llm_id,
|
||||||
"create_time": current_timestamp(),
|
"system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}),
|
||||||
"create_date": get_format_time(),
|
"create_time": timestamp,
|
||||||
"update_time": current_timestamp(),
|
"create_date": format_time,
|
||||||
"update_date": get_format_time(),
|
"update_time": timestamp,
|
||||||
|
"update_date": format_time,
|
||||||
}
|
}
|
||||||
obj = cls.model(**memory_info).save(force_insert=True)
|
obj = cls.model(**memory_info).save(force_insert=True)
|
||||||
|
|
||||||
@ -126,16 +144,18 @@ class MemoryService(CommonService):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_memory(cls, memory_id: str, update_dict: dict):
|
def update_memory(cls, tenant_id: str, memory_id: str, update_dict: dict):
|
||||||
if not update_dict:
|
if not update_dict:
|
||||||
return 0
|
return 0
|
||||||
if "temperature" in update_dict and isinstance(update_dict["temperature"], str):
|
if "temperature" in update_dict and isinstance(update_dict["temperature"], str):
|
||||||
update_dict["temperature"] = float(update_dict["temperature"])
|
update_dict["temperature"] = float(update_dict["temperature"])
|
||||||
|
if "memory_type" in update_dict and isinstance(update_dict["memory_type"], list):
|
||||||
|
update_dict["memory_type"] = calculate_memory_type(update_dict["memory_type"])
|
||||||
if "name" in update_dict:
|
if "name" in update_dict:
|
||||||
update_dict["name"] = duplicate_name(
|
update_dict["name"] = duplicate_name(
|
||||||
cls.query,
|
cls.query,
|
||||||
name=update_dict["name"],
|
name=update_dict["name"],
|
||||||
tenant_id=current_user.id
|
tenant_id=tenant_id
|
||||||
)
|
)
|
||||||
update_dict.update({
|
update_dict.update({
|
||||||
"update_time": current_timestamp(),
|
"update_time": current_timestamp(),
|
||||||
@ -147,4 +167,4 @@ class MemoryService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def delete_memory(cls, memory_id: str):
|
def delete_memory(cls, memory_id: str):
|
||||||
return cls.model.delete().where(cls.model.id == memory_id).execute()
|
return cls.delete_by_id(memory_id)
|
||||||
|
|||||||
@ -169,11 +169,12 @@ class PipelineOperationLogService(CommonService):
|
|||||||
operation_status=operation_status,
|
operation_status=operation_status,
|
||||||
avatar=avatar,
|
avatar=avatar,
|
||||||
)
|
)
|
||||||
log["create_time"] = current_timestamp()
|
timestamp = current_timestamp()
|
||||||
log["create_date"] = datetime_format(datetime.now())
|
datetime_now = datetime_format(datetime.now())
|
||||||
log["update_time"] = current_timestamp()
|
log["create_time"] = timestamp
|
||||||
log["update_date"] = datetime_format(datetime.now())
|
log["create_date"] = datetime_now
|
||||||
|
log["update_time"] = timestamp
|
||||||
|
log["update_date"] = datetime_now
|
||||||
with DB.atomic():
|
with DB.atomic():
|
||||||
obj = cls.save(**log)
|
obj = cls.save(**log)
|
||||||
|
|
||||||
|
|||||||
@ -28,10 +28,13 @@ class SearchService(CommonService):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def save(cls, **kwargs):
|
def save(cls, **kwargs):
|
||||||
kwargs["create_time"] = current_timestamp()
|
current_ts = current_timestamp()
|
||||||
kwargs["create_date"] = datetime_format(datetime.now())
|
current_date = datetime_format(datetime.now())
|
||||||
kwargs["update_time"] = current_timestamp()
|
|
||||||
kwargs["update_date"] = datetime_format(datetime.now())
|
kwargs["create_time"] = current_ts
|
||||||
|
kwargs["create_date"] = current_date
|
||||||
|
kwargs["update_time"] = current_ts
|
||||||
|
kwargs["update_date"] = current_date
|
||||||
obj = cls.model.create(**kwargs)
|
obj = cls.model.create(**kwargs)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|||||||
44
api/db/services/system_settings_service.py
Normal file
44
api/db/services/system_settings_service.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from datetime import datetime
|
||||||
|
from common.time_utils import current_timestamp, datetime_format
|
||||||
|
from api.db.db_models import DB
|
||||||
|
from api.db.db_models import SystemSettings
|
||||||
|
from api.db.services.common_service import CommonService
|
||||||
|
|
||||||
|
|
||||||
|
class SystemSettingsService(CommonService):
|
||||||
|
model = SystemSettings
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_by_name(cls, name):
|
||||||
|
objs = cls.model.select().where(cls.model.name.startswith(name))
|
||||||
|
return objs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def update_by_name(cls, name, obj):
|
||||||
|
obj["update_time"] = current_timestamp()
|
||||||
|
obj["update_date"] = datetime_format(datetime.now())
|
||||||
|
cls.model.update(obj).where(cls.model.name.startswith(name)).execute()
|
||||||
|
return SystemSettings(**obj)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_record_count(cls):
|
||||||
|
count = cls.model.select().count()
|
||||||
|
return count
|
||||||
@ -121,13 +121,6 @@ class TaskService(CommonService):
|
|||||||
.where(cls.model.id == task_id)
|
.where(cls.model.id == task_id)
|
||||||
)
|
)
|
||||||
docs = list(docs.dicts())
|
docs = list(docs.dicts())
|
||||||
# Assuming docs = list(docs.dicts())
|
|
||||||
if docs:
|
|
||||||
kb_config = docs[0]['kb_parser_config'] # Dict from Knowledgebase.parser_config
|
|
||||||
mineru_method = kb_config.get('mineru_parse_method', 'auto')
|
|
||||||
mineru_formula = kb_config.get('mineru_formula_enable', True)
|
|
||||||
mineru_table = kb_config.get('mineru_table_enable', True)
|
|
||||||
print(mineru_method, mineru_formula, mineru_table)
|
|
||||||
if not docs:
|
if not docs:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -179,6 +172,40 @@ class TaskService(CommonService):
|
|||||||
return None
|
return None
|
||||||
return tasks
|
return tasks
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_tasks_progress_by_doc_ids(cls, doc_ids: list[str]):
|
||||||
|
"""Retrieve all tasks associated with specific documents.
|
||||||
|
|
||||||
|
This method fetches all processing tasks for given document ids, ordered by
|
||||||
|
creation time. It includes task progress and chunk information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_ids (str): The unique identifier of the document.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: List of task dictionaries containing task details.
|
||||||
|
Returns None if no tasks are found.
|
||||||
|
"""
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.doc_id,
|
||||||
|
cls.model.from_page,
|
||||||
|
cls.model.progress,
|
||||||
|
cls.model.progress_msg,
|
||||||
|
cls.model.digest,
|
||||||
|
cls.model.chunk_ids,
|
||||||
|
cls.model.create_time
|
||||||
|
]
|
||||||
|
tasks = (
|
||||||
|
cls.model.select(*fields).order_by(cls.model.create_time.desc())
|
||||||
|
.where(cls.model.doc_id.in_(doc_ids))
|
||||||
|
)
|
||||||
|
tasks = list(tasks.dicts())
|
||||||
|
if not tasks:
|
||||||
|
return None
|
||||||
|
return tasks
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_chunk_ids(cls, id: str, chunk_ids: str):
|
def update_chunk_ids(cls, id: str, chunk_ids: str):
|
||||||
@ -495,6 +522,7 @@ def cancel_all_task_of(doc_id):
|
|||||||
def has_canceled(task_id):
|
def has_canceled(task_id):
|
||||||
try:
|
try:
|
||||||
if REDIS_CONN.get(f"{task_id}-cancel"):
|
if REDIS_CONN.get(f"{task_id}-cancel"):
|
||||||
|
logging.info(f"Task: {task_id} has been canceled")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
|
|||||||
@ -19,7 +19,7 @@ import logging
|
|||||||
from peewee import IntegrityError
|
from peewee import IntegrityError
|
||||||
from langfuse import Langfuse
|
from langfuse import Langfuse
|
||||||
from common import settings
|
from common import settings
|
||||||
from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, LLMType
|
from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, PADDLEOCR_DEFAULT_CONFIG, PADDLEOCR_ENV_KEYS, LLMType
|
||||||
from api.db.db_models import DB, LLMFactories, TenantLLM
|
from api.db.db_models import DB, LLMFactories, TenantLLM
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.langfuse_service import TenantLangfuseService
|
from api.db.services.langfuse_service import TenantLangfuseService
|
||||||
@ -60,10 +60,8 @@ class TenantLLMService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_my_llms(cls, tenant_id):
|
def get_my_llms(cls, tenant_id):
|
||||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name,
|
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens, cls.model.status]
|
||||||
cls.model.used_tokens, cls.model.status]
|
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
|
|
||||||
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
|
||||||
|
|
||||||
return list(objs)
|
return list(objs)
|
||||||
|
|
||||||
@ -90,6 +88,7 @@ class TenantLLMService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
||||||
from api.db.services.llm_service import LLMService
|
from api.db.services.llm_service import LLMService
|
||||||
|
|
||||||
e, tenant = TenantService.get_by_id(tenant_id)
|
e, tenant = TenantService.get_by_id(tenant_id)
|
||||||
if not e:
|
if not e:
|
||||||
raise LookupError("Tenant not found")
|
raise LookupError("Tenant not found")
|
||||||
@ -97,7 +96,7 @@ class TenantLLMService(CommonService):
|
|||||||
if llm_type == LLMType.EMBEDDING.value:
|
if llm_type == LLMType.EMBEDDING.value:
|
||||||
mdlnm = tenant.embd_id if not llm_name else llm_name
|
mdlnm = tenant.embd_id if not llm_name else llm_name
|
||||||
elif llm_type == LLMType.SPEECH2TEXT.value:
|
elif llm_type == LLMType.SPEECH2TEXT.value:
|
||||||
mdlnm = tenant.asr_id
|
mdlnm = tenant.asr_id if not llm_name else llm_name
|
||||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||||
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
||||||
elif llm_type == LLMType.CHAT.value:
|
elif llm_type == LLMType.CHAT.value:
|
||||||
@ -119,9 +118,9 @@ class TenantLLMService(CommonService):
|
|||||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||||
if model_config:
|
if model_config:
|
||||||
model_config = model_config.to_dict()
|
model_config = model_config.to_dict()
|
||||||
elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''):
|
elif llm_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv("TEI_MODEL", ""):
|
||||||
embedding_cfg = settings.EMBEDDING_CFG
|
embedding_cfg = settings.EMBEDDING_CFG
|
||||||
model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
|
model_config = {"llm_factory": "Builtin", "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
|
||||||
else:
|
else:
|
||||||
raise LookupError(f"Model({mdlnm}@{fid}) not authorized")
|
raise LookupError(f"Model({mdlnm}@{fid}) not authorized")
|
||||||
|
|
||||||
@ -140,33 +139,27 @@ class TenantLLMService(CommonService):
|
|||||||
if llm_type == LLMType.EMBEDDING.value:
|
if llm_type == LLMType.EMBEDDING.value:
|
||||||
if model_config["llm_factory"] not in EmbeddingModel:
|
if model_config["llm_factory"] not in EmbeddingModel:
|
||||||
return None
|
return None
|
||||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||||
base_url=model_config["api_base"])
|
|
||||||
|
|
||||||
elif llm_type == LLMType.RERANK:
|
elif llm_type == LLMType.RERANK:
|
||||||
if model_config["llm_factory"] not in RerankModel:
|
if model_config["llm_factory"] not in RerankModel:
|
||||||
return None
|
return None
|
||||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||||
base_url=model_config["api_base"])
|
|
||||||
|
|
||||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||||
if model_config["llm_factory"] not in CvModel:
|
if model_config["llm_factory"] not in CvModel:
|
||||||
return None
|
return None
|
||||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
|
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
||||||
base_url=model_config["api_base"], **kwargs)
|
|
||||||
|
|
||||||
elif llm_type == LLMType.CHAT.value:
|
elif llm_type == LLMType.CHAT.value:
|
||||||
if model_config["llm_factory"] not in ChatModel:
|
if model_config["llm_factory"] not in ChatModel:
|
||||||
return None
|
return None
|
||||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
||||||
base_url=model_config["api_base"], **kwargs)
|
|
||||||
|
|
||||||
elif llm_type == LLMType.SPEECH2TEXT:
|
elif llm_type == LLMType.SPEECH2TEXT:
|
||||||
if model_config["llm_factory"] not in Seq2txtModel:
|
if model_config["llm_factory"] not in Seq2txtModel:
|
||||||
return None
|
return None
|
||||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
|
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
||||||
model_name=model_config["llm_name"], lang=lang,
|
|
||||||
base_url=model_config["api_base"])
|
|
||||||
elif llm_type == LLMType.TTS:
|
elif llm_type == LLMType.TTS:
|
||||||
if model_config["llm_factory"] not in TTSModel:
|
if model_config["llm_factory"] not in TTSModel:
|
||||||
return None
|
return None
|
||||||
@ -216,14 +209,11 @@ class TenantLLMService(CommonService):
|
|||||||
try:
|
try:
|
||||||
num = (
|
num = (
|
||||||
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
||||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name,
|
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||||
cls.model.llm_factory == llm_factory if llm_factory else True)
|
|
||||||
.execute()
|
.execute()
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception(
|
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
||||||
"TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s",
|
|
||||||
tenant_id, llm_name)
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
return num
|
return num
|
||||||
@ -231,9 +221,7 @@ class TenantLLMService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_openai_models(cls):
|
def get_openai_models(cls):
|
||||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"),
|
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||||
~(cls.model.llm_name == "text-embedding-3-small"),
|
|
||||||
~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
|
||||||
return list(objs)
|
return list(objs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -298,6 +286,68 @@ class TenantLLMService(CommonService):
|
|||||||
idx += 1
|
idx += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _collect_paddleocr_env_config(cls) -> dict | None:
|
||||||
|
cfg = PADDLEOCR_DEFAULT_CONFIG
|
||||||
|
found = False
|
||||||
|
for key in PADDLEOCR_ENV_KEYS:
|
||||||
|
val = os.environ.get(key)
|
||||||
|
if val:
|
||||||
|
found = True
|
||||||
|
cfg[key] = val
|
||||||
|
return cfg if found else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def ensure_paddleocr_from_env(cls, tenant_id: str) -> str | None:
|
||||||
|
"""
|
||||||
|
Ensure a PaddleOCR model exists for the tenant if env variables are present.
|
||||||
|
Return the existing or newly created llm_name, or None if env not set.
|
||||||
|
"""
|
||||||
|
cfg = cls._collect_paddleocr_env_config()
|
||||||
|
if not cfg:
|
||||||
|
return None
|
||||||
|
|
||||||
|
saved_paddleocr_models = cls.query(tenant_id=tenant_id, llm_factory="PaddleOCR", model_type=LLMType.OCR.value)
|
||||||
|
|
||||||
|
def _parse_api_key(raw: str) -> dict:
|
||||||
|
try:
|
||||||
|
return json.loads(raw or "{}")
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
for item in saved_paddleocr_models:
|
||||||
|
api_cfg = _parse_api_key(item.api_key)
|
||||||
|
normalized = {k: api_cfg.get(k, PADDLEOCR_DEFAULT_CONFIG.get(k)) for k in PADDLEOCR_ENV_KEYS}
|
||||||
|
if normalized == cfg:
|
||||||
|
return item.llm_name
|
||||||
|
|
||||||
|
used_names = {item.llm_name for item in saved_paddleocr_models}
|
||||||
|
idx = 1
|
||||||
|
base_name = "paddleocr-from-env"
|
||||||
|
while True:
|
||||||
|
candidate = f"{base_name}-{idx}"
|
||||||
|
if candidate in used_names:
|
||||||
|
idx += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
cls.save(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
llm_factory="PaddleOCR",
|
||||||
|
llm_name=candidate,
|
||||||
|
model_type=LLMType.OCR.value,
|
||||||
|
api_key=json.dumps(cfg),
|
||||||
|
api_base="",
|
||||||
|
max_tokens=0,
|
||||||
|
)
|
||||||
|
return candidate
|
||||||
|
except IntegrityError:
|
||||||
|
logging.warning("PaddleOCR env model %s already exists for tenant %s, retry with next name", candidate, tenant_id)
|
||||||
|
used_names.add(candidate)
|
||||||
|
idx += 1
|
||||||
|
continue
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def delete_by_tenant_id(cls, tenant_id):
|
def delete_by_tenant_id(cls, tenant_id):
|
||||||
@ -306,6 +356,7 @@ class TenantLLMService(CommonService):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||||
from api.db.services.llm_service import LLMService
|
from api.db.services.llm_service import LLMService
|
||||||
|
|
||||||
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||||
llm_factories = settings.FACTORY_LLM_INFOS
|
llm_factories = settings.FACTORY_LLM_INFOS
|
||||||
for llm_factory in llm_factories:
|
for llm_factory in llm_factories:
|
||||||
@ -340,9 +391,12 @@ class LLM4Tenant:
|
|||||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
||||||
self.langfuse = None
|
self.langfuse = None
|
||||||
if langfuse_keys:
|
if langfuse_keys:
|
||||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key,
|
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||||||
host=langfuse_keys.host)
|
try:
|
||||||
if langfuse.auth_check():
|
if langfuse.auth_check():
|
||||||
self.langfuse = langfuse
|
self.langfuse = langfuse
|
||||||
trace_id = self.langfuse.create_trace_id()
|
trace_id = self.langfuse.create_trace_id()
|
||||||
self.trace_context = {"trace_id": trace_id}
|
self.trace_context = {"trace_id": trace_id}
|
||||||
|
except Exception:
|
||||||
|
# Skip langfuse tracing if connection fails
|
||||||
|
pass
|
||||||
|
|||||||
@ -116,10 +116,13 @@ class UserService(CommonService):
|
|||||||
kwargs["password"] = generate_password_hash(
|
kwargs["password"] = generate_password_hash(
|
||||||
str(kwargs["password"]))
|
str(kwargs["password"]))
|
||||||
|
|
||||||
kwargs["create_time"] = current_timestamp()
|
current_ts = current_timestamp()
|
||||||
kwargs["create_date"] = datetime_format(datetime.now())
|
current_date = datetime_format(datetime.now())
|
||||||
kwargs["update_time"] = current_timestamp()
|
|
||||||
kwargs["update_date"] = datetime_format(datetime.now())
|
kwargs["create_time"] = current_ts
|
||||||
|
kwargs["create_date"] = current_date
|
||||||
|
kwargs["update_time"] = current_ts
|
||||||
|
kwargs["update_date"] = current_date
|
||||||
obj = cls.model(**kwargs).save(force_insert=True)
|
obj = cls.model(**kwargs).save(force_insert=True)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
@ -161,7 +164,7 @@ class UserService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_all_users(cls):
|
def get_all_users(cls):
|
||||||
users = cls.model.select()
|
users = cls.model.select().order_by(cls.model.email)
|
||||||
return list(users)
|
return list(users)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -18,8 +18,8 @@
|
|||||||
# from beartype.claw import beartype_all # <-- you didn't sign up for this
|
# from beartype.claw import beartype_all # <-- you didn't sign up for this
|
||||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||||
|
|
||||||
from common.log_utils import init_root_logger
|
import time
|
||||||
from plugin import GlobalPluginManager
|
start_ts = time.time()
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -40,6 +40,8 @@ from api.db.init_data import init_web_data, init_superuser
|
|||||||
from common.versions import get_ragflow_version
|
from common.versions import get_ragflow_version
|
||||||
from common.config_utils import show_configs
|
from common.config_utils import show_configs
|
||||||
from common.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
from common.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||||
|
from common.log_utils import init_root_logger
|
||||||
|
from plugin import GlobalPluginManager
|
||||||
from rag.utils.redis_conn import RedisDistributedLock
|
from rag.utils.redis_conn import RedisDistributedLock
|
||||||
|
|
||||||
stop_event = threading.Event()
|
stop_event = threading.Event()
|
||||||
@ -145,7 +147,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# start http server
|
# start http server
|
||||||
try:
|
try:
|
||||||
logging.info("RAGFlow HTTP server start...")
|
logging.info(f"RAGFlow server is ready after {time.time() - start_ts}s initialization.")
|
||||||
app.run(host=settings.HOST_IP, port=settings.HOST_PORT)
|
app.run(host=settings.HOST_IP, port=settings.HOST_PORT)
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|||||||
@ -29,8 +29,15 @@ import requests
|
|||||||
from quart import (
|
from quart import (
|
||||||
Response,
|
Response,
|
||||||
jsonify,
|
jsonify,
|
||||||
request
|
request,
|
||||||
|
has_app_context,
|
||||||
)
|
)
|
||||||
|
from werkzeug.exceptions import BadRequest as WerkzeugBadRequest
|
||||||
|
|
||||||
|
try:
|
||||||
|
from quart.exceptions import BadRequest as QuartBadRequest
|
||||||
|
except ImportError: # pragma: no cover - optional dependency
|
||||||
|
QuartBadRequest = None
|
||||||
|
|
||||||
from peewee import OperationalError
|
from peewee import OperationalError
|
||||||
|
|
||||||
@ -42,41 +49,45 @@ from api.db.services.tenant_llm_service import LLMFactoriesService
|
|||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
from common import settings
|
from common import settings
|
||||||
|
from common.misc_utils import thread_pool_exec
|
||||||
|
|
||||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||||
|
|
||||||
|
def _safe_jsonify(payload: dict):
|
||||||
|
if has_app_context():
|
||||||
|
return jsonify(payload)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
async def _coerce_request_data() -> dict:
|
async def _coerce_request_data() -> dict:
|
||||||
"""Fetch JSON body with sane defaults; fallback to form data."""
|
"""Fetch JSON body with sane defaults; fallback to form data."""
|
||||||
|
if hasattr(request, "_cached_payload"):
|
||||||
|
return request._cached_payload
|
||||||
payload: Any = None
|
payload: Any = None
|
||||||
last_error: Exception | None = None
|
|
||||||
|
|
||||||
try:
|
body_bytes = await request.get_data()
|
||||||
payload = await request.get_json(force=True, silent=True)
|
has_body = bool(body_bytes)
|
||||||
except Exception as e:
|
content_type = (request.content_type or "").lower()
|
||||||
last_error = e
|
is_json = content_type.startswith("application/json")
|
||||||
payload = None
|
|
||||||
|
|
||||||
if payload is None:
|
if not has_body:
|
||||||
try:
|
payload = {}
|
||||||
form = await request.form
|
elif is_json:
|
||||||
payload = form.to_dict()
|
payload = await request.get_json(force=False, silent=False)
|
||||||
except Exception as e:
|
if isinstance(payload, dict):
|
||||||
last_error = e
|
payload = payload or {}
|
||||||
payload = None
|
elif isinstance(payload, str):
|
||||||
|
raise AttributeError("'str' object has no attribute 'get'")
|
||||||
|
else:
|
||||||
|
raise TypeError("JSON payload must be an object.")
|
||||||
|
else:
|
||||||
|
form = await request.form
|
||||||
|
payload = form.to_dict() if form else None
|
||||||
|
if payload is None:
|
||||||
|
raise TypeError("Request body is not a valid form payload.")
|
||||||
|
|
||||||
if payload is None:
|
request._cached_payload = payload
|
||||||
if last_error is not None:
|
return payload
|
||||||
raise last_error
|
|
||||||
raise ValueError("No JSON body or form data found in request.")
|
|
||||||
|
|
||||||
if isinstance(payload, dict):
|
|
||||||
return payload or {}
|
|
||||||
|
|
||||||
if isinstance(payload, str):
|
|
||||||
raise AttributeError("'str' object has no attribute 'get'")
|
|
||||||
|
|
||||||
raise TypeError(f"Unsupported request payload type: {type(payload)!r}")
|
|
||||||
|
|
||||||
async def get_request_json():
|
async def get_request_json():
|
||||||
return await _coerce_request_data()
|
return await _coerce_request_data()
|
||||||
@ -115,7 +126,7 @@ def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
response[key] = value
|
response[key] = value
|
||||||
return jsonify(response)
|
return _safe_jsonify(response)
|
||||||
|
|
||||||
|
|
||||||
def server_error_response(e):
|
def server_error_response(e):
|
||||||
@ -124,16 +135,12 @@ def server_error_response(e):
|
|||||||
try:
|
try:
|
||||||
msg = repr(e).lower()
|
msg = repr(e).lower()
|
||||||
if getattr(e, "code", None) == 401 or ("unauthorized" in msg) or ("401" in msg):
|
if getattr(e, "code", None) == 401 or ("unauthorized" in msg) or ("401" in msg):
|
||||||
return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
|
resp = get_json_result(code=RetCode.UNAUTHORIZED, message="Unauthorized")
|
||||||
|
resp.status_code = RetCode.UNAUTHORIZED
|
||||||
|
return resp
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.warning(f"error checking authorization: {ex}")
|
logging.warning(f"error checking authorization: {ex}")
|
||||||
|
|
||||||
if len(e.args) > 1:
|
|
||||||
try:
|
|
||||||
serialized_data = serialize_for_json(e.args[1])
|
|
||||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data)
|
|
||||||
except Exception:
|
|
||||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None)
|
|
||||||
if repr(e).find("index_not_found_exception") >= 0:
|
if repr(e).find("index_not_found_exception") >= 0:
|
||||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
|
return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
|
||||||
|
|
||||||
@ -163,11 +170,22 @@ def validate_request(*args, **kwargs):
|
|||||||
if error_arguments:
|
if error_arguments:
|
||||||
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||||
return error_string
|
return error_string
|
||||||
|
return None
|
||||||
|
|
||||||
def wrapper(func):
|
def wrapper(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def decorated_function(*_args, **_kwargs):
|
async def decorated_function(*_args, **_kwargs):
|
||||||
errs = process_args(await _coerce_request_data())
|
exception_types = (AttributeError, TypeError, WerkzeugBadRequest)
|
||||||
|
if QuartBadRequest is not None:
|
||||||
|
exception_types = exception_types + (QuartBadRequest,)
|
||||||
|
if args or kwargs:
|
||||||
|
try:
|
||||||
|
input_arguments = await _coerce_request_data()
|
||||||
|
except exception_types:
|
||||||
|
input_arguments = {}
|
||||||
|
else:
|
||||||
|
input_arguments = await _coerce_request_data()
|
||||||
|
errs = process_args(input_arguments)
|
||||||
if errs:
|
if errs:
|
||||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
|
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
|
||||||
if inspect.iscoroutinefunction(func):
|
if inspect.iscoroutinefunction(func):
|
||||||
@ -214,7 +232,7 @@ def active_required(func):
|
|||||||
|
|
||||||
def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
|
def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
|
||||||
response = {"code": code, "message": message, "data": data}
|
response = {"code": code, "message": message, "data": data}
|
||||||
return jsonify(response)
|
return _safe_jsonify(response)
|
||||||
|
|
||||||
|
|
||||||
def apikey_required(func):
|
def apikey_required(func):
|
||||||
@ -235,16 +253,16 @@ def apikey_required(func):
|
|||||||
|
|
||||||
def build_error_result(code=RetCode.FORBIDDEN, message="success"):
|
def build_error_result(code=RetCode.FORBIDDEN, message="success"):
|
||||||
response = {"code": code, "message": message}
|
response = {"code": code, "message": message}
|
||||||
response = jsonify(response)
|
response = _safe_jsonify(response)
|
||||||
response.status_code = code
|
if hasattr(response, "status_code"):
|
||||||
|
response.status_code = code
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
|
def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
|
||||||
if data is None:
|
if data is None:
|
||||||
return jsonify({"code": code, "message": message})
|
return _safe_jsonify({"code": code, "message": message})
|
||||||
else:
|
return _safe_jsonify({"code": code, "message": message, "data": data})
|
||||||
return jsonify({"code": code, "message": message, "data": data})
|
|
||||||
|
|
||||||
|
|
||||||
def token_required(func):
|
def token_required(func):
|
||||||
@ -303,7 +321,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None, total=None):
|
|||||||
else:
|
else:
|
||||||
response["message"] = message or "Error"
|
response["message"] = message or "Error"
|
||||||
|
|
||||||
return jsonify(response)
|
return _safe_jsonify(response)
|
||||||
|
|
||||||
|
|
||||||
def get_error_data_result(
|
def get_error_data_result(
|
||||||
@ -317,7 +335,7 @@ def get_error_data_result(
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
response[key] = value
|
response[key] = value
|
||||||
return jsonify(response)
|
return _safe_jsonify(response)
|
||||||
|
|
||||||
|
|
||||||
def get_error_argument_result(message="Invalid arguments"):
|
def get_error_argument_result(message="Invalid arguments"):
|
||||||
@ -409,7 +427,7 @@ def get_parser_config(chunk_method, parser_config):
|
|||||||
if default_config is None:
|
if default_config is None:
|
||||||
return deep_merge(base_defaults, parser_config)
|
return deep_merge(base_defaults, parser_config)
|
||||||
|
|
||||||
# Ensure raptor and graphrag fields have default values if not provided
|
# Ensure raptor and graph_rag fields have default values if not provided
|
||||||
merged_config = deep_merge(base_defaults, default_config)
|
merged_config = deep_merge(base_defaults, default_config)
|
||||||
merged_config = deep_merge(merged_config, parser_config)
|
merged_config = deep_merge(merged_config, parser_config)
|
||||||
|
|
||||||
@ -682,7 +700,7 @@ async def is_strong_enough(chat_model, embedding_model):
|
|||||||
nonlocal chat_model, embedding_model
|
nonlocal chat_model, embedding_model
|
||||||
if embedding_model:
|
if embedding_model:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]),
|
thread_pool_exec(embedding_model.encode, ["Are you strong enough!?"]),
|
||||||
timeout=10
|
timeout=10
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import xxhash
|
||||||
|
|
||||||
|
|
||||||
def string_to_bytes(string):
|
def string_to_bytes(string):
|
||||||
return string if isinstance(
|
return string if isinstance(
|
||||||
@ -22,3 +24,6 @@ def string_to_bytes(string):
|
|||||||
def bytes_to_string(byte):
|
def bytes_to_string(byte):
|
||||||
return byte.decode(encoding="utf-8")
|
return byte.decode(encoding="utf-8")
|
||||||
|
|
||||||
|
# 128 bit = 32 character
|
||||||
|
def hash128(data: str) -> str:
|
||||||
|
return xxhash.xxh128(data).hexdigest()
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from common.file_utils import get_project_base_directory
|
|||||||
|
|
||||||
def crypt(line):
|
def crypt(line):
|
||||||
"""
|
"""
|
||||||
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
|
decrypt(crypt(input_string)) == base64(input_string), which frontend and ragflow_cli use.
|
||||||
"""
|
"""
|
||||||
file_path = os.path.join(get_project_base_directory(), "conf", "public.pem")
|
file_path = os.path.join(get_project_base_directory(), "conf", "public.pem")
|
||||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||||
|
|||||||
@ -42,7 +42,7 @@ def filename_type(filename):
|
|||||||
if re.match(r".*\.pdf$", filename):
|
if re.match(r".*\.pdf$", filename):
|
||||||
return FileType.PDF.value
|
return FileType.PDF.value
|
||||||
|
|
||||||
if re.match(r".*\.(msg|eml|doc|docx|ppt|pptx|yml|xml|htm|json|jsonl|ldjson|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
|
if re.match(r".*\.(msg|eml|doc|docx|ppt|pptx|yml|xml|htm|json|jsonl|ldjson|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|mdx|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
|
||||||
return FileType.DOC.value
|
return FileType.DOC.value
|
||||||
|
|
||||||
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus)$", filename):
|
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus)$", filename):
|
||||||
|
|||||||
@ -82,6 +82,8 @@ async def validate_and_parse_json_request(request: Request, validator: type[Base
|
|||||||
2. Extra fields added via `extras` parameter are automatically removed
|
2. Extra fields added via `extras` parameter are automatically removed
|
||||||
from the final output after validation
|
from the final output after validation
|
||||||
"""
|
"""
|
||||||
|
if request.mimetype != "application/json":
|
||||||
|
return None, f"Unsupported content type: Expected application/json, got {request.content_type}"
|
||||||
try:
|
try:
|
||||||
payload = await request.get_json() or {}
|
payload = await request.get_json() or {}
|
||||||
except UnsupportedMediaType:
|
except UnsupportedMediaType:
|
||||||
|
|||||||
@ -69,6 +69,7 @@ CONTENT_TYPE_MAP = {
|
|||||||
# Web
|
# Web
|
||||||
"md": "text/markdown",
|
"md": "text/markdown",
|
||||||
"markdown": "text/markdown",
|
"markdown": "text/markdown",
|
||||||
|
"mdx": "text/markdown",
|
||||||
"htm": "text/html",
|
"htm": "text/html",
|
||||||
"html": "text/html",
|
"html": "text/html",
|
||||||
"json": "application/json",
|
"json": "application/json",
|
||||||
@ -85,6 +86,9 @@ CONTENT_TYPE_MAP = {
|
|||||||
"ico": "image/x-icon",
|
"ico": "image/x-icon",
|
||||||
"avif": "image/avif",
|
"avif": "image/avif",
|
||||||
"heic": "image/heic",
|
"heic": "image/heic",
|
||||||
|
# PPTX
|
||||||
|
"ppt": "application/vnd.ms-powerpoint",
|
||||||
|
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from strenum import StrEnum
|
|||||||
SERVICE_CONF = "service_conf.yaml"
|
SERVICE_CONF = "service_conf.yaml"
|
||||||
RAG_FLOW_SERVICE_NAME = "ragflow"
|
RAG_FLOW_SERVICE_NAME = "ragflow"
|
||||||
|
|
||||||
|
|
||||||
class CustomEnum(Enum):
|
class CustomEnum(Enum):
|
||||||
@classmethod
|
@classmethod
|
||||||
def valid(cls, value):
|
def valid(cls, value):
|
||||||
@ -54,6 +55,7 @@ class RetCode(IntEnum, CustomEnum):
|
|||||||
SERVER_ERROR = 500
|
SERVER_ERROR = 500
|
||||||
FORBIDDEN = 403
|
FORBIDDEN = 403
|
||||||
NOT_FOUND = 404
|
NOT_FOUND = 404
|
||||||
|
CONFLICT = 409
|
||||||
|
|
||||||
|
|
||||||
class StatusEnum(Enum):
|
class StatusEnum(Enum):
|
||||||
@ -67,13 +69,13 @@ class ActiveEnum(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class LLMType(StrEnum):
|
class LLMType(StrEnum):
|
||||||
CHAT = 'chat'
|
CHAT = "chat"
|
||||||
EMBEDDING = 'embedding'
|
EMBEDDING = "embedding"
|
||||||
SPEECH2TEXT = 'speech2text'
|
SPEECH2TEXT = "speech2text"
|
||||||
IMAGE2TEXT = 'image2text'
|
IMAGE2TEXT = "image2text"
|
||||||
RERANK = 'rerank'
|
RERANK = "rerank"
|
||||||
TTS = 'tts'
|
TTS = "tts"
|
||||||
OCR = 'ocr'
|
OCR = "ocr"
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(StrEnum):
|
class TaskStatus(StrEnum):
|
||||||
@ -85,8 +87,7 @@ class TaskStatus(StrEnum):
|
|||||||
SCHEDULE = "5"
|
SCHEDULE = "5"
|
||||||
|
|
||||||
|
|
||||||
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL,
|
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL, TaskStatus.SCHEDULE}
|
||||||
TaskStatus.SCHEDULE}
|
|
||||||
|
|
||||||
|
|
||||||
class ParserType(StrEnum):
|
class ParserType(StrEnum):
|
||||||
@ -124,6 +125,17 @@ class FileSource(StrEnum):
|
|||||||
MOODLE = "moodle"
|
MOODLE = "moodle"
|
||||||
DROPBOX = "dropbox"
|
DROPBOX = "dropbox"
|
||||||
BOX = "box"
|
BOX = "box"
|
||||||
|
R2 = "r2"
|
||||||
|
OCI_STORAGE = "oci_storage"
|
||||||
|
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||||
|
AIRTABLE = "airtable"
|
||||||
|
ASANA = "asana"
|
||||||
|
GITHUB = "github"
|
||||||
|
GITLAB = "gitlab"
|
||||||
|
IMAP = "imap"
|
||||||
|
BITBUCKET = "bitbucket"
|
||||||
|
ZENDESK = "zendesk"
|
||||||
|
|
||||||
|
|
||||||
class PipelineTaskType(StrEnum):
|
class PipelineTaskType(StrEnum):
|
||||||
PARSE = "Parse"
|
PARSE = "Parse"
|
||||||
@ -131,17 +143,20 @@ class PipelineTaskType(StrEnum):
|
|||||||
RAPTOR = "RAPTOR"
|
RAPTOR = "RAPTOR"
|
||||||
GRAPH_RAG = "GraphRAG"
|
GRAPH_RAG = "GraphRAG"
|
||||||
MINDMAP = "Mindmap"
|
MINDMAP = "Mindmap"
|
||||||
|
MEMORY = "Memory"
|
||||||
|
|
||||||
|
|
||||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR,
|
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
|
||||||
PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
|
|
||||||
|
|
||||||
class MCPServerType(StrEnum):
|
class MCPServerType(StrEnum):
|
||||||
SSE = "sse"
|
SSE = "sse"
|
||||||
STREAMABLE_HTTP = "streamable-http"
|
STREAMABLE_HTTP = "streamable-http"
|
||||||
|
|
||||||
|
|
||||||
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
||||||
|
|
||||||
|
|
||||||
class Storage(Enum):
|
class Storage(Enum):
|
||||||
MINIO = 1
|
MINIO = 1
|
||||||
AZURE_SPN = 2
|
AZURE_SPN = 2
|
||||||
@ -153,10 +168,10 @@ class Storage(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class MemoryType(Enum):
|
class MemoryType(Enum):
|
||||||
RAW = 0b0001 # 1 << 0 = 1 (0b00000001)
|
RAW = 0b0001 # 1 << 0 = 1 (0b00000001)
|
||||||
SEMANTIC = 0b0010 # 1 << 1 = 2 (0b00000010)
|
SEMANTIC = 0b0010 # 1 << 1 = 2 (0b00000010)
|
||||||
EPISODIC = 0b0100 # 1 << 2 = 4 (0b00000100)
|
EPISODIC = 0b0100 # 1 << 2 = 4 (0b00000100)
|
||||||
PROCEDURAL = 0b1000 # 1 << 3 = 8 (0b00001000)
|
PROCEDURAL = 0b1000 # 1 << 3 = 8 (0b00001000)
|
||||||
|
|
||||||
|
|
||||||
class MemoryStorageType(StrEnum):
|
class MemoryStorageType(StrEnum):
|
||||||
@ -165,7 +180,7 @@ class MemoryStorageType(StrEnum):
|
|||||||
|
|
||||||
|
|
||||||
class ForgettingPolicy(StrEnum):
|
class ForgettingPolicy(StrEnum):
|
||||||
FIFO = "fifo"
|
FIFO = "FIFO"
|
||||||
|
|
||||||
|
|
||||||
# environment
|
# environment
|
||||||
@ -227,3 +242,10 @@ MINERU_DEFAULT_CONFIG = {
|
|||||||
"MINERU_SERVER_URL": "",
|
"MINERU_SERVER_URL": "",
|
||||||
"MINERU_DELETE_OUTPUT": 1,
|
"MINERU_DELETE_OUTPUT": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PADDLEOCR_ENV_KEYS = ["PADDLEOCR_API_URL", "PADDLEOCR_ACCESS_TOKEN", "PADDLEOCR_ALGORITHM"]
|
||||||
|
PADDLEOCR_DEFAULT_CONFIG = {
|
||||||
|
"PADDLEOCR_API_URL": "",
|
||||||
|
"PADDLEOCR_ACCESS_TOKEN": None,
|
||||||
|
"PADDLEOCR_ALGORITHM": "PaddleOCR-VL",
|
||||||
|
}
|
||||||
|
|||||||
@ -34,8 +34,11 @@ from .google_drive.connector import GoogleDriveConnector
|
|||||||
from .jira.connector import JiraConnector
|
from .jira.connector import JiraConnector
|
||||||
from .sharepoint_connector import SharePointConnector
|
from .sharepoint_connector import SharePointConnector
|
||||||
from .teams_connector import TeamsConnector
|
from .teams_connector import TeamsConnector
|
||||||
from .webdav_connector import WebDAVConnector
|
|
||||||
from .moodle_connector import MoodleConnector
|
from .moodle_connector import MoodleConnector
|
||||||
|
from .airtable_connector import AirtableConnector
|
||||||
|
from .asana_connector import AsanaConnector
|
||||||
|
from .imap_connector import ImapConnector
|
||||||
|
from .zendesk_connector import ZendeskConnector
|
||||||
from .config import BlobType, DocumentSource
|
from .config import BlobType, DocumentSource
|
||||||
from .models import Document, TextSection, ImageSection, BasicExpertInfo
|
from .models import Document, TextSection, ImageSection, BasicExpertInfo
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
@ -58,7 +61,6 @@ __all__ = [
|
|||||||
"JiraConnector",
|
"JiraConnector",
|
||||||
"SharePointConnector",
|
"SharePointConnector",
|
||||||
"TeamsConnector",
|
"TeamsConnector",
|
||||||
"WebDAVConnector",
|
|
||||||
"MoodleConnector",
|
"MoodleConnector",
|
||||||
"BlobType",
|
"BlobType",
|
||||||
"DocumentSource",
|
"DocumentSource",
|
||||||
@ -70,5 +72,9 @@ __all__ = [
|
|||||||
"ConnectorValidationError",
|
"ConnectorValidationError",
|
||||||
"CredentialExpiredError",
|
"CredentialExpiredError",
|
||||||
"InsufficientPermissionsError",
|
"InsufficientPermissionsError",
|
||||||
"UnexpectedValidationError"
|
"UnexpectedValidationError",
|
||||||
|
"AirtableConnector",
|
||||||
|
"AsanaConnector",
|
||||||
|
"ImapConnector",
|
||||||
|
"ZendeskConnector",
|
||||||
]
|
]
|
||||||
|
|||||||
168
common/data_source/airtable_connector.py
Normal file
168
common/data_source/airtable_connector.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
import logging
|
||||||
|
from typing import Any, Generator
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from pyairtable import Api as AirtableApi
|
||||||
|
|
||||||
|
from common.data_source.config import AIRTABLE_CONNECTOR_SIZE_THRESHOLD, INDEX_BATCH_SIZE, DocumentSource
|
||||||
|
from common.data_source.exceptions import ConnectorMissingCredentialError
|
||||||
|
from common.data_source.interfaces import LoadConnector, PollConnector
|
||||||
|
from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch
|
||||||
|
from common.data_source.utils import extract_size_bytes, get_file_ext
|
||||||
|
|
||||||
|
class AirtableClientNotSetUpError(PermissionError):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(
|
||||||
|
"Airtable client is not set up. Did you forget to call load_credentials()?"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AirtableConnector(LoadConnector, PollConnector):
|
||||||
|
"""
|
||||||
|
Lightweight Airtable connector.
|
||||||
|
|
||||||
|
This connector ingests Airtable attachments as raw blobs without
|
||||||
|
parsing file content or generating text/image sections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_id: str,
|
||||||
|
table_name_or_id: str,
|
||||||
|
batch_size: int = INDEX_BATCH_SIZE,
|
||||||
|
) -> None:
|
||||||
|
self.base_id = base_id
|
||||||
|
self.table_name_or_id = table_name_or_id
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self._airtable_client: AirtableApi | None = None
|
||||||
|
self.size_threshold = AIRTABLE_CONNECTOR_SIZE_THRESHOLD
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Credentials
|
||||||
|
# -------------------------
|
||||||
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
self._airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def airtable_client(self) -> AirtableApi:
|
||||||
|
if not self._airtable_client:
|
||||||
|
raise AirtableClientNotSetUpError()
|
||||||
|
return self._airtable_client
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Core logic
|
||||||
|
# -------------------------
|
||||||
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||||
|
"""
|
||||||
|
Fetch all Airtable records and ingest attachments as raw blobs.
|
||||||
|
|
||||||
|
Each attachment is converted into a single Document(blob=...).
|
||||||
|
"""
|
||||||
|
if not self._airtable_client:
|
||||||
|
raise ConnectorMissingCredentialError("Airtable credentials not loaded")
|
||||||
|
|
||||||
|
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
|
||||||
|
records = table.all()
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Starting Airtable blob ingestion for table {self.table_name_or_id}, "
|
||||||
|
f"{len(records)} records found."
|
||||||
|
)
|
||||||
|
|
||||||
|
batch: list[Document] = []
|
||||||
|
|
||||||
|
for record in records:
|
||||||
|
record_id = record.get("id")
|
||||||
|
fields = record.get("fields", {})
|
||||||
|
created_time = record.get("createdTime")
|
||||||
|
|
||||||
|
for field_value in fields.values():
|
||||||
|
# We only care about attachment fields (lists of dicts with url/filename)
|
||||||
|
if not isinstance(field_value, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for attachment in field_value:
|
||||||
|
url = attachment.get("url")
|
||||||
|
filename = attachment.get("filename")
|
||||||
|
attachment_id = attachment.get("id")
|
||||||
|
|
||||||
|
if not url or not filename or not attachment_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = requests.get(url, timeout=30)
|
||||||
|
resp.raise_for_status()
|
||||||
|
content = resp.content
|
||||||
|
except Exception:
|
||||||
|
logging.exception(
|
||||||
|
f"Failed to download attachment {filename} "
|
||||||
|
f"(record={record_id})"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
size_bytes = extract_size_bytes(attachment)
|
||||||
|
if (
|
||||||
|
self.size_threshold is not None
|
||||||
|
and isinstance(size_bytes, int)
|
||||||
|
and size_bytes > self.size_threshold
|
||||||
|
):
|
||||||
|
logging.warning(
|
||||||
|
f"{filename} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
batch.append(
|
||||||
|
Document(
|
||||||
|
id=f"airtable:{record_id}:{attachment_id}",
|
||||||
|
blob=content,
|
||||||
|
source=DocumentSource.AIRTABLE,
|
||||||
|
semantic_identifier=filename,
|
||||||
|
extension=get_file_ext(filename),
|
||||||
|
size_bytes=size_bytes if size_bytes else 0,
|
||||||
|
doc_updated_at=datetime.strptime(created_time, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(batch) >= self.batch_size:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]:
|
||||||
|
"""Poll source to get documents"""
|
||||||
|
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||||
|
end_dt = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||||
|
|
||||||
|
for batch in self.load_from_state():
|
||||||
|
filtered: list[Document] = []
|
||||||
|
|
||||||
|
for doc in batch:
|
||||||
|
if not doc.doc_updated_at:
|
||||||
|
continue
|
||||||
|
|
||||||
|
doc_dt = doc.doc_updated_at.astimezone(timezone.utc)
|
||||||
|
|
||||||
|
if start_dt <= doc_dt < end_dt:
|
||||||
|
filtered.append(doc)
|
||||||
|
|
||||||
|
if filtered:
|
||||||
|
yield filtered
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import os
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
connector = AirtableConnector("xxx","xxx")
|
||||||
|
connector.load_credentials({"airtable_access_token": os.environ.get("AIRTABLE_ACCESS_TOKEN")})
|
||||||
|
connector.validate_connector_settings()
|
||||||
|
document_batches = connector.load_from_state()
|
||||||
|
try:
|
||||||
|
first_batch = next(document_batches)
|
||||||
|
print(f"Loaded {len(first_batch)} documents in first batch.")
|
||||||
|
for doc in first_batch:
|
||||||
|
print(f"- {doc.semantic_identifier} ({doc.size_bytes} bytes)")
|
||||||
|
except StopIteration:
|
||||||
|
print("No documents available in Dropbox.")
|
||||||
454
common/data_source/asana_connector.py
Normal file
454
common/data_source/asana_connector.py
Normal file
@ -0,0 +1,454 @@
|
|||||||
|
from collections.abc import Iterator
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict
|
||||||
|
import asana
|
||||||
|
import requests
|
||||||
|
from common.data_source.config import CONTINUE_ON_CONNECTOR_FAILURE, INDEX_BATCH_SIZE, DocumentSource
|
||||||
|
from common.data_source.interfaces import LoadConnector, PollConnector
|
||||||
|
from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch
|
||||||
|
from common.data_source.utils import extract_size_bytes, get_file_ext
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/Asana/python-asana/tree/master?tab=readme-ov-file#documentation-for-api-endpoints
|
||||||
|
class AsanaTask:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
title: str,
|
||||||
|
text: str,
|
||||||
|
link: str,
|
||||||
|
last_modified: datetime,
|
||||||
|
project_gid: str,
|
||||||
|
project_name: str,
|
||||||
|
) -> None:
|
||||||
|
self.id = id
|
||||||
|
self.title = title
|
||||||
|
self.text = text
|
||||||
|
self.link = link
|
||||||
|
self.last_modified = last_modified
|
||||||
|
self.project_gid = project_gid
|
||||||
|
self.project_name = project_name
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"ID: {self.id}\nTitle: {self.title}\nLast modified: {self.last_modified}\nText: {self.text}"
|
||||||
|
|
||||||
|
|
||||||
|
class AsanaAPI:
|
||||||
|
def __init__(
|
||||||
|
self, api_token: str, workspace_gid: str, team_gid: str | None
|
||||||
|
) -> None:
|
||||||
|
self._user = None
|
||||||
|
self.workspace_gid = workspace_gid
|
||||||
|
self.team_gid = team_gid
|
||||||
|
|
||||||
|
self.configuration = asana.Configuration()
|
||||||
|
self.api_client = asana.ApiClient(self.configuration)
|
||||||
|
self.tasks_api = asana.TasksApi(self.api_client)
|
||||||
|
self.attachments_api = asana.AttachmentsApi(self.api_client)
|
||||||
|
self.stories_api = asana.StoriesApi(self.api_client)
|
||||||
|
self.users_api = asana.UsersApi(self.api_client)
|
||||||
|
self.project_api = asana.ProjectsApi(self.api_client)
|
||||||
|
self.project_memberships_api = asana.ProjectMembershipsApi(self.api_client)
|
||||||
|
self.workspaces_api = asana.WorkspacesApi(self.api_client)
|
||||||
|
|
||||||
|
self.api_error_count = 0
|
||||||
|
self.configuration.access_token = api_token
|
||||||
|
self.task_count = 0
|
||||||
|
|
||||||
|
def get_tasks(
|
||||||
|
self, project_gids: list[str] | None, start_date: str
|
||||||
|
) -> Iterator[AsanaTask]:
|
||||||
|
"""Get all tasks from the projects with the given gids that were modified since the given date.
|
||||||
|
If project_gids is None, get all tasks from all projects in the workspace."""
|
||||||
|
logging.info("Starting to fetch Asana projects")
|
||||||
|
projects = self.project_api.get_projects(
|
||||||
|
opts={
|
||||||
|
"workspace": self.workspace_gid,
|
||||||
|
"opt_fields": "gid,name,archived,modified_at",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
start_seconds = int(time.mktime(datetime.now().timetuple()))
|
||||||
|
projects_list = []
|
||||||
|
project_count = 0
|
||||||
|
for project_info in projects:
|
||||||
|
project_gid = project_info["gid"]
|
||||||
|
if project_gids is None or project_gid in project_gids:
|
||||||
|
projects_list.append(project_gid)
|
||||||
|
else:
|
||||||
|
logging.debug(
|
||||||
|
f"Skipping project: {project_gid} - not in accepted project_gids"
|
||||||
|
)
|
||||||
|
project_count += 1
|
||||||
|
if project_count % 100 == 0:
|
||||||
|
logging.info(f"Processed {project_count} projects")
|
||||||
|
logging.info(f"Found {len(projects_list)} projects to process")
|
||||||
|
for project_gid in projects_list:
|
||||||
|
for task in self._get_tasks_for_project(
|
||||||
|
project_gid, start_date, start_seconds
|
||||||
|
):
|
||||||
|
yield task
|
||||||
|
logging.info(f"Completed fetching {self.task_count} tasks from Asana")
|
||||||
|
if self.api_error_count > 0:
|
||||||
|
logging.warning(
|
||||||
|
f"Encountered {self.api_error_count} API errors during task fetching"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_tasks_for_project(
|
||||||
|
self, project_gid: str, start_date: str, start_seconds: int
|
||||||
|
) -> Iterator[AsanaTask]:
|
||||||
|
project = self.project_api.get_project(project_gid, opts={})
|
||||||
|
project_name = project.get("name", project_gid)
|
||||||
|
team = project.get("team") or {}
|
||||||
|
team_gid = team.get("gid")
|
||||||
|
|
||||||
|
if project.get("archived"):
|
||||||
|
logging.info(f"Skipping archived project: {project_name} ({project_gid})")
|
||||||
|
return
|
||||||
|
if not team_gid:
|
||||||
|
logging.info(
|
||||||
|
f"Skipping project without a team: {project_name} ({project_gid})"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if project.get("privacy_setting") == "private":
|
||||||
|
if self.team_gid and team_gid != self.team_gid:
|
||||||
|
logging.info(
|
||||||
|
f"Skipping private project not in configured team: {project_name} ({project_gid})"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
logging.info(
|
||||||
|
f"Processing private project in configured team: {project_name} ({project_gid})"
|
||||||
|
)
|
||||||
|
|
||||||
|
simple_start_date = start_date.split(".")[0].split("+")[0]
|
||||||
|
logging.info(
|
||||||
|
f"Fetching tasks modified since {simple_start_date} for project: {project_name} ({project_gid})"
|
||||||
|
)
|
||||||
|
|
||||||
|
opts = {
|
||||||
|
"opt_fields": "name,memberships,memberships.project,completed_at,completed_by,created_at,"
|
||||||
|
"created_by,custom_fields,dependencies,due_at,due_on,external,html_notes,liked,likes,"
|
||||||
|
"modified_at,notes,num_hearts,parent,projects,resource_subtype,resource_type,start_on,"
|
||||||
|
"workspace,permalink_url",
|
||||||
|
"modified_since": start_date,
|
||||||
|
}
|
||||||
|
tasks_from_api = self.tasks_api.get_tasks_for_project(project_gid, opts)
|
||||||
|
for data in tasks_from_api:
|
||||||
|
self.task_count += 1
|
||||||
|
if self.task_count % 10 == 0:
|
||||||
|
end_seconds = time.mktime(datetime.now().timetuple())
|
||||||
|
runtime_seconds = end_seconds - start_seconds
|
||||||
|
if runtime_seconds > 0:
|
||||||
|
logging.info(
|
||||||
|
f"Processed {self.task_count} tasks in {runtime_seconds:.0f} seconds "
|
||||||
|
f"({self.task_count / runtime_seconds:.2f} tasks/second)"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.debug(f"Processing Asana task: {data['name']}")
|
||||||
|
|
||||||
|
text = self._construct_task_text(data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
text += self._fetch_and_add_comments(data["gid"])
|
||||||
|
|
||||||
|
last_modified_date = self.format_date(data["modified_at"])
|
||||||
|
text += f"Last modified: {last_modified_date}\n"
|
||||||
|
|
||||||
|
task = AsanaTask(
|
||||||
|
id=data["gid"],
|
||||||
|
title=data["name"],
|
||||||
|
text=text,
|
||||||
|
link=data["permalink_url"],
|
||||||
|
last_modified=datetime.fromisoformat(data["modified_at"]),
|
||||||
|
project_gid=project_gid,
|
||||||
|
project_name=project_name,
|
||||||
|
)
|
||||||
|
yield task
|
||||||
|
except Exception:
|
||||||
|
logging.error(
|
||||||
|
f"Error processing task {data['gid']} in project {project_gid}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
self.api_error_count += 1
|
||||||
|
|
||||||
|
def _construct_task_text(self, data: Dict) -> str:
|
||||||
|
text = f"{data['name']}\n\n"
|
||||||
|
|
||||||
|
if data["notes"]:
|
||||||
|
text += f"{data['notes']}\n\n"
|
||||||
|
|
||||||
|
if data["created_by"] and data["created_by"]["gid"]:
|
||||||
|
creator = self.get_user(data["created_by"]["gid"])["name"]
|
||||||
|
created_date = self.format_date(data["created_at"])
|
||||||
|
text += f"Created by: {creator} on {created_date}\n"
|
||||||
|
|
||||||
|
if data["due_on"]:
|
||||||
|
due_date = self.format_date(data["due_on"])
|
||||||
|
text += f"Due date: {due_date}\n"
|
||||||
|
|
||||||
|
if data["completed_at"]:
|
||||||
|
completed_date = self.format_date(data["completed_at"])
|
||||||
|
text += f"Completed on: {completed_date}\n"
|
||||||
|
|
||||||
|
text += "\n"
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _fetch_and_add_comments(self, task_gid: str) -> str:
|
||||||
|
text = ""
|
||||||
|
stories_opts: Dict[str, str] = {}
|
||||||
|
story_start = time.time()
|
||||||
|
stories = self.stories_api.get_stories_for_task(task_gid, stories_opts)
|
||||||
|
|
||||||
|
story_count = 0
|
||||||
|
comment_count = 0
|
||||||
|
|
||||||
|
for story in stories:
|
||||||
|
story_count += 1
|
||||||
|
if story["resource_subtype"] == "comment_added":
|
||||||
|
comment = self.stories_api.get_story(
|
||||||
|
story["gid"], opts={"opt_fields": "text,created_by,created_at"}
|
||||||
|
)
|
||||||
|
commenter = self.get_user(comment["created_by"]["gid"])["name"]
|
||||||
|
text += f"Comment by {commenter}: {comment['text']}\n\n"
|
||||||
|
comment_count += 1
|
||||||
|
|
||||||
|
story_duration = time.time() - story_start
|
||||||
|
logging.debug(
|
||||||
|
f"Processed {story_count} stories (including {comment_count} comments) in {story_duration:.2f} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def get_attachments(self, task_gid: str) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Fetch full attachment info (including download_url) for a task.
|
||||||
|
"""
|
||||||
|
attachments: list[dict] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 1: list attachment compact records
|
||||||
|
for att in self.attachments_api.get_attachments_for_object(
|
||||||
|
parent=task_gid,
|
||||||
|
opts={}
|
||||||
|
):
|
||||||
|
gid = att.get("gid")
|
||||||
|
if not gid:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 2: expand to full attachment
|
||||||
|
full = self.attachments_api.get_attachment(
|
||||||
|
attachment_gid=gid,
|
||||||
|
opts={
|
||||||
|
"opt_fields": "name,download_url,size,created_at"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if full.get("download_url"):
|
||||||
|
attachments.append(full)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logging.exception(
|
||||||
|
f"Failed to fetch attachment detail {gid} for task {task_gid}"
|
||||||
|
)
|
||||||
|
self.api_error_count += 1
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"Failed to list attachments for task {task_gid}")
|
||||||
|
self.api_error_count += 1
|
||||||
|
|
||||||
|
return attachments
|
||||||
|
|
||||||
|
def get_accessible_emails(
|
||||||
|
self,
|
||||||
|
workspace_id: str,
|
||||||
|
project_ids: list[str] | None,
|
||||||
|
team_id: str | None,
|
||||||
|
):
|
||||||
|
|
||||||
|
ws_users = self.users_api.get_users(
|
||||||
|
opts={
|
||||||
|
"workspace": workspace_id,
|
||||||
|
"opt_fields": "gid,name,email"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
workspace_users = {
|
||||||
|
u["gid"]: u.get("email")
|
||||||
|
for u in ws_users
|
||||||
|
if u.get("email")
|
||||||
|
}
|
||||||
|
|
||||||
|
if not project_ids:
|
||||||
|
return set(workspace_users.values())
|
||||||
|
|
||||||
|
|
||||||
|
project_emails = set()
|
||||||
|
|
||||||
|
for pid in project_ids:
|
||||||
|
project = self.project_api.get_project(
|
||||||
|
pid,
|
||||||
|
opts={"opt_fields": "team,privacy_setting"}
|
||||||
|
)
|
||||||
|
|
||||||
|
if project["privacy_setting"] == "private":
|
||||||
|
if team_id and project.get("team", {}).get("gid") != team_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
memberships = self.project_memberships_api.get_project_membership(
|
||||||
|
pid,
|
||||||
|
opts={"opt_fields": "user.gid,user.email"}
|
||||||
|
)
|
||||||
|
|
||||||
|
for m in memberships:
|
||||||
|
email = m["user"].get("email")
|
||||||
|
if email:
|
||||||
|
project_emails.add(email)
|
||||||
|
|
||||||
|
return project_emails
|
||||||
|
|
||||||
|
def get_user(self, user_gid: str) -> Dict:
|
||||||
|
if self._user is not None:
|
||||||
|
return self._user
|
||||||
|
self._user = self.users_api.get_user(user_gid, {"opt_fields": "name,email"})
|
||||||
|
|
||||||
|
if not self._user:
|
||||||
|
logging.warning(f"Unable to fetch user information for user_gid: {user_gid}")
|
||||||
|
return {"name": "Unknown"}
|
||||||
|
return self._user
|
||||||
|
|
||||||
|
def format_date(self, date_str: str) -> str:
|
||||||
|
date = datetime.fromisoformat(date_str)
|
||||||
|
return time.strftime("%Y-%m-%d", date.timetuple())
|
||||||
|
|
||||||
|
def get_time(self) -> str:
|
||||||
|
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||||
|
|
||||||
|
|
||||||
|
class AsanaConnector(LoadConnector, PollConnector):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
asana_workspace_id: str,
|
||||||
|
asana_project_ids: str | None = None,
|
||||||
|
asana_team_id: str | None = None,
|
||||||
|
batch_size: int = INDEX_BATCH_SIZE,
|
||||||
|
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||||
|
) -> None:
|
||||||
|
self.workspace_id = asana_workspace_id
|
||||||
|
self.project_ids_to_index: list[str] | None = (
|
||||||
|
asana_project_ids.split(",") if asana_project_ids else None
|
||||||
|
)
|
||||||
|
self.asana_team_id = asana_team_id if asana_team_id else None
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.continue_on_failure = continue_on_failure
|
||||||
|
self.size_threshold = None
|
||||||
|
logging.info(
|
||||||
|
f"AsanaConnector initialized with workspace_id: {asana_workspace_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
self.api_token = credentials["asana_api_token_secret"]
|
||||||
|
self.asana_client = AsanaAPI(
|
||||||
|
api_token=self.api_token,
|
||||||
|
workspace_gid=self.workspace_id,
|
||||||
|
team_gid=self.asana_team_id,
|
||||||
|
)
|
||||||
|
self.workspace_users_email = self.asana_client.get_accessible_emails(self.workspace_id, self.project_ids_to_index, self.asana_team_id)
|
||||||
|
logging.info("Asana credentials loaded and API client initialized")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def poll_source(
|
||||||
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None
|
||||||
|
) -> GenerateDocumentsOutput:
|
||||||
|
start_time = datetime.fromtimestamp(start).isoformat()
|
||||||
|
logging.info(f"Starting Asana poll from {start_time}")
|
||||||
|
docs_batch: list[Document] = []
|
||||||
|
tasks = self.asana_client.get_tasks(self.project_ids_to_index, start_time)
|
||||||
|
for task in tasks:
|
||||||
|
docs = self._task_to_documents(task)
|
||||||
|
docs_batch.extend(docs)
|
||||||
|
|
||||||
|
if len(docs_batch) >= self.batch_size:
|
||||||
|
logging.info(f"Yielding batch of {len(docs_batch)} documents")
|
||||||
|
yield docs_batch
|
||||||
|
docs_batch = []
|
||||||
|
|
||||||
|
if docs_batch:
|
||||||
|
logging.info(f"Yielding final batch of {len(docs_batch)} documents")
|
||||||
|
yield docs_batch
|
||||||
|
|
||||||
|
logging.info("Asana poll completed")
|
||||||
|
|
||||||
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||||
|
logging.info("Starting full index of all Asana tasks")
|
||||||
|
return self.poll_source(start=0, end=None)
|
||||||
|
|
||||||
|
def _task_to_documents(self, task: AsanaTask) -> list[Document]:
|
||||||
|
docs: list[Document] = []
|
||||||
|
|
||||||
|
attachments = self.asana_client.get_attachments(task.id)
|
||||||
|
|
||||||
|
for att in attachments:
|
||||||
|
try:
|
||||||
|
resp = requests.get(att["download_url"], timeout=30)
|
||||||
|
resp.raise_for_status()
|
||||||
|
file_blob = resp.content
|
||||||
|
filename = att.get("name", "attachment")
|
||||||
|
size_bytes = extract_size_bytes(att)
|
||||||
|
if (
|
||||||
|
self.size_threshold is not None
|
||||||
|
and isinstance(size_bytes, int)
|
||||||
|
and size_bytes > self.size_threshold
|
||||||
|
):
|
||||||
|
logging.warning(
|
||||||
|
f"{filename} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
docs.append(
|
||||||
|
Document(
|
||||||
|
id=f"asana:{task.id}:{att['gid']}",
|
||||||
|
blob=file_blob,
|
||||||
|
extension=get_file_ext(filename) or "",
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
doc_updated_at=task.last_modified,
|
||||||
|
source=DocumentSource.ASANA,
|
||||||
|
semantic_identifier=filename,
|
||||||
|
primary_owners=list(self.workspace_users_email),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.exception(
|
||||||
|
f"Failed to download attachment {att.get('gid')} for task {task.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
|
logging.info("Starting Asana connector test")
|
||||||
|
connector = AsanaConnector(
|
||||||
|
os.environ["WORKSPACE_ID"],
|
||||||
|
os.environ["PROJECT_IDS"],
|
||||||
|
os.environ["TEAM_ID"],
|
||||||
|
)
|
||||||
|
connector.load_credentials(
|
||||||
|
{
|
||||||
|
"asana_api_token_secret": os.environ["API_TOKEN"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logging.info("Loading all documents from Asana")
|
||||||
|
all_docs = connector.load_from_state()
|
||||||
|
current = time.time()
|
||||||
|
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||||
|
logging.info("Polling for documents updated in the last 24 hours")
|
||||||
|
latest_docs = connector.poll_source(one_day_ago, current)
|
||||||
|
for docs in all_docs:
|
||||||
|
for doc in docs:
|
||||||
|
print(doc.id)
|
||||||
|
logging.info("Asana connector test completed")
|
||||||
0
common/data_source/bitbucket/__init__.py
Normal file
0
common/data_source/bitbucket/__init__.py
Normal file
388
common/data_source/bitbucket/connector.py
Normal file
388
common/data_source/bitbucket/connector.py
Normal file
@ -0,0 +1,388 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from collections.abc import Callable
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from datetime import datetime
|
||||||
|
from datetime import timezone
|
||||||
|
from typing import Any
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from common.data_source.config import INDEX_BATCH_SIZE
|
||||||
|
from common.data_source.config import DocumentSource
|
||||||
|
from common.data_source.config import REQUEST_TIMEOUT_SECONDS
|
||||||
|
from common.data_source.exceptions import (
|
||||||
|
ConnectorMissingCredentialError,
|
||||||
|
CredentialExpiredError,
|
||||||
|
InsufficientPermissionsError,
|
||||||
|
UnexpectedValidationError,
|
||||||
|
)
|
||||||
|
from common.data_source.interfaces import CheckpointedConnector
|
||||||
|
from common.data_source.interfaces import CheckpointOutput
|
||||||
|
from common.data_source.interfaces import IndexingHeartbeatInterface
|
||||||
|
from common.data_source.interfaces import SecondsSinceUnixEpoch
|
||||||
|
from common.data_source.interfaces import SlimConnectorWithPermSync
|
||||||
|
from common.data_source.models import ConnectorCheckpoint
|
||||||
|
from common.data_source.models import ConnectorFailure
|
||||||
|
from common.data_source.models import DocumentFailure
|
||||||
|
from common.data_source.models import SlimDocument
|
||||||
|
from common.data_source.bitbucket.utils import (
|
||||||
|
build_auth_client,
|
||||||
|
list_repositories,
|
||||||
|
map_pr_to_document,
|
||||||
|
paginate,
|
||||||
|
PR_LIST_RESPONSE_FIELDS,
|
||||||
|
SLIM_PR_LIST_RESPONSE_FIELDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
class BitbucketConnectorCheckpoint(ConnectorCheckpoint):
|
||||||
|
"""Checkpoint state for resumable Bitbucket PR indexing.
|
||||||
|
|
||||||
|
Fields:
|
||||||
|
repos_queue: Materialized list of repository slugs to process.
|
||||||
|
current_repo_index: Index of the repository currently being processed.
|
||||||
|
next_url: Bitbucket "next" URL for continuing pagination within the current repo.
|
||||||
|
"""
|
||||||
|
|
||||||
|
repos_queue: list[str] = []
|
||||||
|
current_repo_index: int = 0
|
||||||
|
next_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BitbucketConnector(
|
||||||
|
CheckpointedConnector[BitbucketConnectorCheckpoint],
|
||||||
|
SlimConnectorWithPermSync,
|
||||||
|
):
|
||||||
|
"""Connector for indexing Bitbucket Cloud pull requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace: Bitbucket workspace ID.
|
||||||
|
repositories: Comma-separated list of repository slugs to index.
|
||||||
|
projects: Comma-separated list of project keys to index all repositories within.
|
||||||
|
batch_size: Max number of documents to yield per batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workspace: str,
|
||||||
|
repositories: str | None = None,
|
||||||
|
projects: str | None = None,
|
||||||
|
batch_size: int = INDEX_BATCH_SIZE,
|
||||||
|
) -> None:
|
||||||
|
self.workspace = workspace
|
||||||
|
self._repositories = (
|
||||||
|
[s.strip() for s in repositories.split(",") if s.strip()]
|
||||||
|
if repositories
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self._projects: list[str] | None = (
|
||||||
|
[s.strip() for s in projects.split(",") if s.strip()] if projects else None
|
||||||
|
)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.email: str | None = None
|
||||||
|
self.api_token: str | None = None
|
||||||
|
|
||||||
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Load API token-based credentials.
|
||||||
|
|
||||||
|
Expects a dict with keys: `bitbucket_email`, `bitbucket_api_token`.
|
||||||
|
"""
|
||||||
|
self.email = credentials.get("bitbucket_email")
|
||||||
|
self.api_token = credentials.get("bitbucket_api_token")
|
||||||
|
if not self.email or not self.api_token:
|
||||||
|
raise ConnectorMissingCredentialError("Bitbucket")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _client(self) -> httpx.Client:
|
||||||
|
"""Build an authenticated HTTP client or raise if credentials missing."""
|
||||||
|
if not self.email or not self.api_token:
|
||||||
|
raise ConnectorMissingCredentialError("Bitbucket")
|
||||||
|
return build_auth_client(self.email, self.api_token)
|
||||||
|
|
||||||
|
def _iter_pull_requests_for_repo(
|
||||||
|
self,
|
||||||
|
client: httpx.Client,
|
||||||
|
repo_slug: str,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
start_url: str | None = None,
|
||||||
|
on_page: Callable[[str | None], None] | None = None,
|
||||||
|
) -> Iterator[dict[str, Any]]:
|
||||||
|
base = f"https://api.bitbucket.org/2.0/repositories/{self.workspace}/{repo_slug}/pullrequests"
|
||||||
|
yield from paginate(
|
||||||
|
client,
|
||||||
|
base,
|
||||||
|
params,
|
||||||
|
start_url=start_url,
|
||||||
|
on_page=on_page,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_params(
|
||||||
|
self,
|
||||||
|
fields: str = PR_LIST_RESPONSE_FIELDS,
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build Bitbucket fetch params.
|
||||||
|
|
||||||
|
Always include OPEN, MERGED, and DECLINED PRs. If both ``start`` and
|
||||||
|
``end`` are provided, apply a single updated_on time window.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _iso(ts: SecondsSinceUnixEpoch) -> str:
|
||||||
|
return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat()
|
||||||
|
|
||||||
|
def _tc_epoch(
|
||||||
|
lower_epoch: SecondsSinceUnixEpoch | None,
|
||||||
|
upper_epoch: SecondsSinceUnixEpoch | None,
|
||||||
|
) -> str | None:
|
||||||
|
if lower_epoch is not None and upper_epoch is not None:
|
||||||
|
lower_iso = _iso(lower_epoch)
|
||||||
|
upper_iso = _iso(upper_epoch)
|
||||||
|
return f'(updated_on > "{lower_iso}" AND updated_on <= "{upper_iso}")'
|
||||||
|
return None
|
||||||
|
|
||||||
|
params: dict[str, Any] = {"fields": fields, "pagelen": 50}
|
||||||
|
time_clause = _tc_epoch(start, end)
|
||||||
|
q = '(state = "OPEN" OR state = "MERGED" OR state = "DECLINED")'
|
||||||
|
if time_clause:
|
||||||
|
q = f"{q} AND {time_clause}"
|
||||||
|
params["q"] = q
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _iter_target_repositories(self, client: httpx.Client) -> Iterator[str]:
|
||||||
|
"""Yield repository slugs based on configuration.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
- repositories list
|
||||||
|
- projects list (list repos by project key)
|
||||||
|
- workspace (all repos)
|
||||||
|
"""
|
||||||
|
if self._repositories:
|
||||||
|
for slug in self._repositories:
|
||||||
|
yield slug
|
||||||
|
return
|
||||||
|
if self._projects:
|
||||||
|
for project_key in self._projects:
|
||||||
|
for repo in list_repositories(client, self.workspace, project_key):
|
||||||
|
slug_val = repo.get("slug")
|
||||||
|
if isinstance(slug_val, str) and slug_val:
|
||||||
|
yield slug_val
|
||||||
|
return
|
||||||
|
for repo in list_repositories(client, self.workspace, None):
|
||||||
|
slug_val = repo.get("slug")
|
||||||
|
if isinstance(slug_val, str) and slug_val:
|
||||||
|
yield slug_val
|
||||||
|
|
||||||
|
@override
|
||||||
|
def load_from_checkpoint(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch,
|
||||||
|
end: SecondsSinceUnixEpoch,
|
||||||
|
checkpoint: BitbucketConnectorCheckpoint,
|
||||||
|
) -> CheckpointOutput[BitbucketConnectorCheckpoint]:
|
||||||
|
"""Resumable PR ingestion across repos and pages within a time window.
|
||||||
|
|
||||||
|
Yields Documents (or ConnectorFailure for per-PR mapping failures) and returns
|
||||||
|
an updated checkpoint that records repo position and next page URL.
|
||||||
|
"""
|
||||||
|
new_checkpoint = copy.deepcopy(checkpoint)
|
||||||
|
|
||||||
|
with self._client() as client:
|
||||||
|
# Materialize target repositories once
|
||||||
|
if not new_checkpoint.repos_queue:
|
||||||
|
# Preserve explicit order; otherwise ensure deterministic ordering
|
||||||
|
repos_list = list(self._iter_target_repositories(client))
|
||||||
|
new_checkpoint.repos_queue = sorted(set(repos_list))
|
||||||
|
new_checkpoint.current_repo_index = 0
|
||||||
|
new_checkpoint.next_url = None
|
||||||
|
|
||||||
|
repos = new_checkpoint.repos_queue
|
||||||
|
if not repos or new_checkpoint.current_repo_index >= len(repos):
|
||||||
|
new_checkpoint.has_more = False
|
||||||
|
return new_checkpoint
|
||||||
|
|
||||||
|
repo_slug = repos[new_checkpoint.current_repo_index]
|
||||||
|
|
||||||
|
first_page_params = self._build_params(
|
||||||
|
fields=PR_LIST_RESPONSE_FIELDS,
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_page(next_url: str | None) -> None:
|
||||||
|
new_checkpoint.next_url = next_url
|
||||||
|
|
||||||
|
for pr in self._iter_pull_requests_for_repo(
|
||||||
|
client,
|
||||||
|
repo_slug,
|
||||||
|
params=first_page_params,
|
||||||
|
start_url=new_checkpoint.next_url,
|
||||||
|
on_page=_on_page,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
document = map_pr_to_document(pr, self.workspace, repo_slug)
|
||||||
|
yield document
|
||||||
|
except Exception as e:
|
||||||
|
pr_id = pr.get("id")
|
||||||
|
pr_link = (
|
||||||
|
f"https://bitbucket.org/{self.workspace}/{repo_slug}/pull-requests/{pr_id}"
|
||||||
|
if pr_id is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
yield ConnectorFailure(
|
||||||
|
failed_document=DocumentFailure(
|
||||||
|
document_id=(
|
||||||
|
f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{repo_slug}:pr:{pr_id}"
|
||||||
|
if pr_id is not None
|
||||||
|
else f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{repo_slug}:pr:unknown"
|
||||||
|
),
|
||||||
|
document_link=pr_link,
|
||||||
|
),
|
||||||
|
failure_message=f"Failed to process Bitbucket PR: {e}",
|
||||||
|
exception=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Advance to next repository (if any) and set has_more accordingly
|
||||||
|
new_checkpoint.current_repo_index += 1
|
||||||
|
new_checkpoint.next_url = None
|
||||||
|
new_checkpoint.has_more = new_checkpoint.current_repo_index < len(repos)
|
||||||
|
|
||||||
|
return new_checkpoint
|
||||||
|
|
||||||
|
@override
|
||||||
|
def build_dummy_checkpoint(self) -> BitbucketConnectorCheckpoint:
|
||||||
|
"""Create an initial checkpoint with work remaining."""
|
||||||
|
return BitbucketConnectorCheckpoint(has_more=True)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def validate_checkpoint_json(
|
||||||
|
self, checkpoint_json: str
|
||||||
|
) -> BitbucketConnectorCheckpoint:
|
||||||
|
"""Validate and deserialize a checkpoint instance from JSON."""
|
||||||
|
return BitbucketConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||||
|
|
||||||
|
def retrieve_all_slim_docs_perm_sync(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
callback: IndexingHeartbeatInterface | None = None,
|
||||||
|
) -> Iterator[list[SlimDocument]]:
|
||||||
|
"""Return only document IDs for all existing pull requests."""
|
||||||
|
batch: list[SlimDocument] = []
|
||||||
|
params = self._build_params(
|
||||||
|
fields=SLIM_PR_LIST_RESPONSE_FIELDS,
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
)
|
||||||
|
with self._client() as client:
|
||||||
|
for slug in self._iter_target_repositories(client):
|
||||||
|
for pr in self._iter_pull_requests_for_repo(
|
||||||
|
client, slug, params=params
|
||||||
|
):
|
||||||
|
pr_id = pr["id"]
|
||||||
|
doc_id = f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{slug}:pr:{pr_id}"
|
||||||
|
batch.append(SlimDocument(id=doc_id))
|
||||||
|
if len(batch) >= self.batch_size:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
if callback:
|
||||||
|
if callback.should_stop():
|
||||||
|
# Note: this is not actually used for permission sync yet, just pruning
|
||||||
|
raise RuntimeError(
|
||||||
|
"bitbucket_pr_sync: Stop signal detected"
|
||||||
|
)
|
||||||
|
callback.progress("bitbucket_pr_sync", len(batch))
|
||||||
|
if batch:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def validate_connector_settings(self) -> None:
|
||||||
|
"""Validate Bitbucket credentials and workspace access by probing a lightweight endpoint.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CredentialExpiredError: on HTTP 401
|
||||||
|
InsufficientPermissionsError: on HTTP 403
|
||||||
|
UnexpectedValidationError: on any other failure
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self._client() as client:
|
||||||
|
url = f"https://api.bitbucket.org/2.0/repositories/{self.workspace}"
|
||||||
|
resp = client.get(
|
||||||
|
url,
|
||||||
|
params={"pagelen": 1, "fields": "pagelen"},
|
||||||
|
timeout=REQUEST_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
if resp.status_code == 401:
|
||||||
|
raise CredentialExpiredError(
|
||||||
|
"Invalid or expired Bitbucket credentials (HTTP 401)."
|
||||||
|
)
|
||||||
|
if resp.status_code == 403:
|
||||||
|
raise InsufficientPermissionsError(
|
||||||
|
"Insufficient permissions to access Bitbucket workspace (HTTP 403)."
|
||||||
|
)
|
||||||
|
if resp.status_code < 200 or resp.status_code >= 300:
|
||||||
|
raise UnexpectedValidationError(
|
||||||
|
f"Unexpected Bitbucket error (status={resp.status_code})."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Network or other unexpected errors
|
||||||
|
if isinstance(
|
||||||
|
e,
|
||||||
|
(
|
||||||
|
CredentialExpiredError,
|
||||||
|
InsufficientPermissionsError,
|
||||||
|
UnexpectedValidationError,
|
||||||
|
ConnectorMissingCredentialError,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
raise
|
||||||
|
raise UnexpectedValidationError(
|
||||||
|
f"Unexpected error while validating Bitbucket settings: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bitbucket = BitbucketConnector(
|
||||||
|
workspace="<YOUR_WORKSPACE>"
|
||||||
|
)
|
||||||
|
|
||||||
|
bitbucket.load_credentials({
|
||||||
|
"bitbucket_email": "<YOUR_EMAIL>",
|
||||||
|
"bitbucket_api_token": "<YOUR_API_TOKEN>",
|
||||||
|
})
|
||||||
|
|
||||||
|
bitbucket.validate_connector_settings()
|
||||||
|
print("Credentials validated successfully.")
|
||||||
|
|
||||||
|
start_time = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||||
|
end_time = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
for doc_batch in bitbucket.retrieve_all_slim_docs_perm_sync(
|
||||||
|
start=start_time.timestamp(),
|
||||||
|
end=end_time.timestamp(),
|
||||||
|
):
|
||||||
|
for doc in doc_batch:
|
||||||
|
print(doc)
|
||||||
|
|
||||||
|
|
||||||
|
bitbucket_checkpoint = bitbucket.build_dummy_checkpoint()
|
||||||
|
|
||||||
|
while bitbucket_checkpoint.has_more:
|
||||||
|
gen = bitbucket.load_from_checkpoint(
|
||||||
|
start=start_time.timestamp(),
|
||||||
|
end=end_time.timestamp(),
|
||||||
|
checkpoint=bitbucket_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
doc = next(gen)
|
||||||
|
print(doc)
|
||||||
|
except StopIteration as e:
|
||||||
|
bitbucket_checkpoint = e.value
|
||||||
|
break
|
||||||
|
|
||||||
288
common/data_source/bitbucket/utils.py
Normal file
288
common/data_source/bitbucket/utils.py
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from datetime import datetime
|
||||||
|
from datetime import timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from common.data_source.config import REQUEST_TIMEOUT_SECONDS, DocumentSource
|
||||||
|
from common.data_source.cross_connector_utils.rate_limit_wrapper import (
|
||||||
|
rate_limit_builder,
|
||||||
|
)
|
||||||
|
from common.data_source.utils import sanitize_filename
|
||||||
|
from common.data_source.models import BasicExpertInfo, Document
|
||||||
|
from common.data_source.cross_connector_utils.retry_wrapper import retry_builder
|
||||||
|
|
||||||
|
# Fields requested from Bitbucket PR list endpoint to ensure rich PR data
|
||||||
|
PR_LIST_RESPONSE_FIELDS: str = ",".join(
|
||||||
|
[
|
||||||
|
"next",
|
||||||
|
"page",
|
||||||
|
"pagelen",
|
||||||
|
"values.author",
|
||||||
|
"values.close_source_branch",
|
||||||
|
"values.closed_by",
|
||||||
|
"values.comment_count",
|
||||||
|
"values.created_on",
|
||||||
|
"values.description",
|
||||||
|
"values.destination",
|
||||||
|
"values.draft",
|
||||||
|
"values.id",
|
||||||
|
"values.links",
|
||||||
|
"values.merge_commit",
|
||||||
|
"values.participants",
|
||||||
|
"values.reason",
|
||||||
|
"values.rendered",
|
||||||
|
"values.reviewers",
|
||||||
|
"values.source",
|
||||||
|
"values.state",
|
||||||
|
"values.summary",
|
||||||
|
"values.task_count",
|
||||||
|
"values.title",
|
||||||
|
"values.type",
|
||||||
|
"values.updated_on",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Minimal fields for slim retrieval (IDs only)
|
||||||
|
SLIM_PR_LIST_RESPONSE_FIELDS: str = ",".join(
|
||||||
|
[
|
||||||
|
"next",
|
||||||
|
"page",
|
||||||
|
"pagelen",
|
||||||
|
"values.id",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Minimal fields for repository list calls
|
||||||
|
REPO_LIST_RESPONSE_FIELDS: str = ",".join(
|
||||||
|
[
|
||||||
|
"next",
|
||||||
|
"page",
|
||||||
|
"pagelen",
|
||||||
|
"values.slug",
|
||||||
|
"values.full_name",
|
||||||
|
"values.project.key",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BitbucketRetriableError(Exception):
|
||||||
|
"""Raised for retriable Bitbucket conditions (429, 5xx)."""
|
||||||
|
|
||||||
|
|
||||||
|
class BitbucketNonRetriableError(Exception):
|
||||||
|
"""Raised for non-retriable Bitbucket client errors (4xx except 429)."""
|
||||||
|
|
||||||
|
|
||||||
|
@retry_builder(
|
||||||
|
tries=6,
|
||||||
|
delay=1,
|
||||||
|
backoff=2,
|
||||||
|
max_delay=30,
|
||||||
|
exceptions=(BitbucketRetriableError, httpx.RequestError),
|
||||||
|
)
|
||||||
|
@rate_limit_builder(max_calls=60, period=60)
|
||||||
|
def bitbucket_get(
|
||||||
|
client: httpx.Client, url: str, params: dict[str, Any] | None = None
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""Perform a GET against Bitbucket with retry and rate limiting.
|
||||||
|
|
||||||
|
Retries on 429 and 5xx responses, and on transport errors. Honors
|
||||||
|
`Retry-After` header for 429 when present by sleeping before retrying.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = client.get(url, params=params, timeout=REQUEST_TIMEOUT_SECONDS)
|
||||||
|
except httpx.RequestError:
|
||||||
|
# Allow retry_builder to handle retries of transport errors
|
||||||
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
status = e.response.status_code if e.response is not None else None
|
||||||
|
if status == 429:
|
||||||
|
retry_after = e.response.headers.get("Retry-After") if e.response else None
|
||||||
|
if retry_after is not None:
|
||||||
|
try:
|
||||||
|
time.sleep(int(retry_after))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
raise BitbucketRetriableError("Bitbucket rate limit exceeded (429)") from e
|
||||||
|
if status is not None and 500 <= status < 600:
|
||||||
|
raise BitbucketRetriableError(f"Bitbucket server error: {status}") from e
|
||||||
|
if status is not None and 400 <= status < 500:
|
||||||
|
raise BitbucketNonRetriableError(f"Bitbucket client error: {status}") from e
|
||||||
|
# Unknown status, propagate
|
||||||
|
raise
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def build_auth_client(email: str, api_token: str) -> httpx.Client:
|
||||||
|
"""Create an authenticated httpx client for Bitbucket Cloud API."""
|
||||||
|
return httpx.Client(auth=(email, api_token), http2=True)
|
||||||
|
|
||||||
|
|
||||||
|
def paginate(
|
||||||
|
client: httpx.Client,
|
||||||
|
url: str,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
start_url: str | None = None,
|
||||||
|
on_page: Callable[[str | None], None] | None = None,
|
||||||
|
) -> Iterator[dict[str, Any]]:
|
||||||
|
"""Iterate over paginated Bitbucket API responses yielding individual values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Authenticated HTTP client.
|
||||||
|
url: Base collection URL (first page when start_url is None).
|
||||||
|
params: Query params for the first page.
|
||||||
|
start_url: If provided, start from this absolute URL (ignores params).
|
||||||
|
on_page: Optional callback invoked after each page with the next page URL.
|
||||||
|
"""
|
||||||
|
next_url = start_url or url
|
||||||
|
# If resuming from a next URL, do not pass params again
|
||||||
|
query = params.copy() if params else None
|
||||||
|
query = None if start_url else query
|
||||||
|
while next_url:
|
||||||
|
resp = bitbucket_get(client, next_url, params=query)
|
||||||
|
data = resp.json()
|
||||||
|
values = data.get("values", [])
|
||||||
|
for item in values:
|
||||||
|
yield item
|
||||||
|
next_url = data.get("next")
|
||||||
|
if on_page is not None:
|
||||||
|
on_page(next_url)
|
||||||
|
# only include params on first call, next_url will contain all necessary params
|
||||||
|
query = None
|
||||||
|
|
||||||
|
|
||||||
|
def list_repositories(
|
||||||
|
client: httpx.Client, workspace: str, project_key: str | None = None
|
||||||
|
) -> Iterator[dict[str, Any]]:
|
||||||
|
"""List repositories in a workspace, optionally filtered by project key."""
|
||||||
|
base_url = f"https://api.bitbucket.org/2.0/repositories/{workspace}"
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"fields": REPO_LIST_RESPONSE_FIELDS,
|
||||||
|
"pagelen": 100,
|
||||||
|
# Ensure deterministic ordering
|
||||||
|
"sort": "full_name",
|
||||||
|
}
|
||||||
|
if project_key:
|
||||||
|
params["q"] = f'project.key="{project_key}"'
|
||||||
|
yield from paginate(client, base_url, params)
|
||||||
|
|
||||||
|
|
||||||
|
def map_pr_to_document(pr: dict[str, Any], workspace: str, repo_slug: str) -> Document:
|
||||||
|
"""Map a Bitbucket pull request JSON to Onyx Document."""
|
||||||
|
pr_id = pr["id"]
|
||||||
|
title = pr.get("title") or f"PR {pr_id}"
|
||||||
|
description = pr.get("description") or ""
|
||||||
|
state = pr.get("state")
|
||||||
|
draft = pr.get("draft", False)
|
||||||
|
author = pr.get("author", {})
|
||||||
|
reviewers = pr.get("reviewers", [])
|
||||||
|
participants = pr.get("participants", [])
|
||||||
|
|
||||||
|
link = pr.get("links", {}).get("html", {}).get("href") or (
|
||||||
|
f"https://bitbucket.org/{workspace}/{repo_slug}/pull-requests/{pr_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
created_on = pr.get("created_on")
|
||||||
|
updated_on = pr.get("updated_on")
|
||||||
|
updated_dt = (
|
||||||
|
datetime.fromisoformat(updated_on.replace("Z", "+00:00")).astimezone(
|
||||||
|
timezone.utc
|
||||||
|
)
|
||||||
|
if isinstance(updated_on, str)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
source_branch = pr.get("source", {}).get("branch", {}).get("name", "")
|
||||||
|
destination_branch = pr.get("destination", {}).get("branch", {}).get("name", "")
|
||||||
|
|
||||||
|
approved_by = [
|
||||||
|
_get_user_name(p.get("user", {})) for p in participants if p.get("approved")
|
||||||
|
]
|
||||||
|
|
||||||
|
primary_owner = None
|
||||||
|
if author:
|
||||||
|
primary_owner = BasicExpertInfo(
|
||||||
|
display_name=_get_user_name(author),
|
||||||
|
)
|
||||||
|
|
||||||
|
# secondary_owners = [
|
||||||
|
# BasicExpertInfo(display_name=_get_user_name(r)) for r in reviewers
|
||||||
|
# ] or None
|
||||||
|
|
||||||
|
reviewer_names = [_get_user_name(r) for r in reviewers]
|
||||||
|
|
||||||
|
# Create a concise summary of key PR info
|
||||||
|
created_date = created_on.split("T")[0] if created_on else "N/A"
|
||||||
|
updated_date = updated_on.split("T")[0] if updated_on else "N/A"
|
||||||
|
content_text = (
|
||||||
|
"Pull Request Information:\n"
|
||||||
|
f"- Pull Request ID: {pr_id}\n"
|
||||||
|
f"- Title: {title}\n"
|
||||||
|
f"- State: {state or 'N/A'} {'(Draft)' if draft else ''}\n"
|
||||||
|
)
|
||||||
|
if state == "DECLINED":
|
||||||
|
content_text += f"- Reason: {pr.get('reason', 'N/A')}\n"
|
||||||
|
content_text += (
|
||||||
|
f"- Author: {_get_user_name(author) if author else 'N/A'}\n"
|
||||||
|
f"- Reviewers: {', '.join(reviewer_names) if reviewer_names else 'N/A'}\n"
|
||||||
|
f"- Branch: {source_branch} -> {destination_branch}\n"
|
||||||
|
f"- Created: {created_date}\n"
|
||||||
|
f"- Updated: {updated_date}"
|
||||||
|
)
|
||||||
|
if description:
|
||||||
|
content_text += f"\n\nDescription:\n{description}"
|
||||||
|
|
||||||
|
metadata: dict[str, str | list[str]] = {
|
||||||
|
"object_type": "PullRequest",
|
||||||
|
"workspace": workspace,
|
||||||
|
"repository": repo_slug,
|
||||||
|
"pr_key": f"{workspace}/{repo_slug}#{pr_id}",
|
||||||
|
"id": str(pr_id),
|
||||||
|
"title": title,
|
||||||
|
"state": state or "",
|
||||||
|
"draft": str(bool(draft)),
|
||||||
|
"link": link,
|
||||||
|
"author": _get_user_name(author) if author else "",
|
||||||
|
"reviewers": reviewer_names,
|
||||||
|
"approved_by": approved_by,
|
||||||
|
"comment_count": str(pr.get("comment_count", "")),
|
||||||
|
"task_count": str(pr.get("task_count", "")),
|
||||||
|
"created_on": created_on or "",
|
||||||
|
"updated_on": updated_on or "",
|
||||||
|
"source_branch": source_branch,
|
||||||
|
"destination_branch": destination_branch,
|
||||||
|
"closed_by": (
|
||||||
|
_get_user_name(pr.get("closed_by", {})) if pr.get("closed_by") else ""
|
||||||
|
),
|
||||||
|
"close_source_branch": str(bool(pr.get("close_source_branch", False))),
|
||||||
|
}
|
||||||
|
|
||||||
|
name = sanitize_filename(title, "md")
|
||||||
|
|
||||||
|
return Document(
|
||||||
|
id=f"{DocumentSource.BITBUCKET.value}:{workspace}:{repo_slug}:pr:{pr_id}",
|
||||||
|
blob=content_text.encode("utf-8"),
|
||||||
|
source=DocumentSource.BITBUCKET,
|
||||||
|
extension=".md",
|
||||||
|
semantic_identifier=f"#{pr_id}: {name}",
|
||||||
|
size_bytes=len(content_text.encode("utf-8")),
|
||||||
|
doc_updated_at=updated_dt,
|
||||||
|
primary_owners=[primary_owner] if primary_owner else None,
|
||||||
|
# secondary_owners=secondary_owners,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_name(user: dict[str, Any]) -> str:
|
||||||
|
return user.get("display_name") or user.get("nickname") or "unknown"
|
||||||
@ -64,16 +64,24 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
|||||||
|
|
||||||
elif self.bucket_type == BlobType.S3:
|
elif self.bucket_type == BlobType.S3:
|
||||||
authentication_method = credentials.get("authentication_method", "access_key")
|
authentication_method = credentials.get("authentication_method", "access_key")
|
||||||
|
|
||||||
if authentication_method == "access_key":
|
if authentication_method == "access_key":
|
||||||
if not all(
|
if not all(
|
||||||
credentials.get(key)
|
credentials.get(key)
|
||||||
for key in ["aws_access_key_id", "aws_secret_access_key"]
|
for key in ["aws_access_key_id", "aws_secret_access_key"]
|
||||||
):
|
):
|
||||||
raise ConnectorMissingCredentialError("Amazon S3")
|
raise ConnectorMissingCredentialError("Amazon S3")
|
||||||
|
|
||||||
elif authentication_method == "iam_role":
|
elif authentication_method == "iam_role":
|
||||||
if not credentials.get("aws_role_arn"):
|
if not credentials.get("aws_role_arn"):
|
||||||
raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required")
|
raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required")
|
||||||
|
|
||||||
|
elif authentication_method == "assume_role":
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ConnectorMissingCredentialError("Unsupported S3 authentication method")
|
||||||
|
|
||||||
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
|
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
|
||||||
if not all(
|
if not all(
|
||||||
credentials.get(key) for key in ["access_key_id", "secret_access_key"]
|
credentials.get(key) for key in ["access_key_id", "secret_access_key"]
|
||||||
@ -120,55 +128,72 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
|||||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
||||||
|
|
||||||
batch: list[Document] = []
|
# Collect all objects first to count filename occurrences
|
||||||
|
all_objects = []
|
||||||
for page in pages:
|
for page in pages:
|
||||||
if "Contents" not in page:
|
if "Contents" not in page:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for obj in page["Contents"]:
|
for obj in page["Contents"]:
|
||||||
if obj["Key"].endswith("/"):
|
if obj["Key"].endswith("/"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||||
|
if start < last_modified <= end:
|
||||||
|
all_objects.append(obj)
|
||||||
|
|
||||||
if not (start < last_modified <= end):
|
# Count filename occurrences to determine which need full paths
|
||||||
|
filename_counts: dict[str, int] = {}
|
||||||
|
for obj in all_objects:
|
||||||
|
file_name = os.path.basename(obj["Key"])
|
||||||
|
filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
|
||||||
|
|
||||||
|
batch: list[Document] = []
|
||||||
|
for obj in all_objects:
|
||||||
|
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||||
|
file_name = os.path.basename(obj["Key"])
|
||||||
|
key = obj["Key"]
|
||||||
|
|
||||||
|
size_bytes = extract_size_bytes(obj)
|
||||||
|
if (
|
||||||
|
self.size_threshold is not None
|
||||||
|
and isinstance(size_bytes, int)
|
||||||
|
and size_bytes > self.size_threshold
|
||||||
|
):
|
||||||
|
logging.warning(
|
||||||
|
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
|
||||||
|
if blob is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
file_name = os.path.basename(obj["Key"])
|
# Use full path only if filename appears multiple times
|
||||||
key = obj["Key"]
|
if filename_counts.get(file_name, 0) > 1:
|
||||||
|
relative_path = key
|
||||||
|
if self.prefix and key.startswith(self.prefix):
|
||||||
|
relative_path = key[len(self.prefix):]
|
||||||
|
semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name
|
||||||
|
else:
|
||||||
|
semantic_id = file_name
|
||||||
|
|
||||||
size_bytes = extract_size_bytes(obj)
|
batch.append(
|
||||||
if (
|
Document(
|
||||||
self.size_threshold is not None
|
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||||
and isinstance(size_bytes, int)
|
blob=blob,
|
||||||
and size_bytes > self.size_threshold
|
source=DocumentSource(self.bucket_type.value),
|
||||||
):
|
semantic_identifier=semantic_id,
|
||||||
logging.warning(
|
extension=get_file_ext(file_name),
|
||||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
doc_updated_at=last_modified,
|
||||||
|
size_bytes=size_bytes if size_bytes else 0
|
||||||
)
|
)
|
||||||
continue
|
)
|
||||||
try:
|
if len(batch) == self.batch_size:
|
||||||
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
|
yield batch
|
||||||
if blob is None:
|
batch = []
|
||||||
continue
|
|
||||||
|
|
||||||
batch.append(
|
except Exception:
|
||||||
Document(
|
logging.exception(f"Error decoding object {key}")
|
||||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
|
||||||
blob=blob,
|
|
||||||
source=DocumentSource(self.bucket_type.value),
|
|
||||||
semantic_identifier=file_name,
|
|
||||||
extension=get_file_ext(file_name),
|
|
||||||
doc_updated_at=last_modified,
|
|
||||||
size_bytes=size_bytes if size_bytes else 0
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if len(batch) == self.batch_size:
|
|
||||||
yield batch
|
|
||||||
batch = []
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
logging.exception(f"Error decoding object {key}")
|
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
yield batch
|
yield batch
|
||||||
|
|||||||
@ -13,6 +13,9 @@ def get_current_tz_offset() -> int:
|
|||||||
return round(time_diff.total_seconds() / 3600)
|
return round(time_diff.total_seconds() / 3600)
|
||||||
|
|
||||||
|
|
||||||
|
# Default request timeout, mostly used by connectors
|
||||||
|
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
|
||||||
|
|
||||||
ONE_MINUTE = 60
|
ONE_MINUTE = 60
|
||||||
ONE_HOUR = 3600
|
ONE_HOUR = 3600
|
||||||
ONE_DAY = ONE_HOUR * 24
|
ONE_DAY = ONE_HOUR * 24
|
||||||
@ -53,6 +56,14 @@ class DocumentSource(str, Enum):
|
|||||||
S3_COMPATIBLE = "s3_compatible"
|
S3_COMPATIBLE = "s3_compatible"
|
||||||
DROPBOX = "dropbox"
|
DROPBOX = "dropbox"
|
||||||
BOX = "box"
|
BOX = "box"
|
||||||
|
AIRTABLE = "airtable"
|
||||||
|
ASANA = "asana"
|
||||||
|
GITHUB = "github"
|
||||||
|
GITLAB = "gitlab"
|
||||||
|
IMAP = "imap"
|
||||||
|
BITBUCKET = "bitbucket"
|
||||||
|
ZENDESK = "zendesk"
|
||||||
|
|
||||||
|
|
||||||
class FileOrigin(str, Enum):
|
class FileOrigin(str, Enum):
|
||||||
"""File origins"""
|
"""File origins"""
|
||||||
@ -83,6 +94,7 @@ _PAGE_EXPANSION_FIELDS = [
|
|||||||
"space",
|
"space",
|
||||||
"metadata.labels",
|
"metadata.labels",
|
||||||
"history.lastUpdated",
|
"history.lastUpdated",
|
||||||
|
"ancestors",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -229,6 +241,8 @@ _REPLACEMENT_EXPANSIONS = "body.view.value"
|
|||||||
|
|
||||||
BOX_WEB_OAUTH_REDIRECT_URI = os.environ.get("BOX_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/box/oauth/web/callback")
|
BOX_WEB_OAUTH_REDIRECT_URI = os.environ.get("BOX_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/box/oauth/web/callback")
|
||||||
|
|
||||||
|
GITHUB_CONNECTOR_BASE_URL = os.environ.get("GITHUB_CONNECTOR_BASE_URL") or None
|
||||||
|
|
||||||
class HtmlBasedConnectorTransformLinksStrategy(str, Enum):
|
class HtmlBasedConnectorTransformLinksStrategy(str, Enum):
|
||||||
# remove links entirely
|
# remove links entirely
|
||||||
STRIP = "strip"
|
STRIP = "strip"
|
||||||
@ -250,6 +264,22 @@ WEB_CONNECTOR_IGNORED_ELEMENTS = os.environ.get(
|
|||||||
"WEB_CONNECTOR_IGNORED_ELEMENTS", "nav,footer,meta,script,style,symbol,aside"
|
"WEB_CONNECTOR_IGNORED_ELEMENTS", "nav,footer,meta,script,style,symbol,aside"
|
||||||
).split(",")
|
).split(",")
|
||||||
|
|
||||||
|
AIRTABLE_CONNECTOR_SIZE_THRESHOLD = int(
|
||||||
|
os.environ.get("AIRTABLE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||||
|
)
|
||||||
|
|
||||||
|
ASANA_CONNECTOR_SIZE_THRESHOLD = int(
|
||||||
|
os.environ.get("ASANA_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||||
|
)
|
||||||
|
|
||||||
|
IMAP_CONNECTOR_SIZE_THRESHOLD = int(
|
||||||
|
os.environ.get("IMAP_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||||
|
)
|
||||||
|
|
||||||
|
ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS = os.environ.get(
|
||||||
|
"ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS", ""
|
||||||
|
).split(",")
|
||||||
|
|
||||||
_USER_NOT_FOUND = "Unknown Confluence User"
|
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||||
|
|
||||||
_COMMENT_EXPANSION_FIELDS = ["body.storage.value"]
|
_COMMENT_EXPANSION_FIELDS = ["body.storage.value"]
|
||||||
|
|||||||
@ -186,7 +186,7 @@ class OnyxConfluence:
|
|||||||
# between the db and redis everywhere the credentials might be updated
|
# between the db and redis everywhere the credentials might be updated
|
||||||
new_credential_str = json.dumps(new_credentials)
|
new_credential_str = json.dumps(new_credentials)
|
||||||
self.redis_client.set(
|
self.redis_client.set(
|
||||||
self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL
|
self.credential_key, new_credential_str, exp=self.CREDENTIAL_TTL
|
||||||
)
|
)
|
||||||
self._credentials_provider.set_credentials(new_credentials)
|
self._credentials_provider.set_credentials(new_credentials)
|
||||||
|
|
||||||
@ -1311,6 +1311,9 @@ class ConfluenceConnector(
|
|||||||
self._low_timeout_confluence_client: OnyxConfluence | None = None
|
self._low_timeout_confluence_client: OnyxConfluence | None = None
|
||||||
self._fetched_titles: set[str] = set()
|
self._fetched_titles: set[str] = set()
|
||||||
self.allow_images = False
|
self.allow_images = False
|
||||||
|
# Track document names to detect duplicates
|
||||||
|
self._document_name_counts: dict[str, int] = {}
|
||||||
|
self._document_name_paths: dict[str, list[str]] = {}
|
||||||
|
|
||||||
# Remove trailing slash from wiki_base if present
|
# Remove trailing slash from wiki_base if present
|
||||||
self.wiki_base = wiki_base.rstrip("/")
|
self.wiki_base = wiki_base.rstrip("/")
|
||||||
@ -1513,6 +1516,40 @@ class ConfluenceConnector(
|
|||||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Build hierarchical path for semantic identifier
|
||||||
|
space_name = page.get("space", {}).get("name", "")
|
||||||
|
|
||||||
|
# Build path from ancestors
|
||||||
|
path_parts = []
|
||||||
|
if space_name:
|
||||||
|
path_parts.append(space_name)
|
||||||
|
|
||||||
|
# Add ancestor pages to path if available
|
||||||
|
if "ancestors" in page and page["ancestors"]:
|
||||||
|
for ancestor in page["ancestors"]:
|
||||||
|
ancestor_title = ancestor.get("title", "")
|
||||||
|
if ancestor_title:
|
||||||
|
path_parts.append(ancestor_title)
|
||||||
|
|
||||||
|
# Add current page title
|
||||||
|
path_parts.append(page_title)
|
||||||
|
|
||||||
|
# Track page names for duplicate detection
|
||||||
|
full_path = " / ".join(path_parts) if len(path_parts) > 1 else page_title
|
||||||
|
|
||||||
|
# Count occurrences of this page title
|
||||||
|
if page_title not in self._document_name_counts:
|
||||||
|
self._document_name_counts[page_title] = 0
|
||||||
|
self._document_name_paths[page_title] = []
|
||||||
|
self._document_name_counts[page_title] += 1
|
||||||
|
self._document_name_paths[page_title].append(full_path)
|
||||||
|
|
||||||
|
# Use simple name if no duplicates, otherwise use full path
|
||||||
|
if self._document_name_counts[page_title] == 1:
|
||||||
|
semantic_identifier = page_title
|
||||||
|
else:
|
||||||
|
semantic_identifier = full_path
|
||||||
|
|
||||||
# Get the page content
|
# Get the page content
|
||||||
page_content = extract_text_from_confluence_html(
|
page_content = extract_text_from_confluence_html(
|
||||||
self.confluence_client, page, self._fetched_titles
|
self.confluence_client, page, self._fetched_titles
|
||||||
@ -1559,11 +1596,11 @@ class ConfluenceConnector(
|
|||||||
return Document(
|
return Document(
|
||||||
id=page_url,
|
id=page_url,
|
||||||
source=DocumentSource.CONFLUENCE,
|
source=DocumentSource.CONFLUENCE,
|
||||||
semantic_identifier=page_title,
|
semantic_identifier=semantic_identifier,
|
||||||
extension=".html", # Confluence pages are HTML
|
extension=".html", # Confluence pages are HTML
|
||||||
blob=page_content.encode("utf-8"), # Encode page content as bytes
|
blob=page_content.encode("utf-8"), # Encode page content as bytes
|
||||||
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
|
|
||||||
doc_updated_at=datetime_from_string(page["version"]["when"]),
|
doc_updated_at=datetime_from_string(page["version"]["when"]),
|
||||||
|
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
|
||||||
primary_owners=primary_owners if primary_owners else None,
|
primary_owners=primary_owners if primary_owners else None,
|
||||||
metadata=metadata if metadata else None,
|
metadata=metadata if metadata else None,
|
||||||
)
|
)
|
||||||
@ -1601,7 +1638,6 @@ class ConfluenceConnector(
|
|||||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||||
):
|
):
|
||||||
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
|
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
|
||||||
|
|
||||||
# TODO(rkuo): this check is partially redundant with validate_attachment_filetype
|
# TODO(rkuo): this check is partially redundant with validate_attachment_filetype
|
||||||
# and checks in convert_attachment_to_content/process_attachment
|
# and checks in convert_attachment_to_content/process_attachment
|
||||||
# but doing the check here avoids an unnecessary download. Due for refactoring.
|
# but doing the check here avoids an unnecessary download. Due for refactoring.
|
||||||
@ -1669,6 +1705,34 @@ class ConfluenceConnector(
|
|||||||
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Build semantic identifier with space and page context
|
||||||
|
attachment_title = attachment.get("title", object_url)
|
||||||
|
space_name = page.get("space", {}).get("name", "")
|
||||||
|
page_title = page.get("title", "")
|
||||||
|
|
||||||
|
# Create hierarchical name: Space / Page / Attachment
|
||||||
|
attachment_path_parts = []
|
||||||
|
if space_name:
|
||||||
|
attachment_path_parts.append(space_name)
|
||||||
|
if page_title:
|
||||||
|
attachment_path_parts.append(page_title)
|
||||||
|
attachment_path_parts.append(attachment_title)
|
||||||
|
|
||||||
|
full_attachment_path = " / ".join(attachment_path_parts) if len(attachment_path_parts) > 1 else attachment_title
|
||||||
|
|
||||||
|
# Track attachment names for duplicate detection
|
||||||
|
if attachment_title not in self._document_name_counts:
|
||||||
|
self._document_name_counts[attachment_title] = 0
|
||||||
|
self._document_name_paths[attachment_title] = []
|
||||||
|
self._document_name_counts[attachment_title] += 1
|
||||||
|
self._document_name_paths[attachment_title].append(full_attachment_path)
|
||||||
|
|
||||||
|
# Use simple name if no duplicates, otherwise use full path
|
||||||
|
if self._document_name_counts[attachment_title] == 1:
|
||||||
|
attachment_semantic_identifier = attachment_title
|
||||||
|
else:
|
||||||
|
attachment_semantic_identifier = full_attachment_path
|
||||||
|
|
||||||
primary_owners: list[BasicExpertInfo] | None = None
|
primary_owners: list[BasicExpertInfo] | None = None
|
||||||
if "version" in attachment and "by" in attachment["version"]:
|
if "version" in attachment and "by" in attachment["version"]:
|
||||||
author = attachment["version"]["by"]
|
author = attachment["version"]["by"]
|
||||||
@ -1680,11 +1744,12 @@ class ConfluenceConnector(
|
|||||||
|
|
||||||
extension = Path(attachment.get("title", "")).suffix or ".unknown"
|
extension = Path(attachment.get("title", "")).suffix or ".unknown"
|
||||||
|
|
||||||
|
|
||||||
attachment_doc = Document(
|
attachment_doc = Document(
|
||||||
id=attachment_id,
|
id=attachment_id,
|
||||||
# sections=sections,
|
# sections=sections,
|
||||||
source=DocumentSource.CONFLUENCE,
|
source=DocumentSource.CONFLUENCE,
|
||||||
semantic_identifier=attachment.get("title", object_url),
|
semantic_identifier=attachment_semantic_identifier,
|
||||||
extension=extension,
|
extension=extension,
|
||||||
blob=file_blob,
|
blob=file_blob,
|
||||||
size_bytes=len(file_blob),
|
size_bytes=len(file_blob),
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user