mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-22 14:16:42 +08:00
Compare commits
353 Commits
v0.22.0
...
alert-auto
| Author | SHA1 | Date | |
|---|---|---|---|
| 5749abfd89 | |||
| 6c9afd1ffb | |||
| bfef96d56e | |||
| 74adf3d59c | |||
| ba7e087aef | |||
| f911aa2997 | |||
| 42f9ac997f | |||
| 15eccb445d | |||
| c7cf7aad4e | |||
| 2118bc2556 | |||
| b49eb6826b | |||
| 8dd2394e93 | |||
| 5aea82d9c4 | |||
| 47005ebe10 | |||
| 3ee47e4af7 | |||
| 55c0468ac9 | |||
| eeb36a5ce7 | |||
| aceca266ff | |||
| d82e502a71 | |||
| 0494b92371 | |||
| 8683a5b1b7 | |||
| 4cbe470089 | |||
| 6cd1824a77 | |||
| 2844700dc4 | |||
| f8fd1ea7e1 | |||
| 57edc215d7 | |||
| 7a4044b05f | |||
| e84d5412bc | |||
| 151480dc85 | |||
| 2331b3a270 | |||
| 5cd1a678c8 | |||
| cc9546b761 | |||
| a63dcfed6f | |||
| 4dd8cdc38b | |||
| 1a4822d6be | |||
| ce161f09cc | |||
| 672958a192 | |||
| 3820de916c | |||
| ef44979b5c | |||
| d38f8a1562 | |||
| 8e4d011b15 | |||
| 7baa67dfe8 | |||
| e58271ef76 | |||
| 4fd4a41e7c | |||
| 82d4e5fb87 | |||
| d16643a53d | |||
| 93ca1e0b91 | |||
| 4046bffaf1 | |||
| 03f9be7cbb | |||
| 5e05f43c3d | |||
| 205a6483f5 | |||
| 2595644dfd | |||
| 30019dab9f | |||
| 4d46726eb7 | |||
| 0e8b9588ba | |||
| 344a106eba | |||
| bccad7b4a8 | |||
| f7926724aa | |||
| 5bba562048 | |||
| 49c74d08e8 | |||
| 1112b6291b | |||
| ef5d1d4b74 | |||
| a98887d4ca | |||
| 7ca3e11566 | |||
| a2e080c2d3 | |||
| ad6f7fd4b0 | |||
| 2a0f835ffe | |||
| 13d8241eee | |||
| 1ddd11f045 | |||
| 81eb03d230 | |||
| 7d23c3aed0 | |||
| 6be0338aa0 | |||
| 44dec89f1f | |||
| 2b260901df | |||
| 948bc93786 | |||
| 0f0fb53256 | |||
| 0fcb1680fd | |||
| 50715ba332 | |||
| f9510edbbc | |||
| 6560388f2b | |||
| e37aea5f81 | |||
| 7db9045b74 | |||
| a6bd765a02 | |||
| 74afb8d710 | |||
| ea4a5cd665 | |||
| 22a51a3868 | |||
| e9710b7aa9 | |||
| bd0eff2954 | |||
| e3cfe8e848 | |||
| c610bb605a | |||
| a6afb7dfe2 | |||
| 7b96113d4c | |||
| 8370bc61b7 | |||
| 74eb894453 | |||
| 34d29d7e8b | |||
| badf33e3b9 | |||
| 3cb72377d7 | |||
| ab4b62031f | |||
| 80f3ccf1ac | |||
| a1164b9c89 | |||
| fd7e55b23d | |||
| f128a1fa9e | |||
| 65a5a56d95 | |||
| ca2d6f3301 | |||
| a94b3b9df2 | |||
| 30377319d8 | |||
| 07dca37ef0 | |||
| 036b29f084 | |||
| 9863862348 | |||
| bb6022477e | |||
| 28bc87c5e2 | |||
| c51e6b2a58 | |||
| 481192300d | |||
| 1777620ea5 | |||
| f3a03b06b2 | |||
| dd046be976 | |||
| 5c9672a265 | |||
| 09a3854ed8 | |||
| 43f51baa96 | |||
| 5a2011e687 | |||
| 7dd9ce0b5f | |||
| b66881a371 | |||
| 4d7934061e | |||
| 660fa8888b | |||
| 3285f09c92 | |||
| 51ec708c58 | |||
| 9b8971a9de | |||
| 6546f86b4e | |||
| 8de6b97806 | |||
| e4e0a88053 | |||
| 7719fd6350 | |||
| 15ef6dd72f | |||
| 5b5f19cbc1 | |||
| ea38e12d42 | |||
| 885eb2eab9 | |||
| 6587acef88 | |||
| ad03ede7cd | |||
| 468e4042c2 | |||
| af1344033d | |||
| 4012d65b3c | |||
| e2bc1a3478 | |||
| 6c2c447a72 | |||
| e7022db9a4 | |||
| ca4a0ee1b2 | |||
| 27b0550876 | |||
| 797e03f843 | |||
| b4e06237ef | |||
| 751a13fb64 | |||
| fa7b857aa9 | |||
| 257af75ece | |||
| cbdacf21f6 | |||
| b1f3130519 | |||
| 3c224c817b | |||
| a3c9402218 | |||
| a7d40e9132 | |||
| 648342b62f | |||
| 4870d42949 | |||
| caaf7043cc | |||
| 237a66913b | |||
| 3c50c7d3ac | |||
| b44e65a12e | |||
| e3f40db963 | |||
| b5ad7b7062 | |||
| 6fc7def562 | |||
| c8f608b2dd | |||
| 5c81e01de5 | |||
| 83fac6d0a0 | |||
| a6681d6366 | |||
| 1388c4420d | |||
| 962bd5f5df | |||
| 627c11c429 | |||
| 4ba17361e9 | |||
| c946858328 | |||
| ba6e2af5fd | |||
| 2ffe6f7439 | |||
| e3987e21b9 | |||
| a713f54732 | |||
| 519f03097e | |||
| 299c655e39 | |||
| b8c0fb4572 | |||
| d1e172171f | |||
| 81ae6cf78d | |||
| 1120575021 | |||
| 221947acc4 | |||
| 21d8ffca56 | |||
| 41cff3e09e | |||
| b6c4722687 | |||
| 6ea4248bdc | |||
| 88a28212b3 | |||
| 9d0309aedc | |||
| 9a8ce9d3e2 | |||
| 7499608a8b | |||
| 0ebbb60102 | |||
| 80f6d22d2a | |||
| 088b049b4c | |||
| fa9b7b259c | |||
| 14616cf845 | |||
| d2915f6984 | |||
| ccce8beeeb | |||
| 3d2e0f1a1b | |||
| 918d5a9ff8 | |||
| 7d05d4ced7 | |||
| dbdda0fbab | |||
| cf7fdd274b | |||
| 982ed233a2 | |||
| 1f96c95b42 | |||
| 8604c4f57c | |||
| a674338c21 | |||
| 89d82ff031 | |||
| c71d25f744 | |||
| f57f32cf3a | |||
| b6314164c5 | |||
| 856201c0f2 | |||
| 9d8b96c1d0 | |||
| 7c3c185038 | |||
| a9259917c6 | |||
| 8c28587821 | |||
| 12979a3f21 | |||
| 376eb15c63 | |||
| 89ba7abe30 | |||
| 2fd5ac1031 | |||
| 40e84ca41a | |||
| a28c672695 | |||
| 74e0b58d89 | |||
| 7c20c964b4 | |||
| 5d0981d046 | |||
| a793dd2ea8 | |||
| 915e385244 | |||
| 7a344a32f9 | |||
| 8c1ee3845a | |||
| 8c751d5afc | |||
| f5faf0c94f | |||
| af72e8dc33 | |||
| bcd70affb5 | |||
| 6987e9f23b | |||
| 41665b0865 | |||
| d1744aaaf3 | |||
| d5f8548200 | |||
| 4d8698624c | |||
| 1009819801 | |||
| 8fe782f4ea | |||
| 7140950e93 | |||
| 0181747881 | |||
| 3c41159d26 | |||
| e0e1d04da5 | |||
| f0a14f5fce | |||
| 174a2578e8 | |||
| a0959b9d38 | |||
| 13299197b8 | |||
| 249296e417 | |||
| db0f6840d9 | |||
| 1033a3ae26 | |||
| 1845daf41f | |||
| 4c8f9f0d77 | |||
| cc00c3ec93 | |||
| 653b785958 | |||
| 971c1bcba7 | |||
| 065917bf1c | |||
| 820934fc77 | |||
| d3d2ccc76c | |||
| c8ab9079b3 | |||
| 0d5589bfda | |||
| b846a0f547 | |||
| 69578ebfce | |||
| 06cef71ba6 | |||
| d2b1da0e26 | |||
| 7c6d30f4c8 | |||
| ea0352ee4a | |||
| fa5cf10f56 | |||
| 3fe71ab7dd | |||
| 9f715d6bc2 | |||
| 48de3b26ba | |||
| 273c4bc4d3 | |||
| 420c97199a | |||
| ecf0322165 | |||
| 38234aca53 | |||
| 1c06ec39ca | |||
| cfdccebb17 | |||
| 980a883033 | |||
| 02d429f0ca | |||
| 9c24d5d44a | |||
| 0cc5d7a8a6 | |||
| c43bf1dcf5 | |||
| f76b8279dd | |||
| db5ec89dc5 | |||
| 1c201c4d54 | |||
| ba78d0f0c2 | |||
| add8c63458 | |||
| 83661efdaf | |||
| 971197d595 | |||
| 0884e9a4d9 | |||
| 2de42f00b8 | |||
| e8fe580d7a | |||
| 62505164d5 | |||
| d1dcf3b43c | |||
| f84662d2ee | |||
| 1cb6b7f5dd | |||
| 023f509501 | |||
| 50bc53a1f5 | |||
| 8cd4882596 | |||
| 35e5fade93 | |||
| 4942a23290 | |||
| d1716d865a | |||
| c2b7c305fa | |||
| 341e5904c8 | |||
| ded9bf80c5 | |||
| fea157ba08 | |||
| 0db00f70b2 | |||
| 701761d119 | |||
| 2993fc666b | |||
| 8a6d205df0 | |||
| 912b6b023e | |||
| 89e8818dda | |||
| 1dba6b5bf9 | |||
| 3fcf2ee54c | |||
| d8f413a885 | |||
| 7264fb6978 | |||
| bd4bc57009 | |||
| 0569b50fed | |||
| 6b64641042 | |||
| 9cef3a2625 | |||
| e7e89d3ecb | |||
| 13e212c856 | |||
| 61cf430dbb | |||
| e841b09d63 | |||
| b1a1eedf53 | |||
| 68e3b33ae4 | |||
| cd55f6c1b8 | |||
| 996b5fe14e | |||
| db4fd19c82 | |||
| 12db62b9c7 | |||
| b5f2cf16bc | |||
| e27ff8d3d4 | |||
| 5f59418aba | |||
| 87e69868c0 | |||
| 72c20022f6 | |||
| 3f2472f1b9 | |||
| 1d4d67daf8 | |||
| 7538e218a5 | |||
| 6b52f7df5a | |||
| 63131ec9b2 | |||
| e8f1a245a6 | |||
| 908450509f | |||
| 70a0f081f6 | |||
| 93422fa8cc | |||
| bfc84ba95b | |||
| 871055b0fc | |||
| ba71160b14 | |||
| bd5dda6b10 | |||
| 774563970b | |||
| 83d84e90ed | |||
| 8ef2f79d0a | |||
| 296476ab89 |
1
.github/copilot-instructions.md
vendored
Normal file
1
.github/copilot-instructions.md
vendored
Normal file
@ -0,0 +1 @@
|
||||
Refer to [AGENTS.MD](../AGENTS.md) for all repo instructions.
|
||||
21
.github/workflows/release.yml
vendored
21
.github/workflows/release.yml
vendored
@ -3,11 +3,12 @@ name: release
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 13 * * *' # This schedule runs every 13:00:00Z(21:00:00+08:00)
|
||||
# https://github.com/orgs/community/discussions/26286?utm_source=chatgpt.com#discussioncomment-3251208
|
||||
# "The create event does not support branch filter and tag filter."
|
||||
# The "create tags" trigger is specifically focused on the creation of new tags, while the "push tags" trigger is activated when tags are pushed, including both new tag creations and updates to existing tags.
|
||||
create:
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*" # normal release
|
||||
- "nightly" # the only one mutable tag
|
||||
|
||||
# https://docs.github.com/en/actions/using-jobs/using-concurrency
|
||||
concurrency:
|
||||
@ -21,9 +22,9 @@ jobs:
|
||||
- name: Ensure workspace ownership
|
||||
run: echo "chown -R ${USER} ${GITHUB_WORKSPACE}" && sudo chown -R ${USER} ${GITHUB_WORKSPACE}
|
||||
|
||||
# https://github.com/actions/checkout/blob/v3/README.md
|
||||
# https://github.com/actions/checkout/blob/v6/README.md
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }} # Use the secret as an environment variable
|
||||
fetch-depth: 0
|
||||
@ -31,12 +32,12 @@ jobs:
|
||||
|
||||
- name: Prepare release body
|
||||
run: |
|
||||
if [[ ${GITHUB_EVENT_NAME} == "create" ]]; then
|
||||
if [[ ${GITHUB_EVENT_NAME} != "schedule" ]]; then
|
||||
RELEASE_TAG=${GITHUB_REF#refs/tags/}
|
||||
if [[ ${RELEASE_TAG} == "nightly" ]]; then
|
||||
PRERELEASE=true
|
||||
else
|
||||
if [[ ${RELEASE_TAG} == v* ]]; then
|
||||
PRERELEASE=false
|
||||
else
|
||||
PRERELEASE=true
|
||||
fi
|
||||
echo "Workflow triggered by create tag: ${RELEASE_TAG}"
|
||||
else
|
||||
@ -55,7 +56,7 @@ jobs:
|
||||
git fetch --tags
|
||||
if [[ ${GITHUB_EVENT_NAME} == "schedule" ]]; then
|
||||
# Determine if a given tag exists and matches a specific Git commit.
|
||||
# actions/checkout@v4 fetch-tags doesn't work when triggered by schedule
|
||||
# actions/checkout@v6 fetch-tags doesn't work when triggered by schedule
|
||||
if [ "$(git rev-parse -q --verify "refs/tags/${RELEASE_TAG}")" = "${GITHUB_SHA}" ]; then
|
||||
echo "mutable tag ${RELEASE_TAG} exists and matches ${GITHUB_SHA}"
|
||||
else
|
||||
@ -88,7 +89,7 @@ jobs:
|
||||
- name: Build and push image
|
||||
run: |
|
||||
sudo docker login --username infiniflow --password-stdin <<< ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
sudo docker build --build-arg NEED_MIRROR=1 -t infiniflow/ragflow:${RELEASE_TAG} -f Dockerfile .
|
||||
sudo docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -t infiniflow/ragflow:${RELEASE_TAG} -f Dockerfile .
|
||||
sudo docker tag infiniflow/ragflow:${RELEASE_TAG} infiniflow/ragflow:latest
|
||||
sudo docker push infiniflow/ragflow:${RELEASE_TAG}
|
||||
sudo docker push infiniflow/ragflow:latest
|
||||
|
||||
59
.github/workflows/tests.yml
vendored
59
.github/workflows/tests.yml
vendored
@ -1,4 +1,6 @@
|
||||
name: tests
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
push:
|
||||
@ -12,7 +14,7 @@ on:
|
||||
# The only difference between pull_request and pull_request_target is the context in which the workflow runs:
|
||||
# — pull_request_target workflows use the workflow files from the default branch, and secrets are available.
|
||||
# — pull_request workflows use the workflow files from the pull request branch, and secrets are unavailable.
|
||||
pull_request_target:
|
||||
pull_request:
|
||||
types: [ synchronize, ready_for_review ]
|
||||
paths-ignore:
|
||||
- 'docs/**'
|
||||
@ -31,12 +33,9 @@ jobs:
|
||||
name: ragflow_tests
|
||||
# https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution
|
||||
# https://github.com/orgs/community/discussions/26261
|
||||
if: ${{ github.event_name != 'pull_request_target' || contains(github.event.pull_request.labels.*.name, 'ci') }}
|
||||
if: ${{ github.event_name != 'pull_request' || (github.event.pull_request.draft == false && contains(github.event.pull_request.labels.*.name, 'ci')) }}
|
||||
runs-on: [ "self-hosted", "ragflow-test" ]
|
||||
steps:
|
||||
# https://github.com/hmarr/debug-action
|
||||
#- uses: hmarr/debug-action@v2
|
||||
|
||||
- name: Ensure workspace ownership
|
||||
run: |
|
||||
echo "Workflow triggered by ${{ github.event_name }}"
|
||||
@ -44,7 +43,7 @@ jobs:
|
||||
|
||||
# https://github.com/actions/checkout/issues/1781
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ (github.event_name == 'pull_request' || github.event_name == 'pull_request_target') && format('refs/pull/{0}/merge', github.event.pull_request.number) || github.sha }}
|
||||
fetch-depth: 0
|
||||
@ -53,7 +52,7 @@ jobs:
|
||||
- name: Check workflow duplication
|
||||
if: ${{ !cancelled() && !failure() }}
|
||||
run: |
|
||||
if [[ ${GITHUB_EVENT_NAME} != "pull_request_target" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then
|
||||
if [[ ${GITHUB_EVENT_NAME} != "pull_request" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then
|
||||
HEAD=$(git rev-parse HEAD)
|
||||
# Find a PR that introduced a given commit
|
||||
gh auth login --with-token <<< "${{ secrets.GITHUB_TOKEN }}"
|
||||
@ -78,7 +77,7 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
elif [[ ${GITHUB_EVENT_NAME} == "pull_request_target" ]]; then
|
||||
elif [[ ${GITHUB_EVENT_NAME} == "pull_request" ]]; then
|
||||
PR_NUMBER=${{ github.event.pull_request.number }}
|
||||
PR_SHA_FP=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/PR_${PR_NUMBER}
|
||||
# Calculate the hash of the current workspace content
|
||||
@ -95,13 +94,53 @@ jobs:
|
||||
version: ">=0.11.x"
|
||||
args: "check"
|
||||
|
||||
- name: Check comments of changed Python files
|
||||
if: ${{ false }}
|
||||
run: |
|
||||
if [[ ${{ github.event_name }} == 'pull_request' || ${{ github.event_name }} == 'pull_request_target' ]]; then
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \
|
||||
| grep -E '\.(py)$' || true)
|
||||
|
||||
if [ -n "$CHANGED_FILES" ]; then
|
||||
echo "Check comments of changed Python files with check_comment_ascii.py"
|
||||
|
||||
readarray -t files <<< "$CHANGED_FILES"
|
||||
HAS_ERROR=0
|
||||
|
||||
for file in "${files[@]}"; do
|
||||
if [ -f "$file" ]; then
|
||||
if python3 check_comment_ascii.py "$file"; then
|
||||
echo "✅ $file"
|
||||
else
|
||||
echo "❌ $file"
|
||||
HAS_ERROR=1
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $HAS_ERROR -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "No Python files changed"
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Run unit test
|
||||
run: |
|
||||
uv sync --python 3.12 --group test --frozen
|
||||
source .venv/bin/activate
|
||||
which pytest || echo "pytest not in PATH"
|
||||
echo "Start to run unit test"
|
||||
python3 run_tests.py
|
||||
|
||||
- name: Build ragflow:nightly
|
||||
run: |
|
||||
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}
|
||||
RAGFLOW_IMAGE=infiniflow/ragflow:${GITHUB_RUN_ID}
|
||||
echo "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> ${GITHUB_ENV}
|
||||
sudo docker pull ubuntu:22.04
|
||||
sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 -f Dockerfile -t ${RAGFLOW_IMAGE} .
|
||||
sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -f Dockerfile -t ${RAGFLOW_IMAGE} .
|
||||
if [[ ${GITHUB_EVENT_NAME} == "schedule" ]]; then
|
||||
export HTTP_API_TEST_LEVEL=p3
|
||||
else
|
||||
@ -161,7 +200,7 @@ jobs:
|
||||
echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> ${GITHUB_ENV}
|
||||
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
|
||||
uv sync --python 3.10 --only-group test --no-default-groups --frozen && uv pip install sdk/python
|
||||
uv sync --python 3.12 --only-group test --no-default-groups --frozen && uv pip install sdk/python --group test
|
||||
|
||||
- name: Run sdk tests against Elasticsearch
|
||||
run: |
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -195,3 +195,6 @@ ragflow_cli.egg-info
|
||||
|
||||
# Default backup dir
|
||||
backup
|
||||
|
||||
|
||||
.hypothesis
|
||||
110
AGENTS.md
Normal file
110
AGENTS.md
Normal file
@ -0,0 +1,110 @@
|
||||
# RAGFlow Project Instructions for GitHub Copilot
|
||||
|
||||
This file provides context, build instructions, and coding standards for the RAGFlow project.
|
||||
It is structured to follow GitHub Copilot's [customization guidelines](https://docs.github.com/en/copilot/concepts/prompting/response-customization).
|
||||
|
||||
## 1. Project Overview
|
||||
RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on deep document understanding. It is a full-stack application with a Python backend and a React/TypeScript frontend.
|
||||
|
||||
- **Backend**: Python 3.10+ (Flask/Quart)
|
||||
- **Frontend**: TypeScript, React, UmiJS
|
||||
- **Architecture**: Microservices based on Docker.
|
||||
- `api/`: Backend API server.
|
||||
- `rag/`: Core RAG logic (indexing, retrieval).
|
||||
- `deepdoc/`: Document parsing and OCR.
|
||||
- `web/`: Frontend application.
|
||||
|
||||
## 2. Directory Structure
|
||||
- `api/`: Backend API server (Flask/Quart).
|
||||
- `apps/`: API Blueprints (Knowledge Base, Chat, etc.).
|
||||
- `db/`: Database models and services.
|
||||
- `rag/`: Core RAG logic.
|
||||
- `llm/`: LLM, Embedding, and Rerank model abstractions.
|
||||
- `deepdoc/`: Document parsing and OCR modules.
|
||||
- `agent/`: Agentic reasoning components.
|
||||
- `web/`: Frontend application (React + UmiJS).
|
||||
- `docker/`: Docker deployment configurations.
|
||||
- `sdk/`: Python SDK.
|
||||
- `test/`: Backend tests.
|
||||
|
||||
## 3. Build Instructions
|
||||
|
||||
### Backend (Python)
|
||||
The project uses **uv** for dependency management.
|
||||
|
||||
1. **Setup Environment**:
|
||||
```bash
|
||||
uv sync --python 3.12 --all-extras
|
||||
uv run download_deps.py
|
||||
```
|
||||
|
||||
2. **Run Server**:
|
||||
- **Pre-requisite**: Start dependent services (MySQL, ES/Infinity, Redis, MinIO).
|
||||
```bash
|
||||
docker compose -f docker/docker-compose-base.yml up -d
|
||||
```
|
||||
- **Launch**:
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
export PYTHONPATH=$(pwd)
|
||||
bash docker/launch_backend_service.sh
|
||||
```
|
||||
|
||||
### Frontend (TypeScript/React)
|
||||
Located in `web/`.
|
||||
|
||||
1. **Install Dependencies**:
|
||||
```bash
|
||||
cd web
|
||||
npm install
|
||||
```
|
||||
|
||||
2. **Run Dev Server**:
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
Runs on port 8000 by default.
|
||||
|
||||
### Docker Deployment
|
||||
To run the full stack using Docker:
|
||||
```bash
|
||||
cd docker
|
||||
docker compose -f docker-compose.yml up -d
|
||||
```
|
||||
|
||||
## 4. Testing Instructions
|
||||
|
||||
### Backend Tests
|
||||
- **Run All Tests**:
|
||||
```bash
|
||||
uv run pytest
|
||||
```
|
||||
- **Run Specific Test**:
|
||||
```bash
|
||||
uv run pytest test/test_api.py
|
||||
```
|
||||
|
||||
### Frontend Tests
|
||||
- **Run Tests**:
|
||||
```bash
|
||||
cd web
|
||||
npm run test
|
||||
```
|
||||
|
||||
## 5. Coding Standards & Guidelines
|
||||
- **Python Formatting**: Use `ruff` for linting and formatting.
|
||||
```bash
|
||||
ruff check
|
||||
ruff format
|
||||
```
|
||||
- **Frontend Linting**:
|
||||
```bash
|
||||
cd web
|
||||
npm run lint
|
||||
```
|
||||
- **Pre-commit**: Ensure pre-commit hooks are installed.
|
||||
```bash
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
@ -45,7 +45,7 @@ RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on d
|
||||
### Backend Development
|
||||
```bash
|
||||
# Install Python dependencies
|
||||
uv sync --python 3.10 --all-extras
|
||||
uv sync --python 3.12 --all-extras
|
||||
uv run download_deps.py
|
||||
pre-commit install
|
||||
|
||||
|
||||
46
Dockerfile
46
Dockerfile
@ -1,5 +1,5 @@
|
||||
# base stage
|
||||
FROM ubuntu:22.04 AS base
|
||||
FROM ubuntu:24.04 AS base
|
||||
USER root
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
@ -10,11 +10,10 @@ WORKDIR /ragflow
|
||||
# Copy models downloaded via download_deps.py
|
||||
RUN mkdir -p /ragflow/rag/res/deepdoc /root/.ragflow
|
||||
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \
|
||||
cp /huggingface.co/InfiniFlow/huqie/huqie.txt.trie /ragflow/rag/res/ && \
|
||||
tar --exclude='.*' -cf - \
|
||||
/huggingface.co/InfiniFlow/text_concat_xgb_v1.0 \
|
||||
/huggingface.co/InfiniFlow/deepdoc \
|
||||
| tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc
|
||||
| tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc
|
||||
|
||||
# https://github.com/chrismattmann/tika-python
|
||||
# This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
|
||||
@ -34,34 +33,41 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
# selenium: libatk-bridge2.0-0 chrome-linux64-121-0-6167-85
|
||||
# Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev
|
||||
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt update && \
|
||||
apt --no-install-recommends install -y ca-certificates; \
|
||||
if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
sed -i 's|http://ports.ubuntu.com|http://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \
|
||||
sed -i 's|http://archive.ubuntu.com|http://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \
|
||||
sed -i 's|http://archive.ubuntu.com/ubuntu|https://mirrors.tuna.tsinghua.edu.cn/ubuntu|g' /etc/apt/sources.list.d/ubuntu.sources; \
|
||||
sed -i 's|http://security.ubuntu.com/ubuntu|https://mirrors.tuna.tsinghua.edu.cn/ubuntu|g' /etc/apt/sources.list.d/ubuntu.sources; \
|
||||
fi; \
|
||||
rm -f /etc/apt/apt.conf.d/docker-clean && \
|
||||
echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache && \
|
||||
chmod 1777 /tmp && \
|
||||
apt update && \
|
||||
apt --no-install-recommends install -y ca-certificates && \
|
||||
apt update && \
|
||||
apt install -y libglib2.0-0 libglx-mesa0 libgl1 && \
|
||||
apt install -y pkg-config libicu-dev libgdiplus && \
|
||||
apt install -y default-jdk && \
|
||||
apt install -y libatk-bridge2.0-0 && \
|
||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||
apt install -y libjemalloc-dev && \
|
||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
||||
apt install -y ghostscript
|
||||
apt install -y nginx unzip curl wget git vim less && \
|
||||
apt install -y ghostscript && \
|
||||
apt install -y pandoc && \
|
||||
apt install -y texlive && \
|
||||
apt install -y fonts-freefont-ttf fonts-noto-cjk
|
||||
|
||||
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||
pip3 config set global.trusted-host pypi.tuna.tsinghua.edu.cn; \
|
||||
# Install uv
|
||||
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
|
||||
if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
mkdir -p /etc/uv && \
|
||||
echo "[[index]]" > /etc/uv/uv.toml && \
|
||||
echo 'python-install-mirror = "https://registry.npmmirror.com/-/binary/python-build-standalone/"' > /etc/uv/uv.toml && \
|
||||
echo '[[index]]' >> /etc/uv/uv.toml && \
|
||||
echo 'url = "https://pypi.tuna.tsinghua.edu.cn/simple"' >> /etc/uv/uv.toml && \
|
||||
echo "default = true" >> /etc/uv/uv.toml; \
|
||||
echo 'default = true' >> /etc/uv/uv.toml; \
|
||||
fi; \
|
||||
pipx install uv
|
||||
tar xzf /deps/uv-x86_64-unknown-linux-gnu.tar.gz \
|
||||
&& cp uv-x86_64-unknown-linux-gnu/* /usr/local/bin/ \
|
||||
&& rm -rf uv-x86_64-unknown-linux-gnu \
|
||||
&& uv python install 3.11
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
|
||||
ENV PATH=/root/.local/bin:$PATH
|
||||
@ -77,12 +83,12 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
||||
# A modern version of cargo is needed for the latest version of the Rust compiler.
|
||||
RUN apt update && apt install -y curl build-essential \
|
||||
&& if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
# Use TUNA mirrors for rustup/rust dist files
|
||||
# Use TUNA mirrors for rustup/rust dist files \
|
||||
export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \
|
||||
export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \
|
||||
echo "Using TUNA mirrors for Rustup."; \
|
||||
fi; \
|
||||
# Force curl to use HTTP/1.1
|
||||
# Force curl to use HTTP/1.1 \
|
||||
curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \
|
||||
&& echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc
|
||||
|
||||
@ -99,10 +105,10 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt update && \
|
||||
arch="$(uname -m)"; \
|
||||
if [ "$arch" = "arm64" ] || [ "$arch" = "aarch64" ]; then \
|
||||
# ARM64 (macOS/Apple Silicon or Linux aarch64)
|
||||
# ARM64 (macOS/Apple Silicon or Linux aarch64) \
|
||||
ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql18; \
|
||||
else \
|
||||
# x86_64 or others
|
||||
# x86_64 or others \
|
||||
ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql17; \
|
||||
fi || \
|
||||
{ echo "Failed to install ODBC driver"; exit 1; }
|
||||
@ -146,7 +152,7 @@ RUN --mount=type=cache,id=ragflow_uv,target=/root/.cache/uv,sharing=locked \
|
||||
else \
|
||||
sed -i 's|pypi.tuna.tsinghua.edu.cn|pypi.org|g' uv.lock; \
|
||||
fi; \
|
||||
uv sync --python 3.10 --frozen
|
||||
uv sync --python 3.12 --frozen
|
||||
|
||||
COPY web web
|
||||
COPY docs docs
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
FROM scratch
|
||||
|
||||
# Copy resources downloaded via download_deps.py
|
||||
COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.0.0.jar tika-server-standard-3.0.0.jar.md5 libssl*.deb /
|
||||
COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.0.0.jar tika-server-standard-3.0.0.jar.md5 libssl*.deb uv-x86_64-unknown-linux-gnu.tar.gz /
|
||||
|
||||
COPY nltk_data /nltk_data
|
||||
|
||||
|
||||
25
README.md
25
README.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -85,7 +85,8 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
## 🔥 Latest Updates
|
||||
|
||||
- 2025-11-12 Supports data synchronization from Confluence, AWS S3, Discord, Google Drive.
|
||||
- 2025-11-19 Supports Gemini 3 Pro.
|
||||
- 2025-11-12 Supports data synchronization from Confluence, S3, Notion, Discord, Google Drive.
|
||||
- 2025-10-23 Supports MinerU & Docling as document parsing methods.
|
||||
- 2025-10-15 Supports orchestrable ingestion pipeline.
|
||||
- 2025-08-08 Supports OpenAI's latest GPT-5 series models.
|
||||
@ -93,8 +94,6 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
- 2025-05-23 Adds a Python/JavaScript code executor component to Agent.
|
||||
- 2025-05-05 Supports cross-language query.
|
||||
- 2025-03-19 Supports using a multi-modal model to make sense of images within PDF or DOCX files.
|
||||
- 2024-12-18 Upgrades Document Layout Analysis model in DeepDoc.
|
||||
- 2024-08-22 Support text to SQL statements through RAG.
|
||||
|
||||
## 🎉 Stay Tuned
|
||||
|
||||
@ -188,13 +187,15 @@ releases! 🌟
|
||||
> All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64.
|
||||
> If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system.
|
||||
|
||||
> The command below downloads the `v0.22.0` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.22.0`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
|
||||
> The command below downloads the `v0.22.1` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.22.1`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# git checkout v0.22.1
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
|
||||
# This step ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
@ -205,10 +206,10 @@ releases! 🌟
|
||||
|
||||
> Note: Prior to `v0.22.0`, we provided both images with embedding models and slim images without embedding models. Details as follows:
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
|-------------------|-----------------|-----------------------|----------------|
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
|
||||
> Starting with `v0.22.0`, we ship only the slim edition and no longer append the **-slim** suffix to the image tag.
|
||||
|
||||
@ -313,7 +314,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
||||
```bash
|
||||
git clone https://github.com/infiniflow/ragflow.git
|
||||
cd ragflow/
|
||||
uv sync --python 3.10 # install RAGFlow dependent python modules
|
||||
uv sync --python 3.12 # install RAGFlow dependent python modules
|
||||
uv run download_deps.py
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
21
README_id.md
21
README_id.md
@ -22,7 +22,7 @@
|
||||
<img alt="Lencana Daring" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Rilis%20Terbaru" alt="Rilis Terbaru">
|
||||
@ -85,7 +85,8 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
## 🔥 Pembaruan Terbaru
|
||||
|
||||
- 2025-11-12 Mendukung sinkronisasi data dari Confluence, AWS S3, Discord, Google Drive.
|
||||
- 2025-11-19 Mendukung Gemini 3 Pro.
|
||||
- 2025-11-12 Mendukung sinkronisasi data dari Confluence, S3, Notion, Discord, Google Drive.
|
||||
- 2025-10-23 Mendukung MinerU & Docling sebagai metode penguraian dokumen.
|
||||
- 2025-10-15 Dukungan untuk jalur data yang terorkestrasi.
|
||||
- 2025-08-08 Mendukung model seri GPT-5 terbaru dari OpenAI.
|
||||
@ -186,12 +187,14 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
> Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64.
|
||||
> Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image).
|
||||
|
||||
> Perintah di bawah ini mengunduh edisi v0.22.0 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.22.0, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
|
||||
> Perintah di bawah ini mengunduh edisi v0.22.1 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.22.1, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases), contoh: git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases)
|
||||
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
@ -203,10 +206,10 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
> Catatan: Sebelum `v0.22.0`, kami menyediakan image dengan model embedding dan image slim tanpa model embedding. Detailnya sebagai berikut:
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
|-------------------|-----------------|-----------------------|----------------|
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
|
||||
> Mulai dari `v0.22.0`, kami hanya menyediakan edisi slim dan tidak lagi menambahkan akhiran **-slim** pada tag image.
|
||||
|
||||
@ -285,7 +288,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
||||
```bash
|
||||
git clone https://github.com/infiniflow/ragflow.git
|
||||
cd ragflow/
|
||||
uv sync --python 3.10 # install RAGFlow dependent python modules
|
||||
uv sync --python 3.12 # install RAGFlow dependent python modules
|
||||
uv run download_deps.py
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
21
README_ja.md
21
README_ja.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -66,7 +66,8 @@
|
||||
|
||||
## 🔥 最新情報
|
||||
|
||||
- 2025-11-12 Confluence、AWS S3、Discord、Google Drive からのデータ同期をサポートします。
|
||||
- 2025-11-19 Gemini 3 Proをサポートしています
|
||||
- 2025-11-12 Confluence、S3、Notion、Discord、Google Drive からのデータ同期をサポートします。
|
||||
- 2025-10-23 ドキュメント解析方法として MinerU と Docling をサポートします。
|
||||
- 2025-10-15 オーケストレーションされたデータパイプラインのサポート。
|
||||
- 2025-08-08 OpenAI の最新 GPT-5 シリーズモデルをサポートします。
|
||||
@ -166,12 +167,14 @@
|
||||
> 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。
|
||||
> ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。
|
||||
|
||||
> 以下のコマンドは、RAGFlow Docker イメージの v0.22.0 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.22.0 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
|
||||
> 以下のコマンドは、RAGFlow Docker イメージの v0.22.1 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.22.1 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) 例: git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases)
|
||||
# この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
@ -183,10 +186,10 @@
|
||||
|
||||
> 注意:`v0.22.0` より前のバージョンでは、embedding モデルを含むイメージと、embedding モデルを含まない slim イメージの両方を提供していました。詳細は以下の通りです:
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
|-------------------|-----------------|-----------------------|----------------|
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
|
||||
> `v0.22.0` 以降、当プロジェクトでは slim エディションのみを提供し、イメージタグに **-slim** サフィックスを付けなくなりました。
|
||||
|
||||
@ -285,7 +288,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
||||
```bash
|
||||
git clone https://github.com/infiniflow/ragflow.git
|
||||
cd ragflow/
|
||||
uv sync --python 3.10 # install RAGFlow dependent python modules
|
||||
uv sync --python 3.12 # install RAGFlow dependent python modules
|
||||
uv run download_deps.py
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
21
README_ko.md
21
README_ko.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -67,7 +67,8 @@
|
||||
|
||||
## 🔥 업데이트
|
||||
|
||||
- 2025-11-12 Confluence, AWS S3, Discord, Google Drive에서 데이터 동기화를 지원합니다.
|
||||
- 2025-11-19 Gemini 3 Pro를 지원합니다.
|
||||
- 2025-11-12 Confluence, S3, Notion, Discord, Google Drive에서 데이터 동기화를 지원합니다.
|
||||
- 2025-10-23 문서 파싱 방법으로 MinerU 및 Docling을 지원합니다.
|
||||
- 2025-10-15 조정된 데이터 파이프라인 지원.
|
||||
- 2025-08-08 OpenAI의 최신 GPT-5 시리즈 모델을 지원합니다.
|
||||
@ -168,12 +169,14 @@
|
||||
> 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다.
|
||||
> ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image).
|
||||
|
||||
> 아래 명령어는 RAGFlow Docker 이미지의 v0.22.0 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.22.0과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
|
||||
> 아래 명령어는 RAGFlow Docker 이미지의 v0.22.1 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.22.1과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
|
||||
# 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
@ -185,10 +188,10 @@
|
||||
|
||||
> 참고: `v0.22.0` 이전 버전에서는 embedding 모델이 포함된 이미지와 embedding 모델이 포함되지 않은 slim 이미지를 모두 제공했습니다. 자세한 내용은 다음과 같습니다:
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
|-------------------|-----------------|-----------------------|----------------|
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
|
||||
> `v0.22.0`부터는 slim 에디션만 배포하며 이미지 태그에 **-slim** 접미사를 더 이상 붙이지 않습니다.
|
||||
|
||||
@ -280,7 +283,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
||||
```bash
|
||||
git clone https://github.com/infiniflow/ragflow.git
|
||||
cd ragflow/
|
||||
uv sync --python 3.10 # install RAGFlow dependent python modules
|
||||
uv sync --python 3.12 # install RAGFlow dependent python modules
|
||||
uv run download_deps.py
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
<img alt="Badge Estático" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Última%20Relese" alt="Última Versão">
|
||||
@ -86,7 +86,8 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
## 🔥 Últimas Atualizações
|
||||
|
||||
- 12-11-2025 Suporta a sincronização de dados do Confluence, AWS S3, Discord e Google Drive.
|
||||
- 19-11-2025 Suporta Gemini 3 Pro.
|
||||
- 12-11-2025 Suporta a sincronização de dados do Confluence, S3, Notion, Discord e Google Drive.
|
||||
- 23-10-2025 Suporta MinerU e Docling como métodos de análise de documentos.
|
||||
- 15-10-2025 Suporte para pipelines de dados orquestrados.
|
||||
- 08-08-2025 Suporta a mais recente série GPT-5 da OpenAI.
|
||||
@ -186,12 +187,14 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
> Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64.
|
||||
> Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema.
|
||||
|
||||
> O comando abaixo baixa a edição`v0.22.0` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.22.0`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
|
||||
> O comando abaixo baixa a edição`v0.22.1` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.22.1`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases), ex.: git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases)
|
||||
# Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
@ -203,10 +206,10 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
> Nota: Antes da `v0.22.0`, fornecíamos imagens com modelos de embedding e imagens slim sem modelos de embedding. Detalhes a seguir:
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
|-------------------|-----------------|-----------------------|----------------|
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
|
||||
> A partir da `v0.22.0`, distribuímos apenas a edição slim e não adicionamos mais o sufixo **-slim** às tags das imagens.
|
||||
|
||||
@ -302,7 +305,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
||||
```bash
|
||||
git clone https://github.com/infiniflow/ragflow.git
|
||||
cd ragflow/
|
||||
uv sync --python 3.10 # instala os módulos Python dependentes do RAGFlow
|
||||
uv sync --python 3.12 # instala os módulos Python dependentes do RAGFlow
|
||||
uv run download_deps.py
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -85,7 +85,8 @@
|
||||
|
||||
## 🔥 近期更新
|
||||
|
||||
- 2025-11-12 支援從 Confluence、AWS S3、Discord、Google Drive 進行資料同步。
|
||||
- 2025-11-19 支援 Gemini 3 Pro.
|
||||
- 2025-11-12 支援從 Confluence、S3、Notion、Discord、Google Drive 進行資料同步。
|
||||
- 2025-10-23 支援 MinerU 和 Docling 作為文件解析方法。
|
||||
- 2025-10-15 支援可編排的資料管道。
|
||||
- 2025-08-08 支援 OpenAI 最新的 GPT-5 系列模型。
|
||||
@ -185,12 +186,14 @@
|
||||
> 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。
|
||||
> 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。
|
||||
|
||||
> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.22.0`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.22.0` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
|
||||
> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.22.1`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.22.1` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases),例:git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases)
|
||||
# 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
@ -202,10 +205,10 @@
|
||||
|
||||
> 注意:在 `v0.22.0` 之前的版本,我們會同時提供包含 embedding 模型的映像和不含 embedding 模型的 slim 映像。具體如下:
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
|-------------------|-----------------|-----------------------|----------------|
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
|
||||
> 從 `v0.22.0` 開始,我們只發佈 slim 版本,並且不再在映像標籤後附加 **-slim** 後綴。
|
||||
|
||||
@ -312,7 +315,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
||||
```bash
|
||||
git clone https://github.com/infiniflow/ragflow.git
|
||||
cd ragflow/
|
||||
uv sync --python 3.10 # install RAGFlow dependent python modules
|
||||
uv sync --python 3.12 # install RAGFlow dependent python modules
|
||||
uv run download_deps.py
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
21
README_zh.md
21
README_zh.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -85,7 +85,8 @@
|
||||
|
||||
## 🔥 近期更新
|
||||
|
||||
- 2025-11-12 支持从 Confluence、AWS S3、Discord、Google Drive 进行数据同步。
|
||||
- 2025-11-19 支持 Gemini 3 Pro.
|
||||
- 2025-11-12 支持从 Confluence、S3、Notion、Discord、Google Drive 进行数据同步。
|
||||
- 2025-10-23 支持 MinerU 和 Docling 作为文档解析方法。
|
||||
- 2025-10-15 支持可编排的数据管道。
|
||||
- 2025-08-08 支持 OpenAI 最新的 GPT-5 系列模型。
|
||||
@ -186,12 +187,14 @@
|
||||
> 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。
|
||||
> 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。
|
||||
|
||||
> 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.22.0`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.22.0` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
|
||||
> 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.22.1`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.22.1` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases),例如:git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases)
|
||||
# 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
@ -203,10 +206,10 @@
|
||||
|
||||
> 注意:在 `v0.22.0` 之前的版本,我们会同时提供包含 embedding 模型的镜像和不含 embedding 模型的 slim 镜像。具体如下:
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
|-------------------|-----------------|-----------------------|----------------|
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
|
||||
> 从 `v0.22.0` 开始,我们只发布 slim 版本,并且不再在镜像标签后附加 **-slim** 后缀。
|
||||
|
||||
@ -312,7 +315,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
||||
```bash
|
||||
git clone https://github.com/infiniflow/ragflow.git
|
||||
cd ragflow/
|
||||
uv sync --python 3.10 # install RAGFlow dependent python modules
|
||||
uv sync --python 3.12 # install RAGFlow dependent python modules
|
||||
uv run download_deps.py
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
@ -6,8 +6,8 @@ Use this section to tell people about which versions of your project are
|
||||
currently being supported with security updates.
|
||||
|
||||
| Version | Supported |
|
||||
| ------- | ------------------ |
|
||||
| <=0.7.0 | :white_check_mark: |
|
||||
|---------|--------------------|
|
||||
| <=0.7.0 | :white_check_mark: |
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
Admin Service is a dedicated management component designed to monitor, maintain, and administrate the RAGFlow system. It provides comprehensive tools for ensuring system stability, performing operational tasks, and managing users and permissions efficiently.
|
||||
|
||||
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
|
||||
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Infinity, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
|
||||
|
||||
For user and system management, it supports listing, creating, modifying, and deleting users and their associated resources like knowledge bases and Agents.
|
||||
|
||||
@ -48,7 +48,7 @@ It consists of a server-side Service and a command-line client (CLI), both imple
|
||||
1. Ensure the Admin Service is running.
|
||||
2. Install ragflow-cli.
|
||||
```bash
|
||||
pip install ragflow-cli==0.22.0
|
||||
pip install ragflow-cli==0.22.1
|
||||
```
|
||||
3. Launch the CLI client:
|
||||
```bash
|
||||
|
||||
@ -351,7 +351,7 @@ class AdminCLI(Cmd):
|
||||
def verify_admin(self, arguments: dict, single_command: bool):
|
||||
self.host = arguments['host']
|
||||
self.port = arguments['port']
|
||||
print(f"Attempt to access ip: {self.host}, port: {self.port}")
|
||||
print("Attempt to access server for admin login")
|
||||
url = f"http://{self.host}:{self.port}/api/v1/admin/login"
|
||||
|
||||
attempt_count = 3
|
||||
@ -378,7 +378,7 @@ class AdminCLI(Cmd):
|
||||
self.session.headers.update({
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': response.headers['Authorization'],
|
||||
'User-Agent': 'RAGFlow-CLI/0.22.0'
|
||||
'User-Agent': 'RAGFlow-CLI/0.22.1'
|
||||
})
|
||||
print("Authentication successful.")
|
||||
return True
|
||||
@ -390,10 +390,12 @@ class AdminCLI(Cmd):
|
||||
print(f"Bad response,status: {response.status_code}, password is wrong")
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
print(f"Can't access {self.host}, port: {self.port}")
|
||||
print("Can't access server for admin login (connection failed)")
|
||||
|
||||
def _format_service_detail_table(self, data):
|
||||
if not any([isinstance(v, list) for v in data.values()]):
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
if not all([isinstance(v, list) for v in data.values()]):
|
||||
# normal table
|
||||
return data
|
||||
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
|
||||
@ -404,7 +406,7 @@ class AdminCLI(Cmd):
|
||||
task_executor_list.append({
|
||||
"task_executor_name": k,
|
||||
**heartbeats[0],
|
||||
})
|
||||
} if heartbeats else {"task_executor_name": k})
|
||||
return task_executor_list
|
||||
|
||||
def _print_table_simple(self, data):
|
||||
@ -415,7 +417,8 @@ class AdminCLI(Cmd):
|
||||
# handle single row data
|
||||
data = [data]
|
||||
|
||||
columns = list(data[0].keys())
|
||||
columns = list(set().union(*(d.keys() for d in data)))
|
||||
columns.sort()
|
||||
col_widths = {}
|
||||
|
||||
def get_string_width(text):
|
||||
@ -671,7 +674,7 @@ class AdminCLI(Cmd):
|
||||
user_name: str = user_name_tree.children[0].strip("'\"")
|
||||
password_tree: Tree = command['password']
|
||||
password: str = password_tree.children[0].strip("'\"")
|
||||
print(f"Alter user: {user_name}, password: {password}")
|
||||
print(f"Alter user: {user_name}, password: ******")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/password'
|
||||
response = self.session.put(url, json={'new_password': encrypt(password)})
|
||||
res_json = response.json()
|
||||
@ -686,7 +689,7 @@ class AdminCLI(Cmd):
|
||||
password_tree: Tree = command['password']
|
||||
password: str = password_tree.children[0].strip("'\"")
|
||||
role: str = command['role']
|
||||
print(f"Create user: {user_name}, password: {password}, role: {role}")
|
||||
print(f"Create user: {user_name}, password: ******, role: {role}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||
response = self.session.post(
|
||||
url,
|
||||
@ -948,7 +951,7 @@ def main():
|
||||
|
||||
args = cli.parse_connection_args(sys.argv)
|
||||
if 'error' in args:
|
||||
print(f"Error: {args['error']}")
|
||||
print("Error: Invalid connection arguments")
|
||||
return
|
||||
|
||||
if 'command' in args:
|
||||
@ -957,7 +960,7 @@ def main():
|
||||
return
|
||||
if cli.verify_admin(args, single_command=True):
|
||||
command: str = args['command']
|
||||
print(f"Run single command: {command}")
|
||||
# print(f"Run single command: {command}")
|
||||
cli.run_single_command(command)
|
||||
else:
|
||||
if cli.verify_admin(args, single_command=False):
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
[project]
|
||||
name = "ragflow-cli"
|
||||
version = "0.22.0"
|
||||
version = "0.22.1"
|
||||
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
|
||||
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
|
||||
license = { text = "Apache License, Version 2.0" }
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.13"
|
||||
requires-python = ">=3.12,<3.15"
|
||||
dependencies = [
|
||||
"requests>=2.30.0,<3.0.0",
|
||||
"beartype>=0.18.5,<0.19.0",
|
||||
"beartype>=0.20.0,<1.0.0",
|
||||
"pycryptodomex>=3.10.0",
|
||||
"lark>=1.1.0",
|
||||
]
|
||||
|
||||
298
admin/client/uv.lock
generated
Normal file
298
admin/client/uv.lock
generated
Normal file
@ -0,0 +1,298 @@
|
||||
version = 1
|
||||
revision = 3
|
||||
requires-python = ">=3.10, <3.13"
|
||||
|
||||
[[package]]
|
||||
name = "beartype"
|
||||
version = "0.22.6"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/88/e2/105ceb1704cb80fe4ab3872529ab7b6f365cf7c74f725e6132d0efcf1560/beartype-0.22.6.tar.gz", hash = "sha256:97fbda69c20b48c5780ac2ca60ce3c1bb9af29b3a1a0216898ffabdd523e48f4", size = 1588975, upload-time = "2025-11-20T04:47:14.736Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/98/c9/ceecc71fe2c9495a1d8e08d44f5f31f5bca1350d5b2e27a4b6265424f59e/beartype-0.22.6-py3-none-any.whl", hash = "sha256:0584bc46a2ea2a871509679278cda992eadde676c01356ab0ac77421f3c9a093", size = 1324807, upload-time = "2025-11-20T04:47:11.837Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2025.11.12"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a2/8c/58f469717fa48465e4a50c014a0400602d3c437d7c0c468e17ada824da3a/certifi-2025.11.12.tar.gz", hash = "sha256:d8ab5478f2ecd78af242878415affce761ca6bc54a22a27e026d7c25357c3316", size = 160538, upload-time = "2025-11-12T02:54:51.517Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/70/7d/9bc192684cea499815ff478dfcdc13835ddf401365057044fb721ec6bddb/certifi-2025.11.12-py3-none-any.whl", hash = "sha256:97de8790030bbd5c2d96b7ec782fc2f7820ef8dba6db909ccf95449f2d062d4b", size = 159438, upload-time = "2025-11-12T02:54:49.735Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "charset-normalizer"
|
||||
version = "3.4.4"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418, upload-time = "2025-10-14T04:42:32.879Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/b8/6d51fc1d52cbd52cd4ccedd5b5b2f0f6a11bbf6765c782298b0f3e808541/charset_normalizer-3.4.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e824f1492727fa856dd6eda4f7cee25f8518a12f3c4a56a74e8095695089cf6d", size = 209709, upload-time = "2025-10-14T04:40:11.385Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/5c/af/1f9d7f7faafe2ddfb6f72a2e07a548a629c61ad510fe60f9630309908fef/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4bd5d4137d500351a30687c2d3971758aac9a19208fc110ccb9d7188fbe709e8", size = 148814, upload-time = "2025-10-14T04:40:13.135Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/79/3d/f2e3ac2bbc056ca0c204298ea4e3d9db9b4afe437812638759db2c976b5f/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:027f6de494925c0ab2a55eab46ae5129951638a49a34d87f4c3eda90f696b4ad", size = 144467, upload-time = "2025-10-14T04:40:14.728Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ec/85/1bf997003815e60d57de7bd972c57dc6950446a3e4ccac43bc3070721856/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f820802628d2694cb7e56db99213f930856014862f3fd943d290ea8438d07ca8", size = 162280, upload-time = "2025-10-14T04:40:16.14Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/3e/8e/6aa1952f56b192f54921c436b87f2aaf7c7a7c3d0d1a765547d64fd83c13/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:798d75d81754988d2565bff1b97ba5a44411867c0cf32b77a7e8f8d84796b10d", size = 159454, upload-time = "2025-10-14T04:40:17.567Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/36/3b/60cbd1f8e93aa25d1c669c649b7a655b0b5fb4c571858910ea9332678558/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d1bb833febdff5c8927f922386db610b49db6e0d4f4ee29601d71e7c2694313", size = 153609, upload-time = "2025-10-14T04:40:19.08Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/64/91/6a13396948b8fd3c4b4fd5bc74d045f5637d78c9675585e8e9fbe5636554/charset_normalizer-3.4.4-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:9cd98cdc06614a2f768d2b7286d66805f94c48cde050acdbbb7db2600ab3197e", size = 151849, upload-time = "2025-10-14T04:40:20.607Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b7/7a/59482e28b9981d105691e968c544cc0df3b7d6133152fb3dcdc8f135da7a/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:077fbb858e903c73f6c9db43374fd213b0b6a778106bc7032446a8e8b5b38b93", size = 151586, upload-time = "2025-10-14T04:40:21.719Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/92/59/f64ef6a1c4bdd2baf892b04cd78792ed8684fbc48d4c2afe467d96b4df57/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:244bfb999c71b35de57821b8ea746b24e863398194a4014e4c76adc2bbdfeff0", size = 145290, upload-time = "2025-10-14T04:40:23.069Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6b/63/3bf9f279ddfa641ffa1962b0db6a57a9c294361cc2f5fcac997049a00e9c/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:64b55f9dce520635f018f907ff1b0df1fdc31f2795a922fb49dd14fbcdf48c84", size = 163663, upload-time = "2025-10-14T04:40:24.17Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ed/09/c9e38fc8fa9e0849b172b581fd9803bdf6e694041127933934184e19f8c3/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:faa3a41b2b66b6e50f84ae4a68c64fcd0c44355741c6374813a800cd6695db9e", size = 151964, upload-time = "2025-10-14T04:40:25.368Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d2/d1/d28b747e512d0da79d8b6a1ac18b7ab2ecfd81b2944c4c710e166d8dd09c/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:6515f3182dbe4ea06ced2d9e8666d97b46ef4c75e326b79bb624110f122551db", size = 161064, upload-time = "2025-10-14T04:40:26.806Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/bb/9a/31d62b611d901c3b9e5500c36aab0ff5eb442043fb3a1c254200d3d397d9/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cc00f04ed596e9dc0da42ed17ac5e596c6ccba999ba6bd92b0e0aef2f170f2d6", size = 155015, upload-time = "2025-10-14T04:40:28.284Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/f3/107e008fa2bff0c8b9319584174418e5e5285fef32f79d8ee6a430d0039c/charset_normalizer-3.4.4-cp310-cp310-win32.whl", hash = "sha256:f34be2938726fc13801220747472850852fe6b1ea75869a048d6f896838c896f", size = 99792, upload-time = "2025-10-14T04:40:29.613Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/eb/66/e396e8a408843337d7315bab30dbf106c38966f1819f123257f5520f8a96/charset_normalizer-3.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:a61900df84c667873b292c3de315a786dd8dac506704dea57bc957bd31e22c7d", size = 107198, upload-time = "2025-10-14T04:40:30.644Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b5/58/01b4f815bf0312704c267f2ccb6e5d42bcc7752340cd487bc9f8c3710597/charset_normalizer-3.4.4-cp310-cp310-win_arm64.whl", hash = "sha256:cead0978fc57397645f12578bfd2d5ea9138ea0fac82b2f63f7f7c6877986a69", size = 100262, upload-time = "2025-10-14T04:40:32.108Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ed/27/c6491ff4954e58a10f69ad90aca8a1b6fe9c5d3c6f380907af3c37435b59/charset_normalizer-3.4.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6e1fcf0720908f200cd21aa4e6750a48ff6ce4afe7ff5a79a90d5ed8a08296f8", size = 206988, upload-time = "2025-10-14T04:40:33.79Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/94/59/2e87300fe67ab820b5428580a53cad894272dbb97f38a7a814a2a1ac1011/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5f819d5fe9234f9f82d75bdfa9aef3a3d72c4d24a6e57aeaebba32a704553aa0", size = 147324, upload-time = "2025-10-14T04:40:34.961Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/07/fb/0cf61dc84b2b088391830f6274cb57c82e4da8bbc2efeac8c025edb88772/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a59cb51917aa591b1c4e6a43c132f0cdc3c76dbad6155df4e28ee626cc77a0a3", size = 142742, upload-time = "2025-10-14T04:40:36.105Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/62/8b/171935adf2312cd745d290ed93cf16cf0dfe320863ab7cbeeae1dcd6535f/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8ef3c867360f88ac904fd3f5e1f902f13307af9052646963ee08ff4f131adafc", size = 160863, upload-time = "2025-10-14T04:40:37.188Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/09/73/ad875b192bda14f2173bfc1bc9a55e009808484a4b256748d931b6948442/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d9e45d7faa48ee908174d8fe84854479ef838fc6a705c9315372eacbc2f02897", size = 157837, upload-time = "2025-10-14T04:40:38.435Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6d/fc/de9cce525b2c5b94b47c70a4b4fb19f871b24995c728e957ee68ab1671ea/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:840c25fb618a231545cbab0564a799f101b63b9901f2569faecd6b222ac72381", size = 151550, upload-time = "2025-10-14T04:40:40.053Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/55/c2/43edd615fdfba8c6f2dfbd459b25a6b3b551f24ea21981e23fb768503ce1/charset_normalizer-3.4.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca5862d5b3928c4940729dacc329aa9102900382fea192fc5e52eb69d6093815", size = 149162, upload-time = "2025-10-14T04:40:41.163Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/03/86/bde4ad8b4d0e9429a4e82c1e8f5c659993a9a863ad62c7df05cf7b678d75/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d9c7f57c3d666a53421049053eaacdd14bbd0a528e2186fcb2e672effd053bb0", size = 150019, upload-time = "2025-10-14T04:40:42.276Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/86/a151eb2af293a7e7bac3a739b81072585ce36ccfb4493039f49f1d3cae8c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:277e970e750505ed74c832b4bf75dac7476262ee2a013f5574dd49075879e161", size = 143310, upload-time = "2025-10-14T04:40:43.439Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b5/fe/43dae6144a7e07b87478fdfc4dbe9efd5defb0e7ec29f5f58a55aeef7bf7/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:31fd66405eaf47bb62e8cd575dc621c56c668f27d46a61d975a249930dd5e2a4", size = 162022, upload-time = "2025-10-14T04:40:44.547Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/80/e6/7aab83774f5d2bca81f42ac58d04caf44f0cc2b65fc6db2b3b2e8a05f3b3/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:0d3d8f15c07f86e9ff82319b3d9ef6f4bf907608f53fe9d92b28ea9ae3d1fd89", size = 149383, upload-time = "2025-10-14T04:40:46.018Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/4f/e8/b289173b4edae05c0dde07f69f8db476a0b511eac556dfe0d6bda3c43384/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:9f7fcd74d410a36883701fafa2482a6af2ff5ba96b9a620e9e0721e28ead5569", size = 159098, upload-time = "2025-10-14T04:40:47.081Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d8/df/fe699727754cae3f8478493c7f45f777b17c3ef0600e28abfec8619eb49c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ebf3e58c7ec8a8bed6d66a75d7fb37b55e5015b03ceae72a8e7c74495551e224", size = 152991, upload-time = "2025-10-14T04:40:48.246Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1a/86/584869fe4ddb6ffa3bd9f491b87a01568797fb9bd8933f557dba9771beaf/charset_normalizer-3.4.4-cp311-cp311-win32.whl", hash = "sha256:eecbc200c7fd5ddb9a7f16c7decb07b566c29fa2161a16cf67b8d068bd21690a", size = 99456, upload-time = "2025-10-14T04:40:49.376Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/65/f6/62fdd5feb60530f50f7e38b4f6a1d5203f4d16ff4f9f0952962c044e919a/charset_normalizer-3.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:5ae497466c7901d54b639cf42d5b8c1b6a4fead55215500d2f486d34db48d016", size = 106978, upload-time = "2025-10-14T04:40:50.844Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/7a/9d/0710916e6c82948b3be62d9d398cb4fcf4e97b56d6a6aeccd66c4b2f2bd5/charset_normalizer-3.4.4-cp311-cp311-win_arm64.whl", hash = "sha256:65e2befcd84bc6f37095f5961e68a6f077bf44946771354a28ad434c2cce0ae1", size = 99969, upload-time = "2025-10-14T04:40:52.272Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f3/85/1637cd4af66fa687396e757dec650f28025f2a2f5a5531a3208dc0ec43f2/charset_normalizer-3.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394", size = 208425, upload-time = "2025-10-14T04:40:53.353Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/9d/6a/04130023fef2a0d9c62d0bae2649b69f7b7d8d24ea5536feef50551029df/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25", size = 148162, upload-time = "2025-10-14T04:40:54.558Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/78/29/62328d79aa60da22c9e0b9a66539feae06ca0f5a4171ac4f7dc285b83688/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef", size = 144558, upload-time = "2025-10-14T04:40:55.677Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/86/bb/b32194a4bf15b88403537c2e120b817c61cd4ecffa9b6876e941c3ee38fe/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d", size = 161497, upload-time = "2025-10-14T04:40:57.217Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/19/89/a54c82b253d5b9b111dc74aca196ba5ccfcca8242d0fb64146d4d3183ff1/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8", size = 159240, upload-time = "2025-10-14T04:40:58.358Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86", size = 153471, upload-time = "2025-10-14T04:40:59.468Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/61/fa/fbf177b55bdd727010f9c0a3c49eefa1d10f960e5f09d1d887bf93c2e698/charset_normalizer-3.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a", size = 150864, upload-time = "2025-10-14T04:41:00.623Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/05/12/9fbc6a4d39c0198adeebbde20b619790e9236557ca59fc40e0e3cebe6f40/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f", size = 150647, upload-time = "2025-10-14T04:41:01.754Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ad/1f/6a9a593d52e3e8c5d2b167daf8c6b968808efb57ef4c210acb907c365bc4/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc", size = 145110, upload-time = "2025-10-14T04:41:03.231Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/30/42/9a52c609e72471b0fc54386dc63c3781a387bb4fe61c20231a4ebcd58bdd/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf", size = 162839, upload-time = "2025-10-14T04:41:04.715Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c4/5b/c0682bbf9f11597073052628ddd38344a3d673fda35a36773f7d19344b23/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15", size = 150667, upload-time = "2025-10-14T04:41:05.827Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/e4/24/a41afeab6f990cf2daf6cb8c67419b63b48cf518e4f56022230840c9bfb2/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9", size = 160535, upload-time = "2025-10-14T04:41:06.938Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/2a/e5/6a4ce77ed243c4a50a1fecca6aaaab419628c818a49434be428fe24c9957/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0", size = 154816, upload-time = "2025-10-14T04:41:08.101Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a8/ef/89297262b8092b312d29cdb2517cb1237e51db8ecef2e9af5edbe7b683b1/charset_normalizer-3.4.4-cp312-cp312-win32.whl", hash = "sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26", size = 99694, upload-time = "2025-10-14T04:41:09.23Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/3d/2d/1e5ed9dd3b3803994c155cd9aacb60c82c331bad84daf75bcb9c91b3295e/charset_normalizer-3.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525", size = 107131, upload-time = "2025-10-14T04:41:10.467Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d0/d9/0ed4c7098a861482a7b6a95603edce4c0d9db2311af23da1fb2b75ec26fc/charset_normalizer-3.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3", size = 100390, upload-time = "2025-10-14T04:41:11.915Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.3.1"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.11"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582, upload-time = "2025-10-12T14:55:20.501Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lark"
|
||||
version = "1.3.1"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/da/34/28fff3ab31ccff1fd4f6c7c7b0ceb2b6968d8ea4950663eadcb5720591a0/lark-1.3.1.tar.gz", hash = "sha256:b426a7a6d6d53189d318f2b6236ab5d6429eaf09259f1ca33eb716eed10d2905", size = 382732, upload-time = "2025-10-27T18:25:56.653Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/82/3d/14ce75ef66813643812f3093ab17e46d3a206942ce7376d31ec2d36229e7/lark-1.3.1-py3-none-any.whl", hash = "sha256:c629b661023a014c37da873b4ff58a817398d12635d3bbb2c5a03be7fe5d1e12", size = 113151, upload-time = "2025-10-27T18:25:54.882Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "25.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.6.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pycryptodomex"
|
||||
version = "3.23.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c9/85/e24bf90972a30b0fcd16c73009add1d7d7cd9140c2498a68252028899e41/pycryptodomex-3.23.0.tar.gz", hash = "sha256:71909758f010c82bc99b0abf4ea12012c98962fbf0583c2164f8b84533c2e4da", size = 4922157, upload-time = "2025-05-17T17:23:41.434Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/dd/9c/1a8f35daa39784ed8adf93a694e7e5dc15c23c741bbda06e1d45f8979e9e/pycryptodomex-3.23.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:06698f957fe1ab229a99ba2defeeae1c09af185baa909a31a5d1f9d42b1aaed6", size = 2499240, upload-time = "2025-05-17T17:22:46.953Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/7a/62/f5221a191a97157d240cf6643747558759126c76ee92f29a3f4aee3197a5/pycryptodomex-3.23.0-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:b2c2537863eccef2d41061e82a881dcabb04944c5c06c5aa7110b577cc487545", size = 1644042, upload-time = "2025-05-17T17:22:49.098Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/8c/fd/5a054543c8988d4ed7b612721d7e78a4b9bf36bc3c5ad45ef45c22d0060e/pycryptodomex-3.23.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:43c446e2ba8df8889e0e16f02211c25b4934898384c1ec1ec04d7889c0333587", size = 2186227, upload-time = "2025-05-17T17:22:51.139Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c8/a9/8862616a85cf450d2822dbd4fff1fcaba90877907a6ff5bc2672cafe42f8/pycryptodomex-3.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f489c4765093fb60e2edafdf223397bc716491b2b69fe74367b70d6999257a5c", size = 2272578, upload-time = "2025-05-17T17:22:53.676Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/46/9f/bda9c49a7c1842820de674ab36c79f4fbeeee03f8ff0e4f3546c3889076b/pycryptodomex-3.23.0-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bdc69d0d3d989a1029df0eed67cc5e8e5d968f3724f4519bd03e0ec68df7543c", size = 2312166, upload-time = "2025-05-17T17:22:56.585Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/03/cc/870b9bf8ca92866ca0186534801cf8d20554ad2a76ca959538041b7a7cf4/pycryptodomex-3.23.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:6bbcb1dd0f646484939e142462d9e532482bc74475cecf9c4903d4e1cd21f003", size = 2185467, upload-time = "2025-05-17T17:22:59.237Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/96/e3/ce9348236d8e669fea5dd82a90e86be48b9c341210f44e25443162aba187/pycryptodomex-3.23.0-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:8a4fcd42ccb04c31268d1efeecfccfd1249612b4de6374205376b8f280321744", size = 2346104, upload-time = "2025-05-17T17:23:02.112Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a5/e9/e869bcee87beb89040263c416a8a50204f7f7a83ac11897646c9e71e0daf/pycryptodomex-3.23.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:55ccbe27f049743a4caf4f4221b166560d3438d0b1e5ab929e07ae1702a4d6fd", size = 2271038, upload-time = "2025-05-17T17:23:04.872Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/8d/67/09ee8500dd22614af5fbaa51a4aee6e342b5fa8aecf0a6cb9cbf52fa6d45/pycryptodomex-3.23.0-cp37-abi3-win32.whl", hash = "sha256:189afbc87f0b9f158386bf051f720e20fa6145975f1e76369303d0f31d1a8d7c", size = 1771969, upload-time = "2025-05-17T17:23:07.115Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/69/96/11f36f71a865dd6df03716d33bd07a67e9d20f6b8d39820470b766af323c/pycryptodomex-3.23.0-cp37-abi3-win_amd64.whl", hash = "sha256:52e5ca58c3a0b0bd5e100a9fbc8015059b05cffc6c66ce9d98b4b45e023443b9", size = 1803124, upload-time = "2025-05-17T17:23:09.267Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/93/45c1cdcbeb182ccd2e144c693eaa097763b08b38cded279f0053ed53c553/pycryptodomex-3.23.0-cp37-abi3-win_arm64.whl", hash = "sha256:02d87b80778c171445d67e23d1caef279bf4b25c3597050ccd2e13970b57fd51", size = 1707161, upload-time = "2025-05-17T17:23:11.414Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f3/b8/3e76d948c3c4ac71335bbe75dac53e154b40b0f8f1f022dfa295257a0c96/pycryptodomex-3.23.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ebfff755c360d674306e5891c564a274a47953562b42fb74a5c25b8fc1fb1cb5", size = 1627695, upload-time = "2025-05-17T17:23:17.38Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6a/cf/80f4297a4820dfdfd1c88cf6c4666a200f204b3488103d027b5edd9176ec/pycryptodomex-3.23.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eca54f4bb349d45afc17e3011ed4264ef1cc9e266699874cdd1349c504e64798", size = 1675772, upload-time = "2025-05-17T17:23:19.202Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d1/42/1e969ee0ad19fe3134b0e1b856c39bd0b70d47a4d0e81c2a8b05727394c9/pycryptodomex-3.23.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f2596e643d4365e14d0879dc5aafe6355616c61c2176009270f3048f6d9a61f", size = 1668083, upload-time = "2025-05-17T17:23:21.867Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6e/c3/1de4f7631fea8a992a44ba632aa40e0008764c0fb9bf2854b0acf78c2cf2/pycryptodomex-3.23.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fdfac7cda115bca3a5abb2f9e43bc2fb66c2b65ab074913643803ca7083a79ea", size = 1706056, upload-time = "2025-05-17T17:23:24.031Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f2/5f/af7da8e6f1e42b52f44a24d08b8e4c726207434e2593732d39e7af5e7256/pycryptodomex-3.23.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:14c37aaece158d0ace436f76a7bb19093db3b4deade9797abfc39ec6cd6cc2fe", size = 1806478, upload-time = "2025-05-17T17:23:26.066Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.19.2"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "9.0.1"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||
{ name = "iniconfig" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pluggy" },
|
||||
{ name = "pygments" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/07/56/f013048ac4bc4c1d9be45afd4ab209ea62822fb1598f40687e6bf45dcea4/pytest-9.0.1.tar.gz", hash = "sha256:3e9c069ea73583e255c3b21cf46b8d3c56f6e3a1a8f6da94ccb0fcf57b9d73c8", size = 1564125, upload-time = "2025-11-12T13:05:09.333Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0b/8b/6300fb80f858cda1c51ffa17075df5d846757081d11ab4aa35cef9e6258b/pytest-9.0.1-py3-none-any.whl", hash = "sha256:67be0030d194df2dfa7b556f2e56fb3c3315bd5c8822c6951162b92b32ce7dad", size = 373668, upload-time = "2025-11-12T13:05:07.379Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ragflow-cli"
|
||||
version = "0.22.1"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "beartype" },
|
||||
{ name = "lark" },
|
||||
{ name = "pycryptodomex" },
|
||||
{ name = "requests" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
test = [
|
||||
{ name = "pytest" },
|
||||
{ name = "requests" },
|
||||
{ name = "requests-toolbelt" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "beartype", specifier = ">=0.20.0,<1.0.0" },
|
||||
{ name = "lark", specifier = ">=1.1.0" },
|
||||
{ name = "pycryptodomex", specifier = ">=3.10.0" },
|
||||
{ name = "requests", specifier = ">=2.30.0,<3.0.0" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
test = [
|
||||
{ name = "pytest", specifier = ">=8.3.5" },
|
||||
{ name = "requests", specifier = ">=2.32.3" },
|
||||
{ name = "requests-toolbelt", specifier = ">=1.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.5"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
{ name = "charset-normalizer" },
|
||||
{ name = "idna" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "requests-toolbelt"
|
||||
version = "1.0.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
dependencies = [
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f3/61/d7545dafb7ac2230c70d38d31cbfe4cc64f7144dc41f6e4e4b78ecd9f5bb/requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6", size = 206888, upload-time = "2023-05-01T04:11:33.229Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481, upload-time = "2023-05-01T04:11:28.427Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b6/eb/a7679c8ac85208706d27436e8d421dfa39d4c914dcf5fa8083a9305f58d9/tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456", size = 96445, upload-time = "2025-10-08T22:01:06.896Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a0/b7/a7a7042715d55c9ba6e8b196d65d2cb662578b4d8cd17d882d45322b0d78/tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876", size = 97124, upload-time = "2025-10-08T22:01:15.629Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.15.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.5.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" },
|
||||
]
|
||||
@ -20,8 +20,11 @@ import logging
|
||||
import time
|
||||
import threading
|
||||
import traceback
|
||||
from werkzeug.serving import run_simple
|
||||
import faulthandler
|
||||
|
||||
from flask import Flask
|
||||
from flask_login import LoginManager
|
||||
from werkzeug.serving import run_simple
|
||||
from routes import admin_bp
|
||||
from common.log_utils import init_root_logger
|
||||
from common.constants import SERVICE_CONF
|
||||
@ -30,12 +33,12 @@ from common import settings
|
||||
from config import load_configurations, SERVICE_CONFIGS
|
||||
from auth import init_default_admin, setup_auth
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from common.versions import get_ragflow_version
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
if __name__ == '__main__':
|
||||
faulthandler.enable()
|
||||
init_root_logger("admin_service")
|
||||
logging.info(r"""
|
||||
____ ___ ______________ ___ __ _
|
||||
|
||||
@ -19,7 +19,8 @@ import logging
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from datetime import datetime
|
||||
from flask import request, jsonify
|
||||
|
||||
from flask import jsonify, request
|
||||
from flask_login import current_user, login_user
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
@ -30,7 +31,7 @@ from common.constants import ActiveEnum, StatusEnum
|
||||
from api.utils.crypt import decrypt
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp, datetime_format, get_format_time
|
||||
from common.connection_utils import construct_response
|
||||
from common.connection_utils import sync_construct_response
|
||||
from common import settings
|
||||
|
||||
|
||||
@ -129,7 +130,7 @@ def login_admin(email: str, password: str):
|
||||
user.last_login_time = get_format_time()
|
||||
user.save()
|
||||
msg = "Welcome back!"
|
||||
return construct_response(data=resp, auth=user.get_id(), message=msg)
|
||||
return sync_construct_response(data=resp, auth=user.get_id(), message=msg)
|
||||
|
||||
|
||||
def check_admin(username: str, password: str):
|
||||
@ -169,17 +170,17 @@ def login_verify(f):
|
||||
username = auth.parameters['username']
|
||||
password = auth.parameters['password']
|
||||
try:
|
||||
if check_admin(username, password) is False:
|
||||
if not check_admin(username, password):
|
||||
return jsonify({
|
||||
"code": 500,
|
||||
"message": "Access denied",
|
||||
"data": None
|
||||
}), 200
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
except Exception:
|
||||
logging.exception("An error occurred during admin login verification.")
|
||||
return jsonify({
|
||||
"code": 500,
|
||||
"message": error_msg
|
||||
"message": "An internal server error occurred."
|
||||
}), 200
|
||||
|
||||
return f(*args, **kwargs)
|
||||
|
||||
@ -25,8 +25,21 @@ from common.config_utils import read_config
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
host: str
|
||||
port: int
|
||||
service_type: str
|
||||
detail_func_name: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
|
||||
'service_type': self.service_type}
|
||||
|
||||
|
||||
class ServiceConfigs:
|
||||
configs = dict
|
||||
configs = list[BaseConfig]
|
||||
|
||||
def __init__(self):
|
||||
self.configs = []
|
||||
@ -45,19 +58,6 @@ class ServiceType(Enum):
|
||||
FILE_STORE = "file_store"
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
host: str
|
||||
port: int
|
||||
service_type: str
|
||||
detail_func_name: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
|
||||
'service_type': self.service_type}
|
||||
|
||||
|
||||
class MetaConfig(BaseConfig):
|
||||
meta_type: str
|
||||
|
||||
@ -227,7 +227,7 @@ def load_configurations(config_path: str) -> list[BaseConfig]:
|
||||
ragflow_count = 0
|
||||
id_count = 0
|
||||
for k, v in raw_configs.items():
|
||||
match (k):
|
||||
match k:
|
||||
case "ragflow":
|
||||
name: str = f'ragflow_{ragflow_count}'
|
||||
host: str = v['host']
|
||||
|
||||
@ -13,8 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
from flask import jsonify
|
||||
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
import secrets
|
||||
|
||||
from flask import Blueprint, request
|
||||
from flask_login import current_user, logout_user, login_required
|
||||
from flask_login import current_user, login_required, logout_user
|
||||
|
||||
from auth import login_verify, login_admin, check_admin_auth
|
||||
from responses import success_response, error_response
|
||||
|
||||
@ -13,8 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import logging
|
||||
import re
|
||||
from werkzeug.security import check_password_hash
|
||||
from common.constants import ActiveEnum
|
||||
@ -190,7 +189,8 @@ class ServiceMgr:
|
||||
config_dict['status'] = service_detail['status']
|
||||
else:
|
||||
config_dict['status'] = 'timeout'
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.warning(f"Can't get service details, error: {e}")
|
||||
config_dict['status'] = 'timeout'
|
||||
if not config_dict['host']:
|
||||
config_dict['host'] = '-'
|
||||
@ -205,17 +205,13 @@ class ServiceMgr:
|
||||
|
||||
@staticmethod
|
||||
def get_service_details(service_id: int):
|
||||
service_id = int(service_id)
|
||||
service_idx = int(service_id)
|
||||
configs = SERVICE_CONFIGS.configs
|
||||
service_config_mapping = {
|
||||
c.id: {
|
||||
'name': c.name,
|
||||
'detail_func_name': c.detail_func_name
|
||||
} for c in configs
|
||||
}
|
||||
service_info = service_config_mapping.get(service_id, {})
|
||||
if not service_info:
|
||||
raise AdminException(f"invalid service_id: {service_id}")
|
||||
if service_idx < 0 or service_idx >= len(configs):
|
||||
raise AdminException(f"invalid service_index: {service_idx}")
|
||||
|
||||
service_config = configs[service_idx]
|
||||
service_info = {'name': service_config.name, 'detail_func_name': service_config.detail_func_name}
|
||||
|
||||
detail_func = getattr(health_utils, service_info.get('detail_func_name'))
|
||||
res = detail_func()
|
||||
|
||||
@ -13,6 +13,3 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from beartype.claw import beartype_this_package
|
||||
beartype_this_package()
|
||||
|
||||
293
agent/canvas.py
293
agent/canvas.py
@ -13,7 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import base64
|
||||
import inspect
|
||||
import binascii
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@ -26,7 +29,9 @@ from typing import Any, Union, Tuple
|
||||
from agent.component import component_class
|
||||
from agent.component.base import ComponentBase
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.constants import LLMType
|
||||
from common.misc_utils import get_uuid, hash_str2int
|
||||
from common.exceptions import TaskCanceledException
|
||||
from rag.prompts.generator import chunks_format
|
||||
@ -80,14 +85,12 @@ class Graph:
|
||||
self.dsl = json.loads(dsl)
|
||||
self._tenant_id = tenant_id
|
||||
self.task_id = task_id if task_id else get_uuid()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=5)
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
self.components = self.dsl["components"]
|
||||
cpn_nms = set([])
|
||||
for k, cpn in self.components.items():
|
||||
cpn_nms.add(cpn["obj"]["component_name"])
|
||||
|
||||
for k, cpn in self.components.items():
|
||||
cpn_nms.add(cpn["obj"]["component_name"])
|
||||
param = component_class(cpn["obj"]["component_name"] + "Param")()
|
||||
@ -157,7 +160,7 @@ class Graph:
|
||||
return self._tenant_id
|
||||
|
||||
def get_value_with_variable(self,value: str) -> Any:
|
||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
||||
out_parts = []
|
||||
last = 0
|
||||
|
||||
@ -207,17 +210,60 @@ class Graph:
|
||||
for key in path.split('.'):
|
||||
if cur is None:
|
||||
return None
|
||||
|
||||
if isinstance(cur, str):
|
||||
try:
|
||||
cur = json.loads(cur)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if isinstance(cur, dict):
|
||||
cur = cur.get(key)
|
||||
else:
|
||||
cur = getattr(cur, key, None)
|
||||
continue
|
||||
|
||||
if isinstance(cur, (list, tuple)):
|
||||
try:
|
||||
idx = int(key)
|
||||
cur = cur[idx]
|
||||
except Exception:
|
||||
return None
|
||||
continue
|
||||
|
||||
cur = getattr(cur, key, None)
|
||||
return cur
|
||||
|
||||
def set_variable_value(self, exp: str,value):
|
||||
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
|
||||
if exp.find("@") < 0:
|
||||
self.globals[exp] = value
|
||||
return
|
||||
cpn_id, var_nm = exp.split("@")
|
||||
cpn = self.get_component(cpn_id)
|
||||
if not cpn:
|
||||
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
||||
parts = var_nm.split(".", 1)
|
||||
root_key = parts[0]
|
||||
rest = parts[1] if len(parts) > 1 else ""
|
||||
if not rest:
|
||||
cpn["obj"].set_output(root_key, value)
|
||||
return
|
||||
root_val = cpn["obj"].output(root_key)
|
||||
if not root_val:
|
||||
root_val = {}
|
||||
cpn["obj"].set_output(root_key, self.set_variable_param_value(root_val,rest,value))
|
||||
|
||||
def set_variable_param_value(self, obj: Any, path: str, value) -> Any:
|
||||
cur = obj
|
||||
keys = path.split('.')
|
||||
if not path:
|
||||
return value
|
||||
for key in keys:
|
||||
if key not in cur or not isinstance(cur[key], dict):
|
||||
cur[key] = {}
|
||||
cur = cur[key]
|
||||
cur[keys[-1]] = value
|
||||
return obj
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
return has_canceled(self.task_id)
|
||||
|
||||
@ -239,6 +285,7 @@ class Canvas(Graph):
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
self.variables = {}
|
||||
super().__init__(dsl, tenant_id, task_id)
|
||||
|
||||
def load(self):
|
||||
@ -253,6 +300,10 @@ class Canvas(Graph):
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
if "variables" in self.dsl:
|
||||
self.variables = self.dsl["variables"]
|
||||
else:
|
||||
self.variables = {}
|
||||
|
||||
self.retrieval = self.dsl["retrieval"]
|
||||
self.memory = self.dsl.get("memory", [])
|
||||
@ -269,6 +320,7 @@ class Canvas(Graph):
|
||||
self.history = []
|
||||
self.retrieval = []
|
||||
self.memory = []
|
||||
print(self.variables)
|
||||
for k in self.globals.keys():
|
||||
if k.startswith("sys."):
|
||||
if isinstance(self.globals[k], str):
|
||||
@ -283,9 +335,31 @@ class Canvas(Graph):
|
||||
self.globals[k] = {}
|
||||
else:
|
||||
self.globals[k] = None
|
||||
if k.startswith("env."):
|
||||
key = k[4:]
|
||||
if key in self.variables:
|
||||
variable = self.variables[key]
|
||||
if variable["value"]:
|
||||
self.globals[k] = variable["value"]
|
||||
else:
|
||||
if variable["type"] == "string":
|
||||
self.globals[k] = ""
|
||||
elif variable["type"] == "number":
|
||||
self.globals[k] = 0
|
||||
elif variable["type"] == "boolean":
|
||||
self.globals[k] = False
|
||||
elif variable["type"] == "object":
|
||||
self.globals[k] = {}
|
||||
elif variable["type"].startswith("array"):
|
||||
self.globals[k] = []
|
||||
else:
|
||||
self.globals[k] = ""
|
||||
else:
|
||||
self.globals[k] = ""
|
||||
|
||||
def run(self, **kwargs):
|
||||
async def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self.message_id = get_uuid()
|
||||
created_at = int(time.time())
|
||||
self.add_user_input(kwargs.get("query"))
|
||||
@ -294,16 +368,19 @@ class Canvas(Graph):
|
||||
|
||||
if kwargs.get("webhook_payload"):
|
||||
for k, cpn in self.components.items():
|
||||
if self.components[k]["obj"].component_name.lower() == "webhook":
|
||||
for kk, vv in kwargs["webhook_payload"].items():
|
||||
if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook":
|
||||
payload = kwargs.get("webhook_payload", {})
|
||||
if "input" in payload:
|
||||
self.components[k]["obj"].set_input_value("request", payload["input"])
|
||||
for kk, vv in payload.items():
|
||||
if kk == "input":
|
||||
continue
|
||||
self.components[k]["obj"].set_output(kk, vv)
|
||||
|
||||
self.components[k]["obj"].reset(True)
|
||||
|
||||
for k in kwargs.keys():
|
||||
if k in ["query", "user_id", "files"] and kwargs[k]:
|
||||
if k == "files":
|
||||
self.globals[f"sys.{k}"] = self.get_files(kwargs[k])
|
||||
self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k])
|
||||
else:
|
||||
self.globals[f"sys.{k}"] = kwargs[k]
|
||||
if not self.globals["sys.conversation_turns"] :
|
||||
@ -333,31 +410,50 @@ class Canvas(Graph):
|
||||
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
|
||||
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
|
||||
|
||||
def _run_batch(f, t):
|
||||
async def _run_batch(f, t):
|
||||
if self.is_canceled():
|
||||
msg = f"Task {self.task_id} has been canceled during batch execution."
|
||||
logging.info(msg)
|
||||
raise TaskCanceledException(msg)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
thr = []
|
||||
i = f
|
||||
while i < t:
|
||||
cpn = self.get_component_obj(self.path[i])
|
||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
|
||||
i += 1
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = []
|
||||
|
||||
def _run_async_in_thread(coro_func, **call_kwargs):
|
||||
return asyncio.run(coro_func(**call_kwargs))
|
||||
|
||||
i = f
|
||||
while i < t:
|
||||
cpn = self.get_component_obj(self.path[i])
|
||||
task_fn = None
|
||||
call_kwargs = None
|
||||
|
||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||
call_kwargs = {"inputs": kwargs.get("inputs", {})}
|
||||
task_fn = cpn.invoke
|
||||
i += 1
|
||||
else:
|
||||
for _, ele in cpn.get_input_elements().items():
|
||||
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
|
||||
self.path.pop(i)
|
||||
t -= 1
|
||||
break
|
||||
else:
|
||||
for _, ele in cpn.get_input_elements().items():
|
||||
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
|
||||
self.path.pop(i)
|
||||
t -= 1
|
||||
break
|
||||
else:
|
||||
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
||||
i += 1
|
||||
for t in thr:
|
||||
t.result()
|
||||
call_kwargs = cpn.get_input()
|
||||
task_fn = cpn.invoke
|
||||
i += 1
|
||||
|
||||
if task_fn is None:
|
||||
continue
|
||||
|
||||
invoke_async = getattr(cpn, "invoke_async", None)
|
||||
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
|
||||
else:
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def _node_finished(cpn_obj):
|
||||
return decorate("node_finished",{
|
||||
@ -374,6 +470,7 @@ class Canvas(Graph):
|
||||
self.error = ""
|
||||
idx = len(self.path) - 1
|
||||
partials = []
|
||||
tts_mdl = None
|
||||
while idx < len(self.path):
|
||||
to = len(self.path)
|
||||
for i in range(idx, to):
|
||||
@ -384,31 +481,70 @@ class Canvas(Graph):
|
||||
"component_type": self.get_component_type(self.path[i]),
|
||||
"thoughts": self.get_component_thoughts(self.path[i])
|
||||
})
|
||||
_run_batch(idx, to)
|
||||
await _run_batch(idx, to)
|
||||
to = len(self.path)
|
||||
# post processing of components invocation
|
||||
# post-processing of components invocation
|
||||
for i in range(idx, to):
|
||||
cpn = self.get_component(self.path[i])
|
||||
cpn_obj = self.get_component_obj(self.path[i])
|
||||
if cpn_obj.component_name.lower() == "message":
|
||||
if cpn_obj.get_param("auto_play"):
|
||||
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
|
||||
if isinstance(cpn_obj.output("content"), partial):
|
||||
_m = ""
|
||||
for m in cpn_obj.output("content")():
|
||||
buff_m = ""
|
||||
stream = cpn_obj.output("content")()
|
||||
async def _process_stream(m):
|
||||
nonlocal buff_m, _m, tts_mdl
|
||||
if not m:
|
||||
continue
|
||||
return
|
||||
if m == "<think>":
|
||||
yield decorate("message", {"content": "", "start_to_think": True})
|
||||
return decorate("message", {"content": "", "start_to_think": True})
|
||||
|
||||
elif m == "</think>":
|
||||
yield decorate("message", {"content": "", "end_to_think": True})
|
||||
else:
|
||||
yield decorate("message", {"content": m})
|
||||
_m += m
|
||||
return decorate("message", {"content": "", "end_to_think": True})
|
||||
|
||||
buff_m += m
|
||||
_m += m
|
||||
|
||||
if len(buff_m) > 16:
|
||||
ev = decorate(
|
||||
"message",
|
||||
{
|
||||
"content": m,
|
||||
"audio_binary": self.tts(tts_mdl, buff_m)
|
||||
}
|
||||
)
|
||||
buff_m = ""
|
||||
return ev
|
||||
|
||||
return decorate("message", {"content": m})
|
||||
|
||||
if inspect.isasyncgen(stream):
|
||||
async for m in stream:
|
||||
ev= await _process_stream(m)
|
||||
if ev:
|
||||
yield ev
|
||||
else:
|
||||
for m in stream:
|
||||
ev= await _process_stream(m)
|
||||
if ev:
|
||||
yield ev
|
||||
if buff_m:
|
||||
yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
|
||||
buff_m = ""
|
||||
cpn_obj.set_output("content", _m)
|
||||
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
||||
else:
|
||||
yield decorate("message", {"content": cpn_obj.output("content")})
|
||||
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
||||
yield decorate("message_end", {"reference": self.get_reference() if cite else None})
|
||||
|
||||
message_end = {}
|
||||
if isinstance(cpn_obj.output("attachment"), dict):
|
||||
message_end["attachment"] = cpn_obj.output("attachment")
|
||||
if cite:
|
||||
message_end["reference"] = self.get_reference()
|
||||
yield decorate("message_end", message_end)
|
||||
|
||||
while partials:
|
||||
_cpn_obj = self.get_component_obj(partials[0])
|
||||
@ -429,7 +565,7 @@ class Canvas(Graph):
|
||||
else:
|
||||
self.error = cpn_obj.error()
|
||||
|
||||
if cpn_obj.component_name.lower() != "iteration":
|
||||
if cpn_obj.component_name.lower() not in ("iteration","loop"):
|
||||
if isinstance(cpn_obj.output("content"), partial):
|
||||
if self.error:
|
||||
cpn_obj.set_output("content", None)
|
||||
@ -454,14 +590,16 @@ class Canvas(Graph):
|
||||
for cpn_id in cpn_ids:
|
||||
_append_path(cpn_id)
|
||||
|
||||
if cpn_obj.component_name.lower() == "iterationitem" and cpn_obj.end():
|
||||
if cpn_obj.component_name.lower() in ("iterationitem","loopitem") and cpn_obj.end():
|
||||
iter = cpn_obj.get_parent()
|
||||
yield _node_finished(iter)
|
||||
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
|
||||
elif cpn_obj.component_name.lower() in ["categorize", "switch"]:
|
||||
_extend_path(cpn_obj.output("_next"))
|
||||
elif cpn_obj.component_name.lower() == "iteration":
|
||||
elif cpn_obj.component_name.lower() in ("iteration", "loop"):
|
||||
_append_path(cpn_obj.get_start())
|
||||
elif cpn_obj.component_name.lower() == "exitloop" and cpn_obj.get_parent().component_name.lower() == "loop":
|
||||
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
|
||||
elif not cpn["downstream"] and cpn_obj.get_parent():
|
||||
_append_path(cpn_obj.get_parent().get_start())
|
||||
else:
|
||||
@ -517,6 +655,50 @@ class Canvas(Graph):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def tts(self,tts_mdl, text):
|
||||
def clean_tts_text(text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||||
|
||||
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||
|
||||
emoji_pattern = re.compile(
|
||||
"[\U0001F600-\U0001F64F"
|
||||
"\U0001F300-\U0001F5FF"
|
||||
"\U0001F680-\U0001F6FF"
|
||||
"\U0001F1E0-\U0001F1FF"
|
||||
"\U00002700-\U000027BF"
|
||||
"\U0001F900-\U0001F9FF"
|
||||
"\U0001FA70-\U0001FAFF"
|
||||
"\U0001FAD0-\U0001FAFF]+",
|
||||
flags=re.UNICODE
|
||||
)
|
||||
text = emoji_pattern.sub("", text)
|
||||
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
MAX_LEN = 500
|
||||
if len(text) > MAX_LEN:
|
||||
text = text[:MAX_LEN]
|
||||
|
||||
return text
|
||||
if not tts_mdl or not text:
|
||||
return None
|
||||
text = clean_tts_text(text)
|
||||
if not text:
|
||||
return None
|
||||
bin = b""
|
||||
try:
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
except Exception as e:
|
||||
logging.error(f"TTS failed: {e}, text={text!r}")
|
||||
return None
|
||||
return binascii.hexlify(bin).decode("utf-8")
|
||||
|
||||
def get_history(self, window_size):
|
||||
convs = []
|
||||
if window_size <= 0:
|
||||
@ -546,20 +728,30 @@ class Canvas(Graph):
|
||||
def get_component_input_elements(self, cpnnm):
|
||||
return self.components[cpnnm]["obj"].get_input_elements()
|
||||
|
||||
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
||||
async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]:
|
||||
if not files:
|
||||
return []
|
||||
def image_to_base64(file):
|
||||
return "data:{};base64,{}".format(file["mime_type"],
|
||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||
exe = ThreadPoolExecutor(max_workers=5)
|
||||
threads = []
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = []
|
||||
for file in files:
|
||||
if file["mime_type"].find("image") >=0:
|
||||
threads.append(exe.submit(image_to_base64, file))
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
|
||||
continue
|
||||
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||
return [th.result() for th in threads]
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
||||
"""
|
||||
Synchronous wrapper for get_files_async, used by sync component invoke paths.
|
||||
"""
|
||||
loop = getattr(self, "_loop", None)
|
||||
if loop and loop.is_running():
|
||||
return asyncio.run_coroutine_threadsafe(self.get_files_async(files), loop).result()
|
||||
|
||||
return asyncio.run(self.get_files_async(files))
|
||||
|
||||
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
|
||||
agent_ids = agent_id.split("-->")
|
||||
@ -613,4 +805,3 @@ class Canvas(Graph):
|
||||
|
||||
def get_component_thoughts(self, cpn_id) -> str:
|
||||
return self.components.get(cpn_id)["obj"].thoughts()
|
||||
|
||||
|
||||
@ -13,10 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
@ -28,9 +29,9 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from common.connection_utils import timeout
|
||||
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
|
||||
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||
from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
|
||||
citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
|
||||
|
||||
@ -137,8 +138,34 @@ class Agent(LLM, ToolBase):
|
||||
res.update(cpn.get_input_form())
|
||||
return res
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||
def _get_output_schema(self):
|
||||
try:
|
||||
cand = self._param.outputs.get("structured")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if isinstance(cand, dict):
|
||||
if isinstance(cand.get("properties"), dict) and len(cand["properties"]) > 0:
|
||||
return cand
|
||||
for k in ("schema", "structured"):
|
||||
if isinstance(cand.get(k), dict) and isinstance(cand[k].get("properties"), dict) and len(cand[k]["properties"]) > 0:
|
||||
return cand[k]
|
||||
|
||||
return None
|
||||
|
||||
async def _force_format_to_schema_async(self, text: str, schema_prompt: str) -> str:
|
||||
fmt_msgs = [
|
||||
{"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."},
|
||||
{"role": "user", "content": text},
|
||||
]
|
||||
_, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97))
|
||||
return await self._generate_async(fmt_msgs)
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
|
||||
@ -157,25 +184,25 @@ class Agent(LLM, ToolBase):
|
||||
if not self.tools:
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
return LLM._invoke(self, **kwargs)
|
||||
return await LLM._invoke_async(self, **kwargs)
|
||||
|
||||
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
||||
output_schema = self._get_output_schema()
|
||||
schema_prompt = ""
|
||||
if output_schema:
|
||||
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
|
||||
schema_prompt = structured_output_prompt(schema)
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
output_structure=None
|
||||
try:
|
||||
output_structure=self._param.outputs['structured']
|
||||
except Exception:
|
||||
pass
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
|
||||
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt))
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
use_tools = []
|
||||
ans = ""
|
||||
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
||||
async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
ans += delta_ans
|
||||
@ -188,16 +215,38 @@ class Agent(LLM, ToolBase):
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
|
||||
if output_schema:
|
||||
error = ""
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
try:
|
||||
def clean_formated_answer(ans: str) -> str:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
|
||||
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||
obj = json_repair.loads(clean_formated_answer(ans))
|
||||
self.set_output("structured", obj)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
return obj
|
||||
except Exception:
|
||||
error = "The answer cannot be parsed as JSON"
|
||||
ans = await self._force_format_to_schema_async(ans, schema_prompt)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
continue
|
||||
|
||||
self.set_output("_ERROR", error)
|
||||
return
|
||||
|
||||
self.set_output("content", ans)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
return ans
|
||||
|
||||
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
|
||||
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer_without_toolcall = ""
|
||||
use_tools = []
|
||||
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
||||
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
|
||||
@ -215,39 +264,23 @@ class Agent(LLM, ToolBase):
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
|
||||
def _gen_citations(self, text):
|
||||
retrievals = self._canvas.get_reference()
|
||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
||||
for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
||||
{"role": "user", "content": text}
|
||||
]):
|
||||
yield delta_ans
|
||||
|
||||
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}):
|
||||
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||
token_count = 0
|
||||
tool_metas = self.tool_meta
|
||||
hist = deepcopy(history)
|
||||
last_calling = ""
|
||||
if len(hist) > 3:
|
||||
st = timer()
|
||||
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
|
||||
user_request = await full_question(messages=history, chat_mdl=self.chat_mdl)
|
||||
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
||||
else:
|
||||
user_request = history[-1]["content"]
|
||||
|
||||
def use_tool(name, args):
|
||||
nonlocal hist, use_tools, token_count,last_calling,user_request
|
||||
async def use_tool_async(name, args):
|
||||
nonlocal hist, use_tools, last_calling
|
||||
logging.info(f"{last_calling=} == {name=}")
|
||||
# Summarize of function calling
|
||||
#if all([
|
||||
# isinstance(self.toolcall_session.get_tool_obj(name), Agent),
|
||||
# last_calling,
|
||||
# last_calling != name
|
||||
#]):
|
||||
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt))
|
||||
last_calling = name
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
tool_response = await self.toolcall_session.tool_call_async(name, args)
|
||||
use_tools.append({
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
@ -258,12 +291,16 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
return name, tool_response
|
||||
|
||||
def complete():
|
||||
async def complete():
|
||||
nonlocal hist
|
||||
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
|
||||
if schema_prompt:
|
||||
need2cite = False
|
||||
cited = False
|
||||
if hist[0]["role"] == "system" and need2cite:
|
||||
if len(hist) < 7:
|
||||
if hist and hist[0]["role"] == "system":
|
||||
if schema_prompt:
|
||||
hist[0]["content"] += "\n" + schema_prompt
|
||||
if need2cite and len(hist) < 7:
|
||||
hist[0]["content"] += citation_prompt()
|
||||
cited = True
|
||||
yield "", token_count
|
||||
@ -272,7 +309,7 @@ class Agent(LLM, ToolBase):
|
||||
if len(hist) > 12:
|
||||
_hist = [hist[0], hist[1], *hist[-10:]]
|
||||
entire_txt = ""
|
||||
for delta_ans in self._generate_streamly(_hist):
|
||||
async for delta_ans in self._generate_streamly(_hist):
|
||||
if not need2cite or cited:
|
||||
yield delta_ans, 0
|
||||
entire_txt += delta_ans
|
||||
@ -281,7 +318,7 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
st = timer()
|
||||
txt = ""
|
||||
for delta_ans in self._gen_citations(entire_txt):
|
||||
async for delta_ans in self._gen_citations_async(entire_txt):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
yield delta_ans, 0
|
||||
@ -296,14 +333,14 @@ class Agent(LLM, ToolBase):
|
||||
hist.append({"role": "user", "content": content})
|
||||
|
||||
st = timer()
|
||||
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||
task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||
for _ in range(self._param.max_rounds + 1):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||
response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||
# self.callback("next_step", {}, str(response)[:256]+"...")
|
||||
token_count += tk
|
||||
token_count += tk or 0
|
||||
hist.append({"role": "assistant", "content": response})
|
||||
try:
|
||||
functions = json_repair.loads(re.sub(r"```.*", "", response))
|
||||
@ -312,23 +349,24 @@ class Agent(LLM, ToolBase):
|
||||
for f in functions:
|
||||
if not isinstance(f, dict):
|
||||
raise TypeError(f"An object type should be returned, but `{f}`")
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
thr = []
|
||||
for func in functions:
|
||||
name = func["name"]
|
||||
args = func["arguments"]
|
||||
if name == COMPLETE_TASK:
|
||||
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
||||
for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
return
|
||||
|
||||
thr.append(executor.submit(use_tool, name, args))
|
||||
tool_tasks = []
|
||||
for func in functions:
|
||||
name = func["name"]
|
||||
args = func["arguments"]
|
||||
if name == COMPLETE_TASK:
|
||||
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
||||
async for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
return
|
||||
|
||||
st = timer()
|
||||
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt)
|
||||
append_user_content(hist, reflection)
|
||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||
tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
|
||||
|
||||
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
|
||||
st = timer()
|
||||
reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
|
||||
append_user_content(hist, reflection)
|
||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
|
||||
@ -352,27 +390,30 @@ Respond immediately with your final comprehensive answer.
|
||||
return
|
||||
append_user_content(hist, final_instruction)
|
||||
|
||||
for txt, tkcnt in complete():
|
||||
async for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
|
||||
def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str:
|
||||
# self.callback("get_useful_memory", {"topn": 3}, "...")
|
||||
mems = self._canvas.get_memory()
|
||||
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt)
|
||||
try:
|
||||
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
|
||||
mems = [mems[r] for r in rank]
|
||||
return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
async def _gen_citations_async(self, text):
|
||||
retrievals = self._canvas.get_reference()
|
||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
||||
async for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
||||
{"role": "user", "content": text}
|
||||
]):
|
||||
yield delta_ans
|
||||
|
||||
return "Error occurred."
|
||||
|
||||
def reset(self, temp=False):
|
||||
def reset(self, only_output=False):
|
||||
"""
|
||||
Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession.
|
||||
"""
|
||||
for k in self._param.outputs.keys():
|
||||
self._param.outputs[k]["value"] = None
|
||||
|
||||
for k, cpn in self.tools.items():
|
||||
if hasattr(cpn, "reset") and callable(cpn.reset):
|
||||
cpn.reset()
|
||||
|
||||
if only_output:
|
||||
return
|
||||
for k in self._param.inputs.keys():
|
||||
self._param.inputs[k]["value"] = None
|
||||
self._param.debug_inputs = {}
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
from abc import ABC
|
||||
@ -23,11 +24,9 @@ import os
|
||||
import logging
|
||||
from typing import Any, List, Union
|
||||
import pandas as pd
|
||||
import trio
|
||||
from agent import settings
|
||||
from common.connection_utils import timeout
|
||||
|
||||
|
||||
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
|
||||
_DEPRECATED_PARAMS = "_deprecated_params"
|
||||
_USER_FEEDED_PARAMS = "_user_feeded_params"
|
||||
@ -97,7 +96,7 @@ class ComponentParamBase(ABC):
|
||||
def _recursive_convert_obj_to_dict(obj):
|
||||
ret_dict = {}
|
||||
if isinstance(obj, dict):
|
||||
for k,v in obj.items():
|
||||
for k, v in obj.items():
|
||||
if isinstance(v, dict) or (v and type(v).__name__ not in dir(builtins)):
|
||||
ret_dict[k] = _recursive_convert_obj_to_dict(v)
|
||||
else:
|
||||
@ -253,96 +252,65 @@ class ComponentParamBase(ABC):
|
||||
self._validate_param(attr, validation_json)
|
||||
|
||||
@staticmethod
|
||||
def check_string(param, descr):
|
||||
def check_string(param, description):
|
||||
if type(param).__name__ not in ["str"]:
|
||||
raise ValueError(
|
||||
descr + " {} not supported, should be string type".format(param)
|
||||
)
|
||||
raise ValueError(description + " {} not supported, should be string type".format(param))
|
||||
|
||||
@staticmethod
|
||||
def check_empty(param, descr):
|
||||
def check_empty(param, description):
|
||||
if not param:
|
||||
raise ValueError(
|
||||
descr + " does not support empty value."
|
||||
)
|
||||
raise ValueError(description + " does not support empty value.")
|
||||
|
||||
@staticmethod
|
||||
def check_positive_integer(param, descr):
|
||||
def check_positive_integer(param, description):
|
||||
if type(param).__name__ not in ["int", "long"] or param <= 0:
|
||||
raise ValueError(
|
||||
descr + " {} not supported, should be positive integer".format(param)
|
||||
)
|
||||
raise ValueError(description + " {} not supported, should be positive integer".format(param))
|
||||
|
||||
@staticmethod
|
||||
def check_positive_number(param, descr):
|
||||
def check_positive_number(param, description):
|
||||
if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
|
||||
raise ValueError(
|
||||
descr + " {} not supported, should be positive numeric".format(param)
|
||||
)
|
||||
raise ValueError(description + " {} not supported, should be positive numeric".format(param))
|
||||
|
||||
@staticmethod
|
||||
def check_nonnegative_number(param, descr):
|
||||
def check_nonnegative_number(param, description):
|
||||
if type(param).__name__ not in ["float", "int", "long"] or param < 0:
|
||||
raise ValueError(
|
||||
descr
|
||||
+ " {} not supported, should be non-negative numeric".format(param)
|
||||
)
|
||||
raise ValueError(description + " {} not supported, should be non-negative numeric".format(param))
|
||||
|
||||
@staticmethod
|
||||
def check_decimal_float(param, descr):
|
||||
def check_decimal_float(param, description):
|
||||
if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
|
||||
raise ValueError(
|
||||
descr
|
||||
+ " {} not supported, should be a float number in range [0, 1]".format(
|
||||
param
|
||||
)
|
||||
)
|
||||
raise ValueError(description + " {} not supported, should be a float number in range [0, 1]".format(param))
|
||||
|
||||
@staticmethod
|
||||
def check_boolean(param, descr):
|
||||
def check_boolean(param, description):
|
||||
if type(param).__name__ != "bool":
|
||||
raise ValueError(
|
||||
descr + " {} not supported, should be bool type".format(param)
|
||||
)
|
||||
raise ValueError(description + " {} not supported, should be bool type".format(param))
|
||||
|
||||
@staticmethod
|
||||
def check_open_unit_interval(param, descr):
|
||||
def check_open_unit_interval(param, description):
|
||||
if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
|
||||
raise ValueError(
|
||||
descr + " should be a numeric number between 0 and 1 exclusively"
|
||||
)
|
||||
raise ValueError(description + " should be a numeric number between 0 and 1 exclusively")
|
||||
|
||||
@staticmethod
|
||||
def check_valid_value(param, descr, valid_values):
|
||||
def check_valid_value(param, description, valid_values):
|
||||
if param not in valid_values:
|
||||
raise ValueError(
|
||||
descr
|
||||
+ " {} is not supported, it should be in {}".format(param, valid_values)
|
||||
)
|
||||
raise ValueError(description + " {} is not supported, it should be in {}".format(param, valid_values))
|
||||
|
||||
@staticmethod
|
||||
def check_defined_type(param, descr, types):
|
||||
def check_defined_type(param, description, types):
|
||||
if type(param).__name__ not in types:
|
||||
raise ValueError(
|
||||
descr + " {} not supported, should be one of {}".format(param, types)
|
||||
)
|
||||
raise ValueError(description + " {} not supported, should be one of {}".format(param, types))
|
||||
|
||||
@staticmethod
|
||||
def check_and_change_lower(param, valid_list, descr=""):
|
||||
def check_and_change_lower(param, valid_list, description=""):
|
||||
if type(param).__name__ != "str":
|
||||
raise ValueError(
|
||||
descr
|
||||
+ " {} not supported, should be one of {}".format(param, valid_list)
|
||||
)
|
||||
raise ValueError(description + " {} not supported, should be one of {}".format(param, valid_list))
|
||||
|
||||
lower_param = param.lower()
|
||||
if lower_param in valid_list:
|
||||
return lower_param
|
||||
else:
|
||||
raise ValueError(
|
||||
descr
|
||||
+ " {} not supported, should be one of {}".format(param, valid_list)
|
||||
)
|
||||
raise ValueError(description + " {} not supported, should be one of {}".format(param, valid_list))
|
||||
|
||||
@staticmethod
|
||||
def _greater_equal_than(value, limit):
|
||||
@ -374,16 +342,16 @@ class ComponentParamBase(ABC):
|
||||
def _not_in(value, wrong_value_list):
|
||||
return value not in wrong_value_list
|
||||
|
||||
def _warn_deprecated_param(self, param_name, descr):
|
||||
def _warn_deprecated_param(self, param_name, description):
|
||||
if self._deprecated_params_set.get(param_name):
|
||||
logging.warning(
|
||||
f"{descr} {param_name} is deprecated and ignored in this version."
|
||||
f"{description} {param_name} is deprecated and ignored in this version."
|
||||
)
|
||||
|
||||
def _warn_to_deprecate_param(self, param_name, descr, new_param):
|
||||
def _warn_to_deprecate_param(self, param_name, description, new_param):
|
||||
if self._deprecated_params_set.get(param_name):
|
||||
logging.warning(
|
||||
f"{descr} {param_name} will be deprecated in future release; "
|
||||
f"{description} {param_name} will be deprecated in future release; "
|
||||
f"please use {new_param} instead."
|
||||
)
|
||||
return True
|
||||
@ -392,8 +360,8 @@ class ComponentParamBase(ABC):
|
||||
|
||||
class ComponentBase(ABC):
|
||||
component_name: str
|
||||
thread_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
|
||||
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
||||
thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
||||
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
@ -407,7 +375,7 @@ class ComponentBase(ABC):
|
||||
"params": {}
|
||||
}}""".format(self.component_name,
|
||||
self._param
|
||||
)
|
||||
)
|
||||
|
||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||
from agent.canvas import Graph # Local import to avoid cyclic dependency
|
||||
@ -445,14 +413,42 @@ class ComponentBase(ABC):
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
async def invoke_async(self, **kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Async wrapper for component invocation.
|
||||
Prefers coroutine `_invoke_async` if present; otherwise falls back to `_invoke`.
|
||||
Handles timing and error recording consistently with `invoke`.
|
||||
"""
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
try:
|
||||
if self.check_if_canceled("Component processing"):
|
||||
return
|
||||
|
||||
fn_async = getattr(self, "_invoke_async", None)
|
||||
if fn_async and asyncio.iscoroutinefunction(fn_async):
|
||||
await fn_async(**kwargs)
|
||||
elif asyncio.iscoroutinefunction(self._invoke):
|
||||
await self._invoke(**kwargs)
|
||||
else:
|
||||
await asyncio.to_thread(self._invoke, **kwargs)
|
||||
except Exception as e:
|
||||
if self.get_exception_default_value():
|
||||
self.set_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", str(e))
|
||||
logging.exception(e)
|
||||
self._param.debug_inputs = {}
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def output(self, var_nm: str=None) -> Union[dict[str, Any], Any]:
|
||||
def output(self, var_nm: str = None) -> Union[dict[str, Any], Any]:
|
||||
if var_nm:
|
||||
return self._param.outputs.get(var_nm, {}).get("value", "")
|
||||
return {k: o.get("value") for k,o in self._param.outputs.items()}
|
||||
return {k: o.get("value") for k, o in self._param.outputs.items()}
|
||||
|
||||
def set_output(self, key: str, value: Any):
|
||||
if key not in self._param.outputs:
|
||||
@ -463,15 +459,18 @@ class ComponentBase(ABC):
|
||||
return self._param.outputs.get("_ERROR", {}).get("value")
|
||||
|
||||
def reset(self, only_output=False):
|
||||
for k in self._param.outputs.keys():
|
||||
self._param.outputs[k]["value"] = None
|
||||
outputs: dict = self._param.outputs # for better performance
|
||||
for k in outputs.keys():
|
||||
outputs[k]["value"] = None
|
||||
if only_output:
|
||||
return
|
||||
for k in self._param.inputs.keys():
|
||||
self._param.inputs[k]["value"] = None
|
||||
|
||||
inputs: dict = self._param.inputs # for better performance
|
||||
for k in inputs.keys():
|
||||
inputs[k]["value"] = None
|
||||
self._param.debug_inputs = {}
|
||||
|
||||
def get_input(self, key: str=None) -> Union[Any, dict[str, Any]]:
|
||||
def get_input(self, key: str = None) -> Union[Any, dict[str, Any]]:
|
||||
if key:
|
||||
return self._param.inputs.get(key, {}).get("value")
|
||||
|
||||
@ -495,13 +494,13 @@ class ComponentBase(ABC):
|
||||
|
||||
def get_input_elements_from_text(self, txt: str) -> dict[str, dict[str, str]]:
|
||||
res = {}
|
||||
for r in re.finditer(self.variable_ref_patt, txt, flags=re.IGNORECASE|re.DOTALL):
|
||||
for r in re.finditer(self.variable_ref_patt, txt, flags=re.IGNORECASE | re.DOTALL):
|
||||
exp = r.group(1)
|
||||
cpn_id, var_nm = exp.split("@") if exp.find("@")>0 else ("", exp)
|
||||
cpn_id, var_nm = exp.split("@") if exp.find("@") > 0 else ("", exp)
|
||||
res[exp] = {
|
||||
"name": (self._canvas.get_component_name(cpn_id) +f"@{var_nm}") if cpn_id else exp,
|
||||
"name": (self._canvas.get_component_name(cpn_id) + f"@{var_nm}") if cpn_id else exp,
|
||||
"value": self._canvas.get_variable_value(exp),
|
||||
"_retrival": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None,
|
||||
"_retrieval": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None,
|
||||
"_cpn_id": cpn_id
|
||||
}
|
||||
return res
|
||||
@ -552,6 +551,7 @@ class ComponentBase(ABC):
|
||||
for n, v in kv.items():
|
||||
def repl(_match, val=v):
|
||||
return str(val) if val is not None else ""
|
||||
|
||||
content = re.sub(
|
||||
r"\{%s\}" % re.escape(n),
|
||||
repl,
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from agent.component.fillup import UserFillUpParam, UserFillUp
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
|
||||
class BeginParam(UserFillUpParam):
|
||||
@ -27,7 +28,7 @@ class BeginParam(UserFillUpParam):
|
||||
self.prologue = "Hi! I'm your smart assistant. What can I do for you?"
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task"])
|
||||
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task","Webhook"])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return getattr(self, "inputs")
|
||||
@ -48,7 +49,7 @@ class Begin(UserFillUp):
|
||||
if v.get("optional") and v.get("value", None) is None:
|
||||
v = None
|
||||
else:
|
||||
v = self._canvas.get_files([v["value"]])
|
||||
v = FileService.get_files([v["value"]])
|
||||
else:
|
||||
v = v.get("value")
|
||||
self.set_output(k, v)
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@ -97,7 +98,7 @@ class Categorize(LLM, ABC):
|
||||
component_name = "Categorize"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("Categorize processing"):
|
||||
return
|
||||
|
||||
@ -121,7 +122,7 @@ class Categorize(LLM, ABC):
|
||||
if self.check_if_canceled("Categorize processing"):
|
||||
return
|
||||
|
||||
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
||||
ans = await chat_mdl.async_chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
||||
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
||||
if ERROR_PREFIX in ans:
|
||||
raise Exception(ans)
|
||||
@ -144,5 +145,9 @@ class Categorize(LLM, ABC):
|
||||
self.set_output("category_name", max_category)
|
||||
self.set_output("_next", cpn_ids)
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import ast
|
||||
import os
|
||||
|
||||
1570
agent/component/docs_generator.py
Normal file
1570
agent/component/docs_generator.py
Normal file
File diff suppressed because it is too large
Load Diff
401
agent/component/excel_processor.py
Normal file
401
agent/component/excel_processor.py
Normal file
@ -0,0 +1,401 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
ExcelProcessor Component
|
||||
|
||||
A component for reading, processing, and generating Excel files in RAGFlow agents.
|
||||
Supports multiple Excel file inputs, data transformation, and Excel output generation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC
|
||||
from io import BytesIO
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.db.services.file_service import FileService
|
||||
from api.utils.api_utils import timeout
|
||||
from common import settings
|
||||
from common.misc_utils import get_uuid
|
||||
|
||||
|
||||
class ExcelProcessorParam(ComponentParamBase):
|
||||
"""
|
||||
Define the ExcelProcessor component parameters.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Input configuration
|
||||
self.input_files = [] # Variable references to uploaded files
|
||||
self.operation = "read" # read, merge, transform, output
|
||||
|
||||
# Processing options
|
||||
self.sheet_selection = "all" # all, first, or comma-separated sheet names
|
||||
self.merge_strategy = "concat" # concat, join
|
||||
self.join_on = "" # Column name for join operations
|
||||
|
||||
# Transform options (for LLM-guided transformations)
|
||||
self.transform_instructions = ""
|
||||
self.transform_data = "" # Variable reference to transformation data
|
||||
|
||||
# Output options
|
||||
self.output_format = "xlsx" # xlsx, csv
|
||||
self.output_filename = "output"
|
||||
|
||||
# Component outputs
|
||||
self.outputs = {
|
||||
"data": {
|
||||
"type": "object",
|
||||
"value": {}
|
||||
},
|
||||
"summary": {
|
||||
"type": "str",
|
||||
"value": ""
|
||||
},
|
||||
"markdown": {
|
||||
"type": "str",
|
||||
"value": ""
|
||||
}
|
||||
}
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(
|
||||
self.operation,
|
||||
"[ExcelProcessor] Operation",
|
||||
["read", "merge", "transform", "output"]
|
||||
)
|
||||
self.check_valid_value(
|
||||
self.output_format,
|
||||
"[ExcelProcessor] Output format",
|
||||
["xlsx", "csv"]
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
class ExcelProcessor(ComponentBase, ABC):
|
||||
"""
|
||||
Excel processing component for RAGFlow agents.
|
||||
|
||||
Operations:
|
||||
- read: Parse Excel files into structured data
|
||||
- merge: Combine multiple Excel files
|
||||
- transform: Apply data transformations based on instructions
|
||||
- output: Generate Excel file output
|
||||
"""
|
||||
component_name = "ExcelProcessor"
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
"""Define input form for the component."""
|
||||
res = {}
|
||||
for ref in (self._param.input_files or []):
|
||||
for k, o in self.get_input_elements_from_text(ref).items():
|
||||
res[k] = {"name": o.get("name", ""), "type": "file"}
|
||||
if self._param.transform_data:
|
||||
for k, o in self.get_input_elements_from_text(self._param.transform_data).items():
|
||||
res[k] = {"name": o.get("name", ""), "type": "object"}
|
||||
return res
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("ExcelProcessor processing"):
|
||||
return
|
||||
|
||||
operation = self._param.operation.lower()
|
||||
|
||||
if operation == "read":
|
||||
self._read_excels()
|
||||
elif operation == "merge":
|
||||
self._merge_excels()
|
||||
elif operation == "transform":
|
||||
self._transform_data()
|
||||
elif operation == "output":
|
||||
self._output_excel()
|
||||
else:
|
||||
self.set_output("summary", f"Unknown operation: {operation}")
|
||||
|
||||
def _get_file_content(self, file_ref: str) -> tuple[bytes, str]:
|
||||
"""
|
||||
Get file content from a variable reference.
|
||||
Returns (content_bytes, filename).
|
||||
"""
|
||||
value = self._canvas.get_variable_value(file_ref)
|
||||
if value is None:
|
||||
return None, None
|
||||
|
||||
# Handle different value formats
|
||||
if isinstance(value, dict):
|
||||
# File reference from Begin/UserFillUp component
|
||||
file_id = value.get("id") or value.get("file_id")
|
||||
created_by = value.get("created_by") or self._canvas.get_tenant_id()
|
||||
filename = value.get("name") or value.get("filename", "unknown.xlsx")
|
||||
if file_id:
|
||||
content = FileService.get_blob(created_by, file_id)
|
||||
return content, filename
|
||||
elif isinstance(value, list) and len(value) > 0:
|
||||
# List of file references - return first
|
||||
return self._get_file_content_from_list(value[0])
|
||||
elif isinstance(value, str):
|
||||
# Could be base64 encoded or a path
|
||||
if value.startswith("data:"):
|
||||
import base64
|
||||
# Extract base64 content
|
||||
_, encoded = value.split(",", 1)
|
||||
return base64.b64decode(encoded), "uploaded.xlsx"
|
||||
|
||||
return None, None
|
||||
|
||||
def _get_file_content_from_list(self, item) -> tuple[bytes, str]:
|
||||
"""Extract file content from a list item."""
|
||||
if isinstance(item, dict):
|
||||
return self._get_file_content(item)
|
||||
return None, None
|
||||
|
||||
def _parse_excel_to_dataframes(self, content: bytes, filename: str) -> dict[str, pd.DataFrame]:
|
||||
"""Parse Excel content into a dictionary of DataFrames (one per sheet)."""
|
||||
try:
|
||||
excel_file = BytesIO(content)
|
||||
|
||||
if filename.lower().endswith(".csv"):
|
||||
df = pd.read_csv(excel_file)
|
||||
return {"Sheet1": df}
|
||||
else:
|
||||
# Read all sheets
|
||||
xlsx = pd.ExcelFile(excel_file, engine='openpyxl')
|
||||
sheet_selection = self._param.sheet_selection
|
||||
|
||||
if sheet_selection == "all":
|
||||
sheets_to_read = xlsx.sheet_names
|
||||
elif sheet_selection == "first":
|
||||
sheets_to_read = [xlsx.sheet_names[0]] if xlsx.sheet_names else []
|
||||
else:
|
||||
# Comma-separated sheet names
|
||||
requested = [s.strip() for s in sheet_selection.split(",")]
|
||||
sheets_to_read = [s for s in requested if s in xlsx.sheet_names]
|
||||
|
||||
dfs = {}
|
||||
for sheet in sheets_to_read:
|
||||
dfs[sheet] = pd.read_excel(xlsx, sheet_name=sheet)
|
||||
return dfs
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error parsing Excel file {filename}: {e}")
|
||||
return {}
|
||||
|
||||
def _read_excels(self):
|
||||
"""Read and parse Excel files into structured data."""
|
||||
all_data = {}
|
||||
summaries = []
|
||||
markdown_parts = []
|
||||
|
||||
for file_ref in (self._param.input_files or []):
|
||||
if self.check_if_canceled("ExcelProcessor reading"):
|
||||
return
|
||||
|
||||
# Get variable value
|
||||
value = self._canvas.get_variable_value(file_ref)
|
||||
self.set_input_value(file_ref, str(value)[:200] if value else "")
|
||||
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
# Handle file content
|
||||
content, filename = self._get_file_content(file_ref)
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
# Parse Excel
|
||||
dfs = self._parse_excel_to_dataframes(content, filename)
|
||||
|
||||
for sheet_name, df in dfs.items():
|
||||
key = f"{filename}_{sheet_name}" if len(dfs) > 1 else filename
|
||||
all_data[key] = df.to_dict(orient="records")
|
||||
|
||||
# Build summary
|
||||
summaries.append(f"**{key}**: {len(df)} rows, {len(df.columns)} columns ({', '.join(df.columns.tolist()[:5])}{'...' if len(df.columns) > 5 else ''})")
|
||||
|
||||
# Build markdown table
|
||||
markdown_parts.append(f"### {key}\n\n{df.head(10).to_markdown(index=False)}\n")
|
||||
|
||||
# Set outputs
|
||||
self.set_output("data", all_data)
|
||||
self.set_output("summary", "\n".join(summaries) if summaries else "No Excel files found")
|
||||
self.set_output("markdown", "\n\n".join(markdown_parts) if markdown_parts else "No data")
|
||||
|
||||
def _merge_excels(self):
|
||||
"""Merge multiple Excel files/sheets into one."""
|
||||
all_dfs = []
|
||||
|
||||
for file_ref in (self._param.input_files or []):
|
||||
if self.check_if_canceled("ExcelProcessor merging"):
|
||||
return
|
||||
|
||||
value = self._canvas.get_variable_value(file_ref)
|
||||
self.set_input_value(file_ref, str(value)[:200] if value else "")
|
||||
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
content, filename = self._get_file_content(file_ref)
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
dfs = self._parse_excel_to_dataframes(content, filename)
|
||||
all_dfs.extend(dfs.values())
|
||||
|
||||
if not all_dfs:
|
||||
self.set_output("data", {})
|
||||
self.set_output("summary", "No data to merge")
|
||||
return
|
||||
|
||||
# Merge strategy
|
||||
if self._param.merge_strategy == "concat":
|
||||
merged_df = pd.concat(all_dfs, ignore_index=True)
|
||||
elif self._param.merge_strategy == "join" and self._param.join_on:
|
||||
# Join on specified column
|
||||
merged_df = all_dfs[0]
|
||||
for df in all_dfs[1:]:
|
||||
merged_df = merged_df.merge(df, on=self._param.join_on, how="outer")
|
||||
else:
|
||||
merged_df = pd.concat(all_dfs, ignore_index=True)
|
||||
|
||||
self.set_output("data", {"merged": merged_df.to_dict(orient="records")})
|
||||
self.set_output("summary", f"Merged {len(all_dfs)} sources into {len(merged_df)} rows, {len(merged_df.columns)} columns")
|
||||
self.set_output("markdown", merged_df.head(20).to_markdown(index=False))
|
||||
|
||||
def _transform_data(self):
|
||||
"""Apply transformations to data based on instructions or input data."""
|
||||
# Get the data to transform
|
||||
transform_ref = self._param.transform_data
|
||||
if not transform_ref:
|
||||
self.set_output("summary", "No transform data reference provided")
|
||||
return
|
||||
|
||||
data = self._canvas.get_variable_value(transform_ref)
|
||||
self.set_input_value(transform_ref, str(data)[:300] if data else "")
|
||||
|
||||
if data is None:
|
||||
self.set_output("summary", "Transform data is empty")
|
||||
return
|
||||
|
||||
# Convert to DataFrame
|
||||
if isinstance(data, dict):
|
||||
# Could be {"sheet": [rows]} format
|
||||
if all(isinstance(v, list) for v in data.values()):
|
||||
# Multiple sheets
|
||||
all_markdown = []
|
||||
for sheet_name, rows in data.items():
|
||||
df = pd.DataFrame(rows)
|
||||
all_markdown.append(f"### {sheet_name}\n\n{df.to_markdown(index=False)}")
|
||||
self.set_output("data", data)
|
||||
self.set_output("markdown", "\n\n".join(all_markdown))
|
||||
else:
|
||||
df = pd.DataFrame([data])
|
||||
self.set_output("data", df.to_dict(orient="records"))
|
||||
self.set_output("markdown", df.to_markdown(index=False))
|
||||
elif isinstance(data, list):
|
||||
df = pd.DataFrame(data)
|
||||
self.set_output("data", df.to_dict(orient="records"))
|
||||
self.set_output("markdown", df.to_markdown(index=False))
|
||||
else:
|
||||
self.set_output("data", {"raw": str(data)})
|
||||
self.set_output("markdown", str(data))
|
||||
|
||||
self.set_output("summary", "Transformed data ready for processing")
|
||||
|
||||
def _output_excel(self):
|
||||
"""Generate Excel file output from data."""
|
||||
# Get data from transform_data reference
|
||||
transform_ref = self._param.transform_data
|
||||
if not transform_ref:
|
||||
self.set_output("summary", "No data reference for output")
|
||||
return
|
||||
|
||||
data = self._canvas.get_variable_value(transform_ref)
|
||||
self.set_input_value(transform_ref, str(data)[:300] if data else "")
|
||||
|
||||
if data is None:
|
||||
self.set_output("summary", "No data to output")
|
||||
return
|
||||
|
||||
try:
|
||||
# Prepare DataFrames
|
||||
if isinstance(data, dict):
|
||||
if all(isinstance(v, list) for v in data.values()):
|
||||
# Multi-sheet format
|
||||
dfs = {k: pd.DataFrame(v) for k, v in data.items()}
|
||||
else:
|
||||
dfs = {"Sheet1": pd.DataFrame([data])}
|
||||
elif isinstance(data, list):
|
||||
dfs = {"Sheet1": pd.DataFrame(data)}
|
||||
else:
|
||||
self.set_output("summary", "Invalid data format for Excel output")
|
||||
return
|
||||
|
||||
# Generate output
|
||||
doc_id = get_uuid()
|
||||
|
||||
if self._param.output_format == "csv":
|
||||
# For CSV, only output first sheet
|
||||
first_df = list(dfs.values())[0]
|
||||
binary_content = first_df.to_csv(index=False).encode("utf-8")
|
||||
filename = f"{self._param.output_filename}.csv"
|
||||
else:
|
||||
# Excel output
|
||||
excel_io = BytesIO()
|
||||
with pd.ExcelWriter(excel_io, engine='openpyxl') as writer:
|
||||
for sheet_name, df in dfs.items():
|
||||
# Sanitize sheet name (max 31 chars, no special chars)
|
||||
safe_name = sheet_name[:31].replace("/", "_").replace("\\", "_")
|
||||
df.to_excel(writer, sheet_name=safe_name, index=False)
|
||||
excel_io.seek(0)
|
||||
binary_content = excel_io.read()
|
||||
filename = f"{self._param.output_filename}.xlsx"
|
||||
|
||||
# Store file
|
||||
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
|
||||
|
||||
# Set attachment output
|
||||
self.set_output("attachment", {
|
||||
"doc_id": doc_id,
|
||||
"format": self._param.output_format,
|
||||
"file_name": filename
|
||||
})
|
||||
|
||||
total_rows = sum(len(df) for df in dfs.values())
|
||||
self.set_output("summary", f"Generated {filename} with {len(dfs)} sheet(s), {total_rows} total rows")
|
||||
self.set_output("data", {k: v.to_dict(orient="records") for k, v in dfs.items()})
|
||||
|
||||
logging.info(f"ExcelProcessor: Generated {filename} as {doc_id}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"ExcelProcessor output error: {e}")
|
||||
self.set_output("summary", f"Error generating output: {str(e)}")
|
||||
|
||||
def thoughts(self) -> str:
|
||||
"""Return component thoughts for UI display."""
|
||||
op = self._param.operation
|
||||
if op == "read":
|
||||
return "Reading Excel files..."
|
||||
elif op == "merge":
|
||||
return "Merging Excel data..."
|
||||
elif op == "transform":
|
||||
return "Transforming data..."
|
||||
elif op == "output":
|
||||
return "Generating Excel output..."
|
||||
return "Processing Excel..."
|
||||
@ -13,26 +13,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
from abc import ABC
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class WebhookParam(ComponentParamBase):
|
||||
|
||||
"""
|
||||
Define the Begin component parameters.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return getattr(self, "inputs")
|
||||
class ExitLoopParam(ComponentParamBase, ABC):
|
||||
def check(self):
|
||||
return True
|
||||
|
||||
|
||||
class Webhook(ComponentBase):
|
||||
component_name = "Webhook"
|
||||
class ExitLoop(ComponentBase, ABC):
|
||||
component_name = "ExitLoop"
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
pass
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return ""
|
||||
return ""
|
||||
@ -18,6 +18,7 @@ import re
|
||||
from functools import partial
|
||||
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
|
||||
class UserFillUpParam(ComponentParamBase):
|
||||
@ -63,6 +64,13 @@ class UserFillUp(ComponentBase):
|
||||
for k, v in kwargs.get("inputs", {}).items():
|
||||
if self.check_if_canceled("UserFillUp processing"):
|
||||
return
|
||||
if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
|
||||
if v.get("optional") and v.get("value", None) is None:
|
||||
v = None
|
||||
else:
|
||||
v = FileService.get_files([v["value"]])
|
||||
else:
|
||||
v = v.get("value")
|
||||
self.set_output(k, v)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
|
||||
@ -32,6 +32,7 @@ class IterationParam(ComponentParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.items_ref = ""
|
||||
self.variable={}
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
|
||||
168
agent/component/list_operations.py
Normal file
168
agent/component/list_operations.py
Normal file
@ -0,0 +1,168 @@
|
||||
from abc import ABC
|
||||
import os
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
class ListOperationsParam(ComponentParamBase):
|
||||
"""
|
||||
Define the List Operations component parameters.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.query = ""
|
||||
self.operations = "topN"
|
||||
self.n=0
|
||||
self.sort_method = "asc"
|
||||
self.filter = {
|
||||
"operator": "=",
|
||||
"value": ""
|
||||
}
|
||||
self.outputs = {
|
||||
"result": {
|
||||
"value": [],
|
||||
"type": "Array of ?"
|
||||
},
|
||||
"first": {
|
||||
"value": "",
|
||||
"type": "?"
|
||||
},
|
||||
"last": {
|
||||
"value": "",
|
||||
"type": "?"
|
||||
}
|
||||
}
|
||||
|
||||
def check(self):
|
||||
self.check_empty(self.query, "query")
|
||||
self.check_valid_value(self.operations, "Support operations", ["topN","head","tail","filter","sort","drop_duplicates"])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {}
|
||||
|
||||
|
||||
class ListOperations(ComponentBase,ABC):
|
||||
component_name = "ListOperations"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
self.input_objects=[]
|
||||
inputs = getattr(self._param, "query", None)
|
||||
self.inputs = self._canvas.get_variable_value(inputs)
|
||||
if not isinstance(self.inputs, list):
|
||||
raise TypeError("The input of List Operations should be an array.")
|
||||
self.set_input_value(inputs, self.inputs)
|
||||
if self._param.operations == "topN":
|
||||
self._topN()
|
||||
elif self._param.operations == "head":
|
||||
self._head()
|
||||
elif self._param.operations == "tail":
|
||||
self._tail()
|
||||
elif self._param.operations == "filter":
|
||||
self._filter()
|
||||
elif self._param.operations == "sort":
|
||||
self._sort()
|
||||
elif self._param.operations == "drop_duplicates":
|
||||
self._drop_duplicates()
|
||||
|
||||
|
||||
def _coerce_n(self):
|
||||
try:
|
||||
return int(getattr(self._param, "n", 0))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _set_outputs(self, outputs):
|
||||
self._param.outputs["result"]["value"] = outputs
|
||||
self._param.outputs["first"]["value"] = outputs[0] if outputs else None
|
||||
self._param.outputs["last"]["value"] = outputs[-1] if outputs else None
|
||||
|
||||
def _topN(self):
|
||||
n = self._coerce_n()
|
||||
if n < 1:
|
||||
outputs = []
|
||||
else:
|
||||
n = min(n, len(self.inputs))
|
||||
outputs = self.inputs[:n]
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _head(self):
|
||||
n = self._coerce_n()
|
||||
if 1 <= n <= len(self.inputs):
|
||||
outputs = [self.inputs[n - 1]]
|
||||
else:
|
||||
outputs = []
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _tail(self):
|
||||
n = self._coerce_n()
|
||||
if 1 <= n <= len(self.inputs):
|
||||
outputs = [self.inputs[-n]]
|
||||
else:
|
||||
outputs = []
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _filter(self):
|
||||
self._set_outputs([i for i in self.inputs if self._eval(self._norm(i),self._param.filter["operator"],self._param.filter["value"])])
|
||||
|
||||
def _norm(self,v):
|
||||
s = "" if v is None else str(v)
|
||||
return s
|
||||
|
||||
def _eval(self, v, operator, value):
|
||||
if operator == "=":
|
||||
return v == value
|
||||
elif operator == "≠":
|
||||
return v != value
|
||||
elif operator == "contains":
|
||||
return value in v
|
||||
elif operator == "start with":
|
||||
return v.startswith(value)
|
||||
elif operator == "end with":
|
||||
return v.endswith(value)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _sort(self):
|
||||
items = self.inputs or []
|
||||
method = getattr(self._param, "sort_method", "asc") or "asc"
|
||||
reverse = method == "desc"
|
||||
|
||||
if not items:
|
||||
self._set_outputs([])
|
||||
return
|
||||
|
||||
first = items[0]
|
||||
|
||||
if isinstance(first, dict):
|
||||
outputs = sorted(
|
||||
items,
|
||||
key=lambda x: self._hashable(x),
|
||||
reverse=reverse,
|
||||
)
|
||||
else:
|
||||
outputs = sorted(items, reverse=reverse)
|
||||
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _drop_duplicates(self):
|
||||
seen = set()
|
||||
outs = []
|
||||
for item in self.inputs:
|
||||
k = self._hashable(item)
|
||||
if k in seen:
|
||||
continue
|
||||
seen.add(k)
|
||||
outs.append(item)
|
||||
self._set_outputs(outs)
|
||||
|
||||
def _hashable(self,x):
|
||||
if isinstance(x, dict):
|
||||
return tuple(sorted((k, self._hashable(v)) for k, v in x.items()))
|
||||
if isinstance(x, (list, tuple)):
|
||||
return tuple(self._hashable(v) for v in x)
|
||||
if isinstance(x, set):
|
||||
return tuple(sorted(self._hashable(v) for v in x))
|
||||
return x
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "ListOperation in progress"
|
||||
@ -13,12 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Any, Generator
|
||||
from typing import Any, AsyncGenerator
|
||||
import json_repair
|
||||
from functools import partial
|
||||
from common.constants import LLMType
|
||||
@ -166,25 +167,67 @@ class LLM(ComponentBase):
|
||||
sys_prompt = re.sub(rf"<{tag}>(.*?)</{tag}>", "", sys_prompt, flags=re.DOTALL|re.IGNORECASE)
|
||||
return pts, sys_prompt
|
||||
|
||||
def _generate(self, msg:list[dict], **kwargs) -> str:
|
||||
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
|
||||
if not self.imgs:
|
||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||
|
||||
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
|
||||
ans = ""
|
||||
async def _generate_streamly(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
|
||||
async def delta_wrapper(txt_iter):
|
||||
ans = ""
|
||||
last_idx = 0
|
||||
endswith_think = False
|
||||
|
||||
def delta(txt):
|
||||
nonlocal ans, last_idx, endswith_think
|
||||
delta_ans = txt[last_idx:]
|
||||
ans = txt
|
||||
|
||||
if delta_ans.find("<think>") == 0:
|
||||
last_idx += len("<think>")
|
||||
return "<think>"
|
||||
elif delta_ans.find("<think>") > 0:
|
||||
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
|
||||
last_idx += delta_ans.find("<think>")
|
||||
return delta_ans
|
||||
elif delta_ans.endswith("</think>"):
|
||||
endswith_think = True
|
||||
elif endswith_think:
|
||||
endswith_think = False
|
||||
return "</think>"
|
||||
|
||||
last_idx = len(ans)
|
||||
if ans.endswith("</think>"):
|
||||
last_idx -= len("</think>")
|
||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
async for t in txt_iter:
|
||||
yield delta(t)
|
||||
|
||||
if not self.imgs:
|
||||
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
|
||||
yield t
|
||||
return
|
||||
|
||||
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
|
||||
yield t
|
||||
|
||||
async def _stream_output_async(self, prompt, msg):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer = ""
|
||||
last_idx = 0
|
||||
endswith_think = False
|
||||
|
||||
def delta(txt):
|
||||
nonlocal ans, last_idx, endswith_think
|
||||
nonlocal answer, last_idx, endswith_think
|
||||
delta_ans = txt[last_idx:]
|
||||
ans = txt
|
||||
answer = txt
|
||||
|
||||
if delta_ans.find("<think>") == 0:
|
||||
last_idx += len("<think>")
|
||||
return "<think>"
|
||||
elif delta_ans.find("<think>") > 0:
|
||||
delta_ans = txt[last_idx:last_idx+delta_ans.find("<think>")]
|
||||
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
|
||||
last_idx += delta_ans.find("<think>")
|
||||
return delta_ans
|
||||
elif delta_ans.endswith("</think>"):
|
||||
@ -193,20 +236,33 @@ class LLM(ComponentBase):
|
||||
endswith_think = False
|
||||
return "</think>"
|
||||
|
||||
last_idx = len(ans)
|
||||
if ans.endswith("</think>"):
|
||||
last_idx = len(answer)
|
||||
if answer.endswith("</think>"):
|
||||
last_idx -= len("</think>")
|
||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
if not self.imgs:
|
||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs):
|
||||
yield delta(txt)
|
||||
else:
|
||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
||||
yield delta(txt)
|
||||
stream_kwargs = {"images": self.imgs} if self.imgs else {}
|
||||
async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs):
|
||||
if self.check_if_canceled("LLM streaming"):
|
||||
return
|
||||
|
||||
if isinstance(ans, int):
|
||||
continue
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
|
||||
yield delta(ans)
|
||||
|
||||
self.set_output("content", answer)
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
@ -217,22 +273,25 @@ class LLM(ComponentBase):
|
||||
|
||||
prompt, msg, _ = self._prepare_prompt_variables()
|
||||
error: str = ""
|
||||
output_structure=None
|
||||
output_structure = None
|
||||
try:
|
||||
output_structure = self._param.outputs['structured']
|
||||
output_structure = self._param.outputs["structured"]
|
||||
except Exception:
|
||||
pass
|
||||
if output_structure:
|
||||
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||
prompt += structured_output_prompt(schema)
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0:
|
||||
schema = json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||
prompt_with_schema = prompt + structured_output_prompt(schema)
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
_, msg_fit = message_fit_in(
|
||||
[{"role": "system", "content": prompt_with_schema}, *deepcopy(msg)],
|
||||
int(self.chat_mdl.max_length * 0.97),
|
||||
)
|
||||
error = ""
|
||||
ans = self._generate(msg)
|
||||
msg.pop(0)
|
||||
ans = await self._generate_async(msg_fit)
|
||||
msg_fit.pop(0)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"LLM response error: {ans}")
|
||||
error = ans
|
||||
@ -241,7 +300,7 @@ class LLM(ComponentBase):
|
||||
self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
|
||||
return
|
||||
except Exception:
|
||||
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
|
||||
msg_fit.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
|
||||
error = "The answer can't not be parsed as JSON"
|
||||
if error:
|
||||
self.set_output("_ERROR", error)
|
||||
@ -249,18 +308,23 @@ class LLM(ComponentBase):
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self._stream_output, prompt, msg))
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams]) and not (
|
||||
ex and ex["goto"]
|
||||
):
|
||||
self.set_output("content", partial(self._stream_output_async, prompt, deepcopy(msg)))
|
||||
return
|
||||
|
||||
for _ in range(self._param.max_retries+1):
|
||||
error = ""
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
_, msg_fit = message_fit_in(
|
||||
[{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97)
|
||||
)
|
||||
error = ""
|
||||
ans = self._generate(msg)
|
||||
msg.pop(0)
|
||||
ans = await self._generate_async(msg_fit)
|
||||
msg_fit.pop(0)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"LLM response error: {ans}")
|
||||
error = ans
|
||||
@ -274,26 +338,12 @@ class LLM(ComponentBase):
|
||||
else:
|
||||
self.set_output("_ERROR", error)
|
||||
|
||||
def _stream_output(self, prompt, msg):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer = ""
|
||||
for ans in self._generate_streamly(msg):
|
||||
if self.check_if_canceled("LLM streaming"):
|
||||
return
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
yield ans
|
||||
answer += ans
|
||||
self.set_output("content", answer)
|
||||
|
||||
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
|
||||
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
|
||||
async def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
|
||||
summ = await tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
|
||||
logging.info(f"[MEMORY]: {summ}")
|
||||
self._canvas.add_memory(user, assist, summ)
|
||||
|
||||
|
||||
80
agent/component/loop.py
Normal file
80
agent/component/loop.py
Normal file
@ -0,0 +1,80 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class LoopParam(ComponentParamBase):
|
||||
"""
|
||||
Define the Loop component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loop_variables = []
|
||||
self.loop_termination_condition=[]
|
||||
self.maximum_loop_count = 0
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"items": {
|
||||
"type": "json",
|
||||
"name": "Items"
|
||||
}
|
||||
}
|
||||
|
||||
def check(self):
|
||||
return True
|
||||
|
||||
|
||||
class Loop(ComponentBase, ABC):
|
||||
component_name = "Loop"
|
||||
|
||||
def get_start(self):
|
||||
for cid in self._canvas.components.keys():
|
||||
if self._canvas.get_component(cid)["obj"].component_name.lower() != "loopitem":
|
||||
continue
|
||||
if self._canvas.get_component(cid)["parent_id"] == self._id:
|
||||
return cid
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Loop processing"):
|
||||
return
|
||||
|
||||
for item in self._param.loop_variables:
|
||||
if any([not item.get("variable"), not item.get("input_mode"), not item.get("value"),not item.get("type")]):
|
||||
assert "Loop Variable is not complete."
|
||||
if item["input_mode"]=="variable":
|
||||
self.set_output(item["variable"],self._canvas.get_variable_value(item["value"]))
|
||||
elif item["input_mode"]=="constant":
|
||||
self.set_output(item["variable"],item["value"])
|
||||
else:
|
||||
if item["type"] == "number":
|
||||
self.set_output(item["variable"], 0)
|
||||
elif item["type"] == "string":
|
||||
self.set_output(item["variable"], "")
|
||||
elif item["type"] == "boolean":
|
||||
self.set_output(item["variable"], False)
|
||||
elif item["type"].startswith("object"):
|
||||
self.set_output(item["variable"], {})
|
||||
elif item["type"].startswith("array"):
|
||||
self.set_output(item["variable"], [])
|
||||
else:
|
||||
self.set_output(item["variable"], "")
|
||||
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Loop from canvas."
|
||||
163
agent/component/loopitem.py
Normal file
163
agent/component/loopitem.py
Normal file
@ -0,0 +1,163 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class LoopItemParam(ComponentParamBase):
|
||||
"""
|
||||
Define the LoopItem component parameters.
|
||||
"""
|
||||
def check(self):
|
||||
return True
|
||||
|
||||
class LoopItem(ComponentBase, ABC):
|
||||
component_name = "LoopItem"
|
||||
|
||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||
super().__init__(canvas, id, param)
|
||||
self._idx = 0
|
||||
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("LoopItem processing"):
|
||||
return
|
||||
parent = self.get_parent()
|
||||
maximum_loop_count = parent._param.maximum_loop_count
|
||||
if self._idx >= maximum_loop_count:
|
||||
self._idx = -1
|
||||
return
|
||||
if self._idx > 0:
|
||||
if self.check_if_canceled("LoopItem processing"):
|
||||
return
|
||||
self._idx += 1
|
||||
|
||||
def evaluate_condition(self,var, operator, value):
|
||||
if isinstance(var, str):
|
||||
if operator == "contains":
|
||||
return value in var
|
||||
elif operator == "not contains":
|
||||
return value not in var
|
||||
elif operator == "start with":
|
||||
return var.startswith(value)
|
||||
elif operator == "end with":
|
||||
return var.endswith(value)
|
||||
elif operator == "is":
|
||||
return var == value
|
||||
elif operator == "is not":
|
||||
return var != value
|
||||
elif operator == "empty":
|
||||
return var == ""
|
||||
elif operator == "not empty":
|
||||
return var != ""
|
||||
|
||||
elif isinstance(var, (int, float)):
|
||||
if operator == "=":
|
||||
return var == value
|
||||
elif operator == "≠":
|
||||
return var != value
|
||||
elif operator == ">":
|
||||
return var > value
|
||||
elif operator == "<":
|
||||
return var < value
|
||||
elif operator == "≥":
|
||||
return var >= value
|
||||
elif operator == "≤":
|
||||
return var <= value
|
||||
elif operator == "empty":
|
||||
return var is None
|
||||
elif operator == "not empty":
|
||||
return var is not None
|
||||
|
||||
elif isinstance(var, bool):
|
||||
if operator == "is":
|
||||
return var is value
|
||||
elif operator == "is not":
|
||||
return var is not value
|
||||
elif operator == "empty":
|
||||
return var is None
|
||||
elif operator == "not empty":
|
||||
return var is not None
|
||||
|
||||
elif isinstance(var, dict):
|
||||
if operator == "empty":
|
||||
return len(var) == 0
|
||||
elif operator == "not empty":
|
||||
return len(var) > 0
|
||||
|
||||
elif isinstance(var, list):
|
||||
if operator == "contains":
|
||||
return value in var
|
||||
elif operator == "not contains":
|
||||
return value not in var
|
||||
|
||||
elif operator == "is":
|
||||
return var == value
|
||||
elif operator == "is not":
|
||||
return var != value
|
||||
|
||||
elif operator == "empty":
|
||||
return len(var) == 0
|
||||
elif operator == "not empty":
|
||||
return len(var) > 0
|
||||
|
||||
raise Exception(f"Invalid operator: {operator}")
|
||||
|
||||
def end(self):
|
||||
if self._idx == -1:
|
||||
return True
|
||||
parent = self.get_parent()
|
||||
logical_operator = parent._param.logical_operator if hasattr(parent._param, "logical_operator") else "and"
|
||||
conditions = []
|
||||
for item in parent._param.loop_termination_condition:
|
||||
if not item.get("variable") or not item.get("operator"):
|
||||
raise ValueError("Loop condition is incomplete.")
|
||||
var = self._canvas.get_variable_value(item["variable"])
|
||||
operator = item["operator"]
|
||||
input_mode = item.get("input_mode", "constant")
|
||||
|
||||
if input_mode == "variable":
|
||||
value = self._canvas.get_variable_value(item.get("value", ""))
|
||||
elif input_mode == "constant":
|
||||
value = item.get("value", "")
|
||||
else:
|
||||
raise ValueError("Invalid input mode.")
|
||||
conditions.append(self.evaluate_condition(var, operator, value))
|
||||
should_end = (
|
||||
all(conditions) if logical_operator == "and"
|
||||
else any(conditions) if logical_operator == "or"
|
||||
else None
|
||||
)
|
||||
if should_end is None:
|
||||
raise ValueError("Invalid logical operator,should be 'and' or 'or'.")
|
||||
|
||||
if should_end:
|
||||
self._idx = -1
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def next(self):
|
||||
if self._idx == -1:
|
||||
self._idx = 0
|
||||
else:
|
||||
self._idx += 1
|
||||
if self._idx >= len(self._items):
|
||||
self._idx = -1
|
||||
return False
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Next turn..."
|
||||
@ -13,10 +13,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import logging
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
@ -24,6 +30,8 @@ from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from jinja2 import Template as Jinja2Template
|
||||
|
||||
from common.connection_utils import timeout
|
||||
from common.misc_utils import get_uuid
|
||||
from common import settings
|
||||
|
||||
|
||||
class MessageParam(ComponentParamBase):
|
||||
@ -34,6 +42,8 @@ class MessageParam(ComponentParamBase):
|
||||
super().__init__()
|
||||
self.content = []
|
||||
self.stream = True
|
||||
self.output_format = None # default output format
|
||||
self.auto_play = False
|
||||
self.outputs = {
|
||||
"content": {
|
||||
"type": "str"
|
||||
@ -61,8 +71,12 @@ class Message(ComponentBase):
|
||||
v = ""
|
||||
ans = ""
|
||||
if isinstance(v, partial):
|
||||
for t in v():
|
||||
ans += t
|
||||
iter_obj = v()
|
||||
if inspect.isasyncgen(iter_obj):
|
||||
ans = asyncio.run(self._consume_async_gen(iter_obj))
|
||||
else:
|
||||
for t in iter_obj:
|
||||
ans += t
|
||||
elif isinstance(v, list) and delimiter:
|
||||
ans = delimiter.join([str(vv) for vv in v])
|
||||
elif not isinstance(v, str):
|
||||
@ -84,7 +98,13 @@ class Message(ComponentBase):
|
||||
_kwargs[_n] = v
|
||||
return script, _kwargs
|
||||
|
||||
def _stream(self, rand_cnt:str):
|
||||
async def _consume_async_gen(self, agen):
|
||||
buf = ""
|
||||
async for t in agen:
|
||||
buf += t
|
||||
return buf
|
||||
|
||||
async def _stream(self, rand_cnt:str):
|
||||
s = 0
|
||||
all_content = ""
|
||||
cache = {}
|
||||
@ -106,15 +126,27 @@ class Message(ComponentBase):
|
||||
v = ""
|
||||
if isinstance(v, partial):
|
||||
cnt = ""
|
||||
for t in v():
|
||||
if self.check_if_canceled("Message streaming"):
|
||||
return
|
||||
iter_obj = v()
|
||||
if inspect.isasyncgen(iter_obj):
|
||||
async for t in iter_obj:
|
||||
if self.check_if_canceled("Message streaming"):
|
||||
return
|
||||
|
||||
all_content += t
|
||||
cnt += t
|
||||
yield t
|
||||
all_content += t
|
||||
cnt += t
|
||||
yield t
|
||||
else:
|
||||
for t in iter_obj:
|
||||
if self.check_if_canceled("Message streaming"):
|
||||
return
|
||||
|
||||
all_content += t
|
||||
cnt += t
|
||||
yield t
|
||||
self.set_input_value(exp, cnt)
|
||||
continue
|
||||
elif inspect.isawaitable(v):
|
||||
v = await v
|
||||
elif not isinstance(v, str):
|
||||
try:
|
||||
v = json.dumps(v, ensure_ascii=False)
|
||||
@ -133,6 +165,7 @@ class Message(ComponentBase):
|
||||
yield rand_cnt[s: ]
|
||||
|
||||
self.set_output("content", all_content)
|
||||
self._convert_content(all_content)
|
||||
|
||||
def _is_jinjia2(self, content:str) -> bool:
|
||||
patt = [
|
||||
@ -164,6 +197,227 @@ class Message(ComponentBase):
|
||||
content = re.sub(n, v, content)
|
||||
|
||||
self.set_output("content", content)
|
||||
self._convert_content(content)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return ""
|
||||
|
||||
def _parse_markdown_table_lines(self, table_lines: list):
|
||||
"""
|
||||
Parse a list of Markdown table lines into a pandas DataFrame.
|
||||
|
||||
Args:
|
||||
table_lines: List of strings, each representing a row in the Markdown table
|
||||
(excluding separator lines like |---|---|)
|
||||
|
||||
Returns:
|
||||
pandas DataFrame with the table data, or None if parsing fails
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
if not table_lines:
|
||||
return None
|
||||
|
||||
rows = []
|
||||
headers = None
|
||||
|
||||
for line in table_lines:
|
||||
# Split by | and clean up
|
||||
cells = [cell.strip() for cell in line.split('|')]
|
||||
# Remove empty first and last elements from split (caused by leading/trailing |)
|
||||
cells = [c for c in cells if c]
|
||||
|
||||
if headers is None:
|
||||
headers = cells
|
||||
else:
|
||||
rows.append(cells)
|
||||
|
||||
if headers and rows:
|
||||
# Ensure all rows have same number of columns as headers
|
||||
normalized_rows = []
|
||||
for row in rows:
|
||||
while len(row) < len(headers):
|
||||
row.append('')
|
||||
normalized_rows.append(row[:len(headers)])
|
||||
|
||||
return pd.DataFrame(normalized_rows, columns=headers)
|
||||
|
||||
return None
|
||||
|
||||
def _convert_content(self, content):
|
||||
if not self._param.output_format:
|
||||
return
|
||||
|
||||
import pypandoc
|
||||
doc_id = get_uuid()
|
||||
|
||||
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx", "xlsx"}:
|
||||
self._param.output_format = "markdown"
|
||||
|
||||
try:
|
||||
if self._param.output_format in {"markdown", "html"}:
|
||||
if isinstance(content, str):
|
||||
converted = pypandoc.convert_text(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
)
|
||||
else:
|
||||
converted = pypandoc.convert_file(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
)
|
||||
|
||||
binary_content = converted.encode("utf-8")
|
||||
|
||||
elif self._param.output_format == "xlsx":
|
||||
import pandas as pd
|
||||
from io import BytesIO
|
||||
|
||||
# Debug: log the content being parsed
|
||||
logging.info(f"XLSX Parser: Content length={len(content) if content else 0}, first 500 chars: {content[:500] if content else 'None'}")
|
||||
|
||||
# Try to parse ALL Markdown tables from the content
|
||||
# Each table will be written to a separate sheet
|
||||
tables = [] # List of (sheet_name, dataframe)
|
||||
|
||||
if isinstance(content, str):
|
||||
lines = content.strip().split('\n')
|
||||
logging.info(f"XLSX Parser: Total lines={len(lines)}, lines starting with '|': {sum(1 for line in lines if line.strip().startswith('|'))}")
|
||||
current_table_lines = []
|
||||
current_table_title = None
|
||||
pending_title = None
|
||||
in_table = False
|
||||
table_count = 0
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
stripped = line.strip()
|
||||
|
||||
# Check for potential table title (lines before a table)
|
||||
# Look for patterns like "Table 1:", "## Table", or markdown headers
|
||||
if not in_table and stripped and not stripped.startswith('|'):
|
||||
# Check if this could be a table title
|
||||
lower_stripped = stripped.lower()
|
||||
if (lower_stripped.startswith('table') or
|
||||
stripped.startswith('#') or
|
||||
':' in stripped):
|
||||
pending_title = stripped.lstrip('#').strip()
|
||||
|
||||
if stripped.startswith('|') and '|' in stripped[1:]:
|
||||
# Check if this is a separator line (|---|---|)
|
||||
cleaned = stripped.replace(' ', '').replace('|', '').replace('-', '').replace(':', '')
|
||||
if cleaned == '':
|
||||
continue # Skip separator line
|
||||
|
||||
if not in_table:
|
||||
# Starting a new table
|
||||
in_table = True
|
||||
current_table_lines = []
|
||||
current_table_title = pending_title
|
||||
pending_title = None
|
||||
|
||||
current_table_lines.append(stripped)
|
||||
|
||||
elif in_table and not stripped.startswith('|'):
|
||||
# End of current table - save it
|
||||
if current_table_lines:
|
||||
df = self._parse_markdown_table_lines(current_table_lines)
|
||||
if df is not None and not df.empty:
|
||||
table_count += 1
|
||||
# Generate sheet name
|
||||
if current_table_title:
|
||||
# Clean and truncate title for sheet name
|
||||
sheet_name = current_table_title[:31]
|
||||
sheet_name = sheet_name.replace('/', '_').replace('\\', '_').replace('*', '').replace('?', '').replace('[', '').replace(']', '').replace(':', '')
|
||||
else:
|
||||
sheet_name = f"Table_{table_count}"
|
||||
tables.append((sheet_name, df))
|
||||
|
||||
# Reset for next table
|
||||
in_table = False
|
||||
current_table_lines = []
|
||||
current_table_title = None
|
||||
|
||||
# Check if this line could be a title for the next table
|
||||
if stripped:
|
||||
lower_stripped = stripped.lower()
|
||||
if (lower_stripped.startswith('table') or
|
||||
stripped.startswith('#') or
|
||||
':' in stripped):
|
||||
pending_title = stripped.lstrip('#').strip()
|
||||
|
||||
# Don't forget the last table if content ends with a table
|
||||
if in_table and current_table_lines:
|
||||
df = self._parse_markdown_table_lines(current_table_lines)
|
||||
if df is not None and not df.empty:
|
||||
table_count += 1
|
||||
if current_table_title:
|
||||
sheet_name = current_table_title[:31]
|
||||
sheet_name = sheet_name.replace('/', '_').replace('\\', '_').replace('*', '').replace('?', '').replace('[', '').replace(']', '').replace(':', '')
|
||||
else:
|
||||
sheet_name = f"Table_{table_count}"
|
||||
tables.append((sheet_name, df))
|
||||
|
||||
# Fallback: if no tables found, create single sheet with content
|
||||
if not tables:
|
||||
df = pd.DataFrame({"Content": [content if content else ""]})
|
||||
tables = [("Data", df)]
|
||||
|
||||
# Write all tables to Excel, each in a separate sheet
|
||||
excel_io = BytesIO()
|
||||
with pd.ExcelWriter(excel_io, engine='openpyxl') as writer:
|
||||
used_names = set()
|
||||
for sheet_name, df in tables:
|
||||
# Ensure unique sheet names
|
||||
original_name = sheet_name
|
||||
counter = 1
|
||||
while sheet_name in used_names:
|
||||
suffix = f"_{counter}"
|
||||
sheet_name = original_name[:31-len(suffix)] + suffix
|
||||
counter += 1
|
||||
used_names.add(sheet_name)
|
||||
df.to_excel(writer, sheet_name=sheet_name, index=False)
|
||||
|
||||
excel_io.seek(0)
|
||||
binary_content = excel_io.read()
|
||||
|
||||
logging.info(f"Generated Excel with {len(tables)} sheet(s): {[t[0] for t in tables]}")
|
||||
|
||||
else: # pdf, docx
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{self._param.output_format}", delete=False) as tmp:
|
||||
tmp_name = tmp.name
|
||||
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
pypandoc.convert_text(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
outputfile=tmp_name,
|
||||
)
|
||||
else:
|
||||
pypandoc.convert_file(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
outputfile=tmp_name,
|
||||
)
|
||||
|
||||
with open(tmp_name, "rb") as f:
|
||||
binary_content = f.read()
|
||||
|
||||
finally:
|
||||
if os.path.exists(tmp_name):
|
||||
os.remove(tmp_name)
|
||||
|
||||
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
|
||||
self.set_output("attachment", {
|
||||
"doc_id":doc_id,
|
||||
"format":self._param.output_format,
|
||||
"file_name":f"{doc_id[:8]}.{self._param.output_format}"})
|
||||
|
||||
logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error converting content to {self._param.output_format}: {e}")
|
||||
|
||||
192
agent/component/variable_assigner.py
Normal file
192
agent/component/variable_assigner.py
Normal file
@ -0,0 +1,192 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import os
|
||||
import numbers
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
class VariableAssignerParam(ComponentParamBase):
|
||||
"""
|
||||
Define the Variable Assigner component parameters.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.variables=[]
|
||||
|
||||
def check(self):
|
||||
return True
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"items": {
|
||||
"type": "json",
|
||||
"name": "Items"
|
||||
}
|
||||
}
|
||||
|
||||
class VariableAssigner(ComponentBase,ABC):
|
||||
component_name = "VariableAssigner"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not isinstance(self._param.variables,list):
|
||||
return
|
||||
else:
|
||||
for item in self._param.variables:
|
||||
if any([not item.get("variable"), not item.get("operator"), not item.get("parameter")]):
|
||||
assert "Variable is not complete."
|
||||
variable=item["variable"]
|
||||
operator=item["operator"]
|
||||
parameter=item["parameter"]
|
||||
variable_value=self._canvas.get_variable_value(variable)
|
||||
new_variable=self._operate(variable_value,operator,parameter)
|
||||
self._canvas.set_variable_value(variable, new_variable)
|
||||
|
||||
def _operate(self,variable,operator,parameter):
|
||||
if operator == "overwrite":
|
||||
return self._overwrite(parameter)
|
||||
elif operator == "clear":
|
||||
return self._clear(variable)
|
||||
elif operator == "set":
|
||||
return self._set(variable,parameter)
|
||||
elif operator == "append":
|
||||
return self._append(variable,parameter)
|
||||
elif operator == "extend":
|
||||
return self._extend(variable,parameter)
|
||||
elif operator == "remove_first":
|
||||
return self._remove_first(variable)
|
||||
elif operator == "remove_last":
|
||||
return self._remove_last(variable)
|
||||
elif operator == "+=":
|
||||
return self._add(variable,parameter)
|
||||
elif operator == "-=":
|
||||
return self._subtract(variable,parameter)
|
||||
elif operator == "*=":
|
||||
return self._multiply(variable,parameter)
|
||||
elif operator == "/=":
|
||||
return self._divide(variable,parameter)
|
||||
else:
|
||||
return
|
||||
|
||||
def _overwrite(self,parameter):
|
||||
return self._canvas.get_variable_value(parameter)
|
||||
|
||||
def _clear(self,variable):
|
||||
if isinstance(variable,list):
|
||||
return []
|
||||
elif isinstance(variable,str):
|
||||
return ""
|
||||
elif isinstance(variable,dict):
|
||||
return {}
|
||||
elif isinstance(variable,int):
|
||||
return 0
|
||||
elif isinstance(variable,float):
|
||||
return 0.0
|
||||
elif isinstance(variable,bool):
|
||||
return False
|
||||
else:
|
||||
return None
|
||||
|
||||
def _set(self,variable,parameter):
|
||||
if variable is None:
|
||||
return self._canvas.get_value_with_variable(parameter)
|
||||
elif isinstance(variable,str):
|
||||
return self._canvas.get_value_with_variable(parameter)
|
||||
elif isinstance(variable,bool):
|
||||
return parameter
|
||||
elif isinstance(variable,int):
|
||||
return parameter
|
||||
elif isinstance(variable,float):
|
||||
return parameter
|
||||
else:
|
||||
return parameter
|
||||
|
||||
def _append(self,variable,parameter):
|
||||
parameter=self._canvas.get_variable_value(parameter)
|
||||
if variable is None:
|
||||
variable=[]
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
elif len(variable)!=0 and not isinstance(parameter,type(variable[0])):
|
||||
return "ERROR:PARAMETER_NOT_LIST_ELEMENT_TYPE"
|
||||
else:
|
||||
variable.append(parameter)
|
||||
return variable
|
||||
|
||||
def _extend(self,variable,parameter):
|
||||
parameter=self._canvas.get_variable_value(parameter)
|
||||
if variable is None:
|
||||
variable=[]
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
elif not isinstance(parameter,list):
|
||||
return "ERROR:PARAMETER_NOT_LIST"
|
||||
elif len(variable)!=0 and len(parameter)!=0 and not isinstance(parameter[0],type(variable[0])):
|
||||
return "ERROR:PARAMETER_NOT_LIST_ELEMENT_TYPE"
|
||||
else:
|
||||
return variable + parameter
|
||||
|
||||
def _remove_first(self,variable):
|
||||
if len(variable)==0:
|
||||
return variable
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
else:
|
||||
return variable[1:]
|
||||
|
||||
def _remove_last(self,variable):
|
||||
if len(variable)==0:
|
||||
return variable
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
else:
|
||||
return variable[:-1]
|
||||
|
||||
def is_number(self, value):
|
||||
if isinstance(value, bool):
|
||||
return False
|
||||
return isinstance(value, numbers.Number)
|
||||
|
||||
def _add(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
return variable + parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def _subtract(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
return variable - parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def _multiply(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
return variable * parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def _divide(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
if parameter==0:
|
||||
return "ERROR:DIVIDE_BY_ZERO"
|
||||
else:
|
||||
return variable/parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Assign variables from canvas."
|
||||
@ -193,7 +193,7 @@
|
||||
"presence_penalty": 0.4,
|
||||
"prompts": [
|
||||
{
|
||||
"content": "Text Content:\n{Splitter:KindDingosJam@chunks}\n",
|
||||
"content": "Text Content:\n{Splitter:NineTiesSin@chunks}\n",
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
@ -226,7 +226,7 @@
|
||||
"presence_penalty": 0.4,
|
||||
"prompts": [
|
||||
{
|
||||
"content": "Text Content:\n\n{Splitter:KindDingosJam@chunks}\n",
|
||||
"content": "Text Content:\n\n{Splitter:TastyPointsLay@chunks}\n",
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
@ -259,7 +259,7 @@
|
||||
"presence_penalty": 0.4,
|
||||
"prompts": [
|
||||
{
|
||||
"content": "Content: \n\n{Splitter:KindDingosJam@chunks}",
|
||||
"content": "Content: \n\n{Splitter:CuteBusesBet@chunks}",
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
@ -485,7 +485,7 @@
|
||||
"outputs": {},
|
||||
"presencePenaltyEnabled": false,
|
||||
"presence_penalty": 0.4,
|
||||
"prompts": "Text Content:\n{Splitter:KindDingosJam@chunks}\n",
|
||||
"prompts": "Text Content:\n{Splitter:NineTiesSin@chunks}\n",
|
||||
"sys_prompt": "Role\nYou are a text analyzer.\n\nTask\nExtract the most important keywords/phrases of a given piece of text content.\n\nRequirements\n- Summarize the text content, and give the top 5 important keywords/phrases.\n- The keywords MUST be in the same language as the given piece of text content.\n- The keywords are delimited by ENGLISH COMMA.\n- Output keywords ONLY.",
|
||||
"temperature": 0.1,
|
||||
"temperatureEnabled": false,
|
||||
@ -522,7 +522,7 @@
|
||||
"outputs": {},
|
||||
"presencePenaltyEnabled": false,
|
||||
"presence_penalty": 0.4,
|
||||
"prompts": "Text Content:\n\n{Splitter:KindDingosJam@chunks}\n",
|
||||
"prompts": "Text Content:\n\n{Splitter:TastyPointsLay@chunks}\n",
|
||||
"sys_prompt": "Role\nYou are a text analyzer.\n\nTask\nPropose 3 questions about a given piece of text content.\n\nRequirements\n- Understand and summarize the text content, and propose the top 3 important questions.\n- The questions SHOULD NOT have overlapping meanings.\n- The questions SHOULD cover the main content of the text as much as possible.\n- The questions MUST be in the same language as the given piece of text content.\n- One question per line.\n- Output questions ONLY.",
|
||||
"temperature": 0.1,
|
||||
"temperatureEnabled": false,
|
||||
@ -559,7 +559,7 @@
|
||||
"outputs": {},
|
||||
"presencePenaltyEnabled": false,
|
||||
"presence_penalty": 0.4,
|
||||
"prompts": "Content: \n\n{Splitter:KindDingosJam@chunks}",
|
||||
"prompts": "Content: \n\n{Splitter:BlueResultsWink@chunks}",
|
||||
"sys_prompt": "Extract important structured information from the given content. Output ONLY a valid JSON string with no additional text. If no important structured information is found, output an empty JSON object: {}.\n\nImportant structured information may include: names, dates, locations, events, key facts, numerical data, or other extractable entities.",
|
||||
"temperature": 0.1,
|
||||
"temperatureEnabled": false,
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -83,10 +83,10 @@
|
||||
"value": []
|
||||
}
|
||||
},
|
||||
"password": "20010812Yy!",
|
||||
"password": "",
|
||||
"port": 3306,
|
||||
"sql": "{Agent:WickedGoatsDivide@content}",
|
||||
"username": "13637682833@163.com"
|
||||
"username": ""
|
||||
}
|
||||
},
|
||||
"upstream": [
|
||||
@ -527,10 +527,10 @@
|
||||
"value": []
|
||||
}
|
||||
},
|
||||
"password": "20010812Yy!",
|
||||
"password": "",
|
||||
"port": 3306,
|
||||
"sql": "{Agent:WickedGoatsDivide@content}",
|
||||
"username": "13637682833@163.com"
|
||||
"username": ""
|
||||
},
|
||||
"label": "ExeSQL",
|
||||
"name": "ExeSQL"
|
||||
@ -578,7 +578,7 @@
|
||||
{
|
||||
"data": {
|
||||
"form": {
|
||||
"text": "Searches for relevant database creation statements.\n\nIt should label with a knowledgebase to which the schema is dumped in. You could use \" General \" as parsing method, \" 2 \" as chunk size and \" ; \" as delimiter."
|
||||
"text": "Searches for relevant database creation statements.\n\nIt should label with a dataset to which the schema is dumped in. You could use \" General \" as parsing method, \" 2 \" as chunk size and \" ; \" as delimiter."
|
||||
},
|
||||
"label": "Note",
|
||||
"name": "Note Schema"
|
||||
|
||||
@ -75,7 +75,7 @@
|
||||
},
|
||||
"history": [],
|
||||
"path": [],
|
||||
"retrival": {"chunks": [], "doc_aggs": []},
|
||||
"retrieval": {"chunks": [], "doc_aggs": []},
|
||||
"globals": {
|
||||
"sys.query": "",
|
||||
"sys.user_id": "",
|
||||
|
||||
@ -82,7 +82,7 @@
|
||||
},
|
||||
"history": [],
|
||||
"path": [],
|
||||
"retrival": {"chunks": [], "doc_aggs": []},
|
||||
"retrieval": {"chunks": [], "doc_aggs": []},
|
||||
"globals": {
|
||||
"sys.query": "",
|
||||
"sys.user_id": "",
|
||||
|
||||
@ -31,7 +31,7 @@
|
||||
"component_name": "LLM",
|
||||
"params": {
|
||||
"llm_id": "deepseek-chat",
|
||||
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {retrieval:0@formalized_content}\n The above is the knowledge base.",
|
||||
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {retrieval:0@formalized_content}\n Above is the knowledge base.",
|
||||
"temperature": 0.2
|
||||
}
|
||||
},
|
||||
@ -51,7 +51,7 @@
|
||||
},
|
||||
"history": [],
|
||||
"path": [],
|
||||
"retrival": {"chunks": [], "doc_aggs": []},
|
||||
"retrieval": {"chunks": [], "doc_aggs": []},
|
||||
"globals": {
|
||||
"sys.query": "",
|
||||
"sys.user_id": "",
|
||||
|
||||
@ -65,7 +65,7 @@
|
||||
"component_name": "Agent",
|
||||
"params": {
|
||||
"llm_id": "deepseek-chat",
|
||||
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {retrieval:0@formalized_content}\n The above is the knowledge base.",
|
||||
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the dataset to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {retrieval:0@formalized_content}\n The above is the knowledge base.",
|
||||
"temperature": 0.2
|
||||
}
|
||||
},
|
||||
@ -85,7 +85,7 @@
|
||||
},
|
||||
"history": [],
|
||||
"path": [],
|
||||
"retrival": {"chunks": [], "doc_aggs": []},
|
||||
"retrieval": {"chunks": [], "doc_aggs": []},
|
||||
"globals": {
|
||||
"sys.query": "",
|
||||
"sys.user_id": "",
|
||||
|
||||
@ -25,7 +25,7 @@
|
||||
"component_name": "LLM",
|
||||
"params": {
|
||||
"llm_id": "deepseek-chat",
|
||||
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {tavily:0@formalized_content}\n The above is the knowledge base.",
|
||||
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {tavily:0@formalized_content}\n Above is the knowledge base.",
|
||||
"temperature": 0.2
|
||||
}
|
||||
},
|
||||
@ -45,7 +45,7 @@
|
||||
},
|
||||
"history": [],
|
||||
"path": [],
|
||||
"retrival": {"chunks": [], "doc_aggs": []},
|
||||
"retrieval": {"chunks": [], "doc_aggs": []},
|
||||
"globals": {
|
||||
"sys.query": "",
|
||||
"sys.user_id": "",
|
||||
|
||||
@ -17,13 +17,13 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import TypedDict, List, Any
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
from common.misc_utils import hash_str2int
|
||||
from rag.llm.chat_model import ToolCallSession
|
||||
from rag.prompts.generator import kb_prompt
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
|
||||
from timeit import default_timer as timer
|
||||
|
||||
|
||||
@ -49,12 +49,19 @@ class LLMToolPluginCallSession(ToolCallSession):
|
||||
self.callback = callback
|
||||
|
||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||
return asyncio.run(self.tool_call_async(name, arguments))
|
||||
|
||||
async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
||||
st = timer()
|
||||
if isinstance(self.tools_map[name], MCPToolCallSession):
|
||||
resp = self.tools_map[name].tool_call(name, arguments, 60)
|
||||
tool_obj = self.tools_map[name]
|
||||
if isinstance(tool_obj, MCPToolCallSession):
|
||||
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
|
||||
else:
|
||||
resp = self.tools_map[name].invoke(**arguments)
|
||||
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
|
||||
resp = await tool_obj.invoke_async(**arguments)
|
||||
else:
|
||||
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
|
||||
|
||||
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
||||
return resp
|
||||
@ -140,6 +147,33 @@ class ToolBase(ComponentBase):
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return res
|
||||
|
||||
async def invoke_async(self, **kwargs):
|
||||
"""
|
||||
Async wrapper for tool invocation.
|
||||
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
|
||||
Mirrors the exception handling of `invoke`.
|
||||
"""
|
||||
if self.check_if_canceled("Tool processing"):
|
||||
return
|
||||
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
try:
|
||||
fn_async = getattr(self, "_invoke_async", None)
|
||||
if fn_async and asyncio.iscoroutinefunction(fn_async):
|
||||
res = await fn_async(**kwargs)
|
||||
elif asyncio.iscoroutinefunction(self._invoke):
|
||||
res = await self._invoke(**kwargs)
|
||||
else:
|
||||
res = await asyncio.to_thread(self._invoke, **kwargs)
|
||||
except Exception as e:
|
||||
self._param.outputs["_ERROR"] = {"value": str(e)}
|
||||
logging.exception(e)
|
||||
res = str(e)
|
||||
self._param.debug_inputs = []
|
||||
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return res
|
||||
|
||||
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
|
||||
chunks = []
|
||||
aggs = []
|
||||
|
||||
@ -13,16 +13,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import ast
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC
|
||||
from strenum import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||
from common.connection_utils import timeout
|
||||
from strenum import StrEnum
|
||||
|
||||
from agent.tools.base import ToolBase, ToolMeta, ToolParamBase
|
||||
from common import settings
|
||||
from common.connection_utils import timeout
|
||||
|
||||
|
||||
class Language(StrEnum):
|
||||
@ -62,10 +66,10 @@ class CodeExecParam(ToolParamBase):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
self.meta: ToolMeta = {
|
||||
"name": "execute_code",
|
||||
"description": """
|
||||
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It recieves a piece of code and return a Json string.
|
||||
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It receives a piece of code and return a Json string.
|
||||
Here's a code example for Python(`main` function MUST be included):
|
||||
def main() -> dict:
|
||||
\"\"\"
|
||||
@ -99,16 +103,12 @@ module.exports = { main };
|
||||
"enum": ["python", "javascript"],
|
||||
"required": True,
|
||||
},
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "A piece of code in right format. There MUST be main function.",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
"script": {"type": "string", "description": "A piece of code in right format. There MUST be main function.", "required": True},
|
||||
},
|
||||
}
|
||||
super().__init__()
|
||||
self.lang = Language.PYTHON.value
|
||||
self.script = "def main(arg1: str, arg2: str) -> dict: return {\"result\": arg1 + arg2}"
|
||||
self.script = 'def main(arg1: str, arg2: str) -> dict: return {"result": arg1 + arg2}'
|
||||
self.arguments = {}
|
||||
self.outputs = {"result": {"value": "", "type": "string"}}
|
||||
|
||||
@ -119,17 +119,14 @@ module.exports = { main };
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
res = {}
|
||||
for k, v in self.arguments.items():
|
||||
res[k] = {
|
||||
"type": "line",
|
||||
"name": k
|
||||
}
|
||||
res[k] = {"type": "line", "name": k}
|
||||
return res
|
||||
|
||||
|
||||
class CodeExec(ToolBase, ABC):
|
||||
component_name = "CodeExec"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("CodeExec processing"):
|
||||
return
|
||||
@ -138,17 +135,12 @@ class CodeExec(ToolBase, ABC):
|
||||
script = kwargs.get("script", self._param.script)
|
||||
arguments = {}
|
||||
for k, v in self._param.arguments.items():
|
||||
|
||||
if kwargs.get(k):
|
||||
arguments[k] = kwargs[k]
|
||||
continue
|
||||
arguments[k] = self._canvas.get_variable_value(v) if v else None
|
||||
|
||||
self._execute_code(
|
||||
language=lang,
|
||||
code=script,
|
||||
arguments=arguments
|
||||
)
|
||||
self._execute_code(language=lang, code=script, arguments=arguments)
|
||||
|
||||
def _execute_code(self, language: str, code: str, arguments: dict):
|
||||
import requests
|
||||
@ -169,7 +161,7 @@ class CodeExec(ToolBase, ABC):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return "Task has been canceled"
|
||||
|
||||
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
|
||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:")
|
||||
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
@ -183,35 +175,10 @@ class CodeExec(ToolBase, ABC):
|
||||
if stderr:
|
||||
self.set_output("_ERROR", stderr)
|
||||
return
|
||||
try:
|
||||
rt = eval(body.get("stdout", ""))
|
||||
except Exception:
|
||||
rt = body.get("stdout", "")
|
||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run -> {rt}")
|
||||
if isinstance(rt, tuple):
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
if k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt[i]
|
||||
elif isinstance(rt, dict):
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
if k not in rt or k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt[k]
|
||||
else:
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
if k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt
|
||||
raw_stdout = body.get("stdout", "")
|
||||
parsed_stdout = self._deserialize_stdout(raw_stdout)
|
||||
logging.info(f"[CodeExec]: http://{settings.SANDBOX_HOST}:9385/run -> {parsed_stdout}")
|
||||
self._populate_outputs(parsed_stdout, raw_stdout)
|
||||
else:
|
||||
self.set_output("_ERROR", "There is no response from sandbox")
|
||||
|
||||
@ -228,3 +195,149 @@ class CodeExec(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Running a short script to process data."
|
||||
|
||||
def _deserialize_stdout(self, stdout: str):
|
||||
text = str(stdout).strip()
|
||||
if not text:
|
||||
return ""
|
||||
for loader in (json.loads, ast.literal_eval):
|
||||
try:
|
||||
return loader(text)
|
||||
except Exception:
|
||||
continue
|
||||
return text
|
||||
|
||||
def _coerce_output_value(self, value, expected_type: Optional[str]):
|
||||
if expected_type is None:
|
||||
return value
|
||||
|
||||
etype = expected_type.strip().lower()
|
||||
inner_type = None
|
||||
if etype.startswith("array<") and etype.endswith(">"):
|
||||
inner_type = etype[6:-1].strip()
|
||||
etype = "array"
|
||||
|
||||
try:
|
||||
if etype == "string":
|
||||
return "" if value is None else str(value)
|
||||
|
||||
if etype == "number":
|
||||
if value is None or value == "":
|
||||
return None
|
||||
if isinstance(value, (int, float)):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return value
|
||||
return float(value)
|
||||
|
||||
if etype == "boolean":
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
lv = value.lower()
|
||||
if lv in ("true", "1", "yes", "y", "on"):
|
||||
return True
|
||||
if lv in ("false", "0", "no", "n", "off"):
|
||||
return False
|
||||
return bool(value)
|
||||
|
||||
if etype == "array":
|
||||
candidate = value
|
||||
if isinstance(candidate, str):
|
||||
parsed = self._deserialize_stdout(candidate)
|
||||
candidate = parsed
|
||||
if isinstance(candidate, tuple):
|
||||
candidate = list(candidate)
|
||||
if not isinstance(candidate, list):
|
||||
candidate = [] if candidate is None else [candidate]
|
||||
|
||||
if inner_type == "string":
|
||||
return ["" if v is None else str(v) for v in candidate]
|
||||
if inner_type == "number":
|
||||
coerced = []
|
||||
for v in candidate:
|
||||
try:
|
||||
if v is None or v == "":
|
||||
coerced.append(None)
|
||||
elif isinstance(v, (int, float)):
|
||||
coerced.append(v)
|
||||
else:
|
||||
coerced.append(float(v))
|
||||
except Exception:
|
||||
coerced.append(v)
|
||||
return coerced
|
||||
return candidate
|
||||
|
||||
if etype == "object":
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
parsed = self._deserialize_stdout(value)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return value
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
return value
|
||||
|
||||
def _populate_outputs(self, parsed_stdout, raw_stdout: str):
|
||||
outputs_items = list(self._param.outputs.items())
|
||||
logging.info(f"[CodeExec]: outputs schema keys: {[k for k, _ in outputs_items]}")
|
||||
if not outputs_items:
|
||||
return
|
||||
|
||||
if isinstance(parsed_stdout, dict):
|
||||
for key, meta in outputs_items:
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
val = self._get_by_path(parsed_stdout, key)
|
||||
coerced = self._coerce_output_value(val, meta.get("type"))
|
||||
logging.info(f"[CodeExec]: populate dict key='{key}' raw='{val}' coerced='{coerced}'")
|
||||
self.set_output(key, coerced)
|
||||
return
|
||||
|
||||
if isinstance(parsed_stdout, (list, tuple)):
|
||||
for idx, (key, meta) in enumerate(outputs_items):
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
val = parsed_stdout[idx] if idx < len(parsed_stdout) else None
|
||||
coerced = self._coerce_output_value(val, meta.get("type"))
|
||||
logging.info(f"[CodeExec]: populate list key='{key}' raw='{val}' coerced='{coerced}'")
|
||||
self.set_output(key, coerced)
|
||||
return
|
||||
|
||||
default_val = parsed_stdout if parsed_stdout is not None else raw_stdout
|
||||
for idx, (key, meta) in enumerate(outputs_items):
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
val = default_val if idx == 0 else None
|
||||
coerced = self._coerce_output_value(val, meta.get("type"))
|
||||
logging.info(f"[CodeExec]: populate scalar key='{key}' raw='{val}' coerced='{coerced}'")
|
||||
self.set_output(key, coerced)
|
||||
|
||||
def _get_by_path(self, data, path: str):
|
||||
if not path:
|
||||
return None
|
||||
cur = data
|
||||
for part in path.split("."):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
return None
|
||||
if isinstance(cur, dict):
|
||||
cur = cur.get(part)
|
||||
elif isinstance(cur, list):
|
||||
try:
|
||||
idx = int(part)
|
||||
cur = cur[idx]
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
if cur is None:
|
||||
return None
|
||||
logging.info(f"[CodeExec]: resolve path '{path}' -> {cur}")
|
||||
return cur
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
from functools import partial
|
||||
import json
|
||||
import os
|
||||
@ -21,13 +22,13 @@ from abc import ABC
|
||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||
from common.constants import LLMType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.dialog_service import meta_filter
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common import settings
|
||||
from common.connection_utils import timeout
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
|
||||
from rag.prompts.generator import cross_languages, kb_prompt
|
||||
|
||||
|
||||
class RetrievalParam(ToolParamBase):
|
||||
@ -81,7 +82,7 @@ class Retrieval(ToolBase, ABC):
|
||||
component_name = "Retrieval"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
|
||||
@ -130,47 +131,51 @@ class Retrieval(ToolBase, ABC):
|
||||
doc_ids=[]
|
||||
if self._param.meta_data_filter!={}:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if self._param.meta_data_filter.get("method") == "auto":
|
||||
|
||||
def _resolve_manual_filter(flt: dict) -> dict:
|
||||
pat = re.compile(self.variable_ref_patt)
|
||||
s = flt.get("value", "")
|
||||
out_parts = []
|
||||
last = 0
|
||||
|
||||
for m in pat.finditer(s):
|
||||
out_parts.append(s[last:m.start()])
|
||||
key = m.group(1)
|
||||
v = self._canvas.get_variable_value(key)
|
||||
if v is None:
|
||||
rep = ""
|
||||
elif isinstance(v, partial):
|
||||
buf = []
|
||||
for chunk in v():
|
||||
buf.append(chunk)
|
||||
rep = "".join(buf)
|
||||
elif isinstance(v, str):
|
||||
rep = v
|
||||
else:
|
||||
rep = json.dumps(v, ensure_ascii=False)
|
||||
|
||||
out_parts.append(rep)
|
||||
last = m.end()
|
||||
|
||||
out_parts.append(s[last:])
|
||||
flt["value"] = "".join(out_parts)
|
||||
return flt
|
||||
|
||||
chat_mdl = None
|
||||
if self._param.meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
||||
filters = gen_meta_filter(chat_mdl, metas, query)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif self._param.meta_data_filter.get("method") == "manual":
|
||||
filters=self._param.meta_data_filter["manual"]
|
||||
for flt in filters:
|
||||
pat = re.compile(self.variable_ref_patt)
|
||||
s = flt["value"]
|
||||
out_parts = []
|
||||
last = 0
|
||||
|
||||
for m in pat.finditer(s):
|
||||
out_parts.append(s[last:m.start()])
|
||||
key = m.group(1)
|
||||
v = self._canvas.get_variable_value(key)
|
||||
if v is None:
|
||||
rep = ""
|
||||
elif isinstance(v, partial):
|
||||
buf = []
|
||||
for chunk in v():
|
||||
buf.append(chunk)
|
||||
rep = "".join(buf)
|
||||
elif isinstance(v, str):
|
||||
rep = v
|
||||
else:
|
||||
rep = json.dumps(v, ensure_ascii=False)
|
||||
|
||||
out_parts.append(rep)
|
||||
last = m.end()
|
||||
|
||||
out_parts.append(s[last:])
|
||||
flt["value"] = "".join(out_parts)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
doc_ids = await apply_meta_data_filter(
|
||||
self._param.meta_data_filter,
|
||||
metas,
|
||||
query,
|
||||
chat_mdl,
|
||||
doc_ids,
|
||||
_resolve_manual_filter if self._param.meta_data_filter.get("method") == "manual" else None,
|
||||
)
|
||||
|
||||
if self._param.cross_languages:
|
||||
query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
|
||||
query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
|
||||
|
||||
if kbs:
|
||||
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
||||
@ -198,6 +203,7 @@ class Retrieval(ToolBase, ABC):
|
||||
return
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
|
||||
if self._param.use_kg:
|
||||
ck = settings.kg_retriever.retrieval(query,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
@ -242,6 +248,10 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
return form_cnt
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
|
||||
@ -75,7 +75,7 @@ class YahooFinance(ToolBase, ABC):
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("YahooFinance processing"):
|
||||
return
|
||||
return None
|
||||
|
||||
if not kwargs.get("stock_code"):
|
||||
self.set_output("report", "")
|
||||
@ -84,33 +84,33 @@ class YahooFinance(ToolBase, ABC):
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("YahooFinance processing"):
|
||||
return
|
||||
return None
|
||||
|
||||
yohoo_res = []
|
||||
yahoo_res = []
|
||||
try:
|
||||
msft = yf.Ticker(kwargs["stock_code"])
|
||||
if self.check_if_canceled("YahooFinance processing"):
|
||||
return
|
||||
return None
|
||||
|
||||
if self._param.info:
|
||||
yohoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
|
||||
yahoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
|
||||
if self._param.history:
|
||||
yohoo_res.append("# History:\n" + msft.history().to_markdown() + "\n")
|
||||
yahoo_res.append("# History:\n" + msft.history().to_markdown() + "\n")
|
||||
if self._param.financials:
|
||||
yohoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n")
|
||||
yahoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n")
|
||||
if self._param.balance_sheet:
|
||||
yohoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n")
|
||||
yohoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n")
|
||||
yahoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n")
|
||||
yahoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n")
|
||||
if self._param.cash_flow_statement:
|
||||
yohoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n")
|
||||
yohoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n")
|
||||
yahoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n")
|
||||
yahoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n")
|
||||
if self._param.news:
|
||||
yohoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n")
|
||||
self.set_output("report", "\n\n".join(yohoo_res))
|
||||
yahoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n")
|
||||
self.set_output("report", "\n\n".join(yahoo_res))
|
||||
return self.output("report")
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("YahooFinance processing"):
|
||||
return
|
||||
return None
|
||||
|
||||
last_e = e
|
||||
logging.exception(f"YahooFinance error: {e}")
|
||||
|
||||
@ -51,7 +51,7 @@ class DeepResearcher:
|
||||
"""Remove Result Tags"""
|
||||
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT)
|
||||
|
||||
def _generate_reasoning(self, msg_history):
|
||||
async def _generate_reasoning(self, msg_history):
|
||||
"""Generate reasoning steps"""
|
||||
query_think = ""
|
||||
if msg_history[-1]["role"] != "user":
|
||||
@ -59,13 +59,14 @@ class DeepResearcher:
|
||||
else:
|
||||
msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
|
||||
|
||||
for ans in self.chat_mdl.chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
|
||||
async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
if not ans:
|
||||
continue
|
||||
query_think = ans
|
||||
yield query_think
|
||||
return query_think
|
||||
query_think = ""
|
||||
yield query_think
|
||||
|
||||
def _extract_search_queries(self, query_think, question, step_index):
|
||||
"""Extract search queries from thinking"""
|
||||
@ -143,10 +144,10 @@ class DeepResearcher:
|
||||
if d["doc_id"] not in dids:
|
||||
chunk_info["doc_aggs"].append(d)
|
||||
|
||||
def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
|
||||
async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
|
||||
"""Extract and summarize relevant information"""
|
||||
summary_think = ""
|
||||
for ans in self.chat_mdl.chat_streamly(
|
||||
async for ans in self.chat_mdl.async_chat_streamly(
|
||||
RELEVANT_EXTRACTION_PROMPT.format(
|
||||
prev_reasoning=truncated_prev_reasoning,
|
||||
search_query=search_query,
|
||||
@ -160,10 +161,11 @@ class DeepResearcher:
|
||||
continue
|
||||
summary_think = ans
|
||||
yield summary_think
|
||||
summary_think = ""
|
||||
|
||||
return summary_think
|
||||
yield summary_think
|
||||
|
||||
def thinking(self, chunk_info: dict, question: str):
|
||||
async def thinking(self, chunk_info: dict, question: str):
|
||||
executed_search_queries = []
|
||||
msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
|
||||
all_reasoning_steps = []
|
||||
@ -180,7 +182,7 @@ class DeepResearcher:
|
||||
|
||||
# Step 1: Generate reasoning
|
||||
query_think = ""
|
||||
for ans in self._generate_reasoning(msg_history):
|
||||
async for ans in self._generate_reasoning(msg_history):
|
||||
query_think = ans
|
||||
yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||
|
||||
@ -223,7 +225,7 @@ class DeepResearcher:
|
||||
# Step 6: Extract relevant information
|
||||
think += "\n\n"
|
||||
summary_think = ""
|
||||
for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
|
||||
async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
|
||||
summary_think = ans
|
||||
yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||
|
||||
|
||||
@ -14,5 +14,5 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from beartype.claw import beartype_this_package
|
||||
beartype_this_package()
|
||||
# from beartype.claw import beartype_this_package
|
||||
# beartype_this_package()
|
||||
|
||||
@ -13,36 +13,34 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from flask import Blueprint, Flask
|
||||
from werkzeug.wrappers.request import Request
|
||||
from flask_cors import CORS
|
||||
from quart import Blueprint, Quart, request, g, current_app, session
|
||||
from flasgger import Swagger
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
from quart_cors import cors
|
||||
from common.constants import StatusEnum
|
||||
from api.db.db_models import close_connection
|
||||
from api.db.db_models import close_connection, APIToken
|
||||
from api.db.services import UserService
|
||||
from api.utils.json_encode import CustomJSONEncoder
|
||||
from api.utils import commands
|
||||
|
||||
from flask_mail import Mail
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from quart_auth import Unauthorized
|
||||
from common import settings
|
||||
from api.utils.api_utils import server_error_response
|
||||
from api.constants import API_VERSION
|
||||
from common.misc_utils import get_uuid
|
||||
|
||||
settings.init_settings()
|
||||
|
||||
__all__ = ["app"]
|
||||
|
||||
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
||||
|
||||
app = Flask(__name__)
|
||||
smtp_mail_server = Mail()
|
||||
app = Quart(__name__)
|
||||
app = cors(app, allow_origin="*")
|
||||
|
||||
# Add this at the beginning of your file to configure Swagger UI
|
||||
swagger_config = {
|
||||
@ -76,32 +74,166 @@ swagger = Swagger(
|
||||
},
|
||||
)
|
||||
|
||||
CORS(app, supports_credentials=True, max_age=2592000)
|
||||
app.url_map.strict_slashes = False
|
||||
app.json_encoder = CustomJSONEncoder
|
||||
app.errorhandler(Exception)(server_error_response)
|
||||
|
||||
# Configure Quart timeouts for slow LLM responses (e.g., local Ollama on CPU)
|
||||
# Default Quart timeouts are 60 seconds which is too short for many LLM backends
|
||||
app.config["RESPONSE_TIMEOUT"] = int(os.environ.get("QUART_RESPONSE_TIMEOUT", 600))
|
||||
app.config["BODY_TIMEOUT"] = int(os.environ.get("QUART_BODY_TIMEOUT", 600))
|
||||
|
||||
## convince for dev and debug
|
||||
# app.config["LOGIN_DISABLED"] = True
|
||||
app.config["SESSION_PERMANENT"] = False
|
||||
app.config["SESSION_TYPE"] = "filesystem"
|
||||
app.config["SESSION_TYPE"] = "redis"
|
||||
app.config["SESSION_REDIS"] = settings.decrypt_database_config(name="redis")
|
||||
app.config["MAX_CONTENT_LENGTH"] = int(
|
||||
os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024)
|
||||
)
|
||||
|
||||
Session(app)
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
app.config['SECRET_KEY'] = settings.SECRET_KEY
|
||||
app.secret_key = settings.SECRET_KEY
|
||||
commands.register_commands(app)
|
||||
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
from collections.abc import Awaitable, Callable
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
def search_pages_path(pages_dir):
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
def _load_user():
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = request.headers.get("Authorization")
|
||||
g.user = None
|
||||
if not authorization:
|
||||
return
|
||||
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
logging.warning("Authentication attempt with empty access token")
|
||||
return None
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
|
||||
return None
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if not user and len(authorization.split()) == 2:
|
||||
objs = APIToken.query(token=authorization.split()[1])
|
||||
if objs:
|
||||
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
g.user = user[0]
|
||||
return user[0]
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
|
||||
|
||||
current_user = LocalProxy(_load_user)
|
||||
|
||||
|
||||
def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
|
||||
"""A decorator to restrict route access to authenticated users.
|
||||
|
||||
This should be used to wrap a route handler (or view function) to
|
||||
enforce that only authenticated requests can access it. Note that
|
||||
it is important that this decorator be wrapped by the route
|
||||
decorator and not vice, versa, as below.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@app.route('/')
|
||||
@login_required
|
||||
async def index():
|
||||
...
|
||||
|
||||
If the request is not authenticated a
|
||||
`quart.exceptions.Unauthorized` exception will be raised.
|
||||
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
if not current_user:# or not session.get("_user_id"):
|
||||
raise Unauthorized()
|
||||
else:
|
||||
return await current_app.ensure_async(func)(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def login_user(user, remember=False, duration=None, force=False, fresh=True):
|
||||
"""
|
||||
Logs a user in. You should pass the actual user object to this. If the
|
||||
user's `is_active` property is ``False``, they will not be logged in
|
||||
unless `force` is ``True``.
|
||||
|
||||
This will return ``True`` if the login attempt succeeds, and ``False`` if
|
||||
it fails (i.e. because the user is inactive).
|
||||
|
||||
:param user: The user object to log in.
|
||||
:type user: object
|
||||
:param remember: Whether to remember the user after their session expires.
|
||||
Defaults to ``False``.
|
||||
:type remember: bool
|
||||
:param duration: The amount of time before the remember cookie expires. If
|
||||
``None`` the value set in the settings is used. Defaults to ``None``.
|
||||
:type duration: :class:`datetime.timedelta`
|
||||
:param force: If the user is inactive, setting this to ``True`` will log
|
||||
them in regardless. Defaults to ``False``.
|
||||
:type force: bool
|
||||
:param fresh: setting this to ``False`` will log in the user with a session
|
||||
marked as not "fresh". Defaults to ``True``.
|
||||
:type fresh: bool
|
||||
"""
|
||||
if not force and not user.is_active:
|
||||
return False
|
||||
|
||||
session["_user_id"] = user.id
|
||||
session["_fresh"] = fresh
|
||||
session["_id"] = get_uuid()
|
||||
return True
|
||||
|
||||
|
||||
def logout_user():
|
||||
"""
|
||||
Logs a user out. (You do not need to pass the actual user.) This will
|
||||
also clean up the remember me cookie if it exists.
|
||||
"""
|
||||
if "_user_id" in session:
|
||||
session.pop("_user_id")
|
||||
|
||||
if "_fresh" in session:
|
||||
session.pop("_fresh")
|
||||
|
||||
if "_id" in session:
|
||||
session.pop("_id")
|
||||
|
||||
COOKIE_NAME = "remember_token"
|
||||
cookie_name = current_app.config.get("REMEMBER_COOKIE_NAME", COOKIE_NAME)
|
||||
if cookie_name in request.cookies:
|
||||
session["_remember"] = "clear"
|
||||
if "_remember_seconds" in session:
|
||||
session.pop("_remember_seconds")
|
||||
|
||||
return True
|
||||
|
||||
def search_pages_path(page_path):
|
||||
app_path_list = [
|
||||
path for path in pages_dir.glob("*_app.py") if not path.name.startswith(".")
|
||||
path for path in page_path.glob("*_app.py") if not path.name.startswith(".")
|
||||
]
|
||||
api_path_list = [
|
||||
path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".")
|
||||
path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")
|
||||
]
|
||||
app_path_list.extend(api_path_list)
|
||||
return app_path_list
|
||||
@ -138,44 +270,12 @@ pages_dir = [
|
||||
]
|
||||
|
||||
client_urls_prefix = [
|
||||
register_page(path) for dir in pages_dir for path in search_pages_path(dir)
|
||||
register_page(path) for directory in pages_dir for path in search_pages_path(directory)
|
||||
]
|
||||
|
||||
|
||||
@login_manager.request_loader
|
||||
def load_user(web_request):
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = web_request.headers.get("Authorization")
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
logging.warning("Authentication attempt with empty access token")
|
||||
return None
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
|
||||
return None
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
return user[0]
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@app.teardown_request
|
||||
def _db_close(exc):
|
||||
def _db_close(exception):
|
||||
if exception:
|
||||
logging.exception(f"Request failed: {exception}")
|
||||
close_connection()
|
||||
|
||||
@ -13,46 +13,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from flask import request, Response
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db import VALID_FILE_TYPES, FileType
|
||||
from api.db.db_models import APIToken, Task, File
|
||||
from api.db.services import duplicate_name
|
||||
from quart import request
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||
from api.db.services.dialog_service import DialogService, chat
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import queue_tasks, TaskService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
||||
generate_confirmation_token
|
||||
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import keyword_extraction
|
||||
from api.utils.api_utils import generate_confirmation_token, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from agent.canvas import Canvas
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def new_token():
|
||||
req = request.json
|
||||
async def new_token():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if not tenants:
|
||||
@ -97,8 +71,8 @@ def token_list():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@validate_request("tokens", "tenant_id")
|
||||
@login_required
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
for token in req["tokens"]:
|
||||
APITokenService.filter_delete(
|
||||
@ -126,770 +100,18 @@ def stats():
|
||||
"to_date",
|
||||
datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
|
||||
"agent" if "canvas_id" in request.args else None)
|
||||
res = {
|
||||
"pv": [(o["dt"], o["pv"]) for o in objs],
|
||||
"uv": [(o["dt"], o["uv"]) for o in objs],
|
||||
"speed": [(o["dt"], float(o["tokens"]) / (float(o["duration"] + 0.1))) for o in objs],
|
||||
"tokens": [(o["dt"], float(o["tokens"]) / 1000.) for o in objs],
|
||||
"round": [(o["dt"], o["round"]) for o in objs],
|
||||
"thumb_up": [(o["dt"], o["thumb_up"]) for o in objs]
|
||||
}
|
||||
|
||||
res = {"pv": [], "uv": [], "speed": [], "tokens": [], "round": [], "thumb_up": []}
|
||||
|
||||
for obj in objs:
|
||||
dt = obj["dt"]
|
||||
res["pv"].append((dt, obj["pv"]))
|
||||
res["uv"].append((dt, obj["uv"]))
|
||||
res["speed"].append((dt, float(obj["tokens"]) / (float(obj["duration"]) + 0.1))) # +0.1 to avoid division by zero
|
||||
res["tokens"].append((dt, float(obj["tokens"]) / 1000.0)) # convert to thousands
|
||||
res["round"].append((dt, obj["round"]))
|
||||
res["thumb_up"].append((dt, obj["thumb_up"]))
|
||||
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/new_conversation', methods=['GET']) # noqa: F821
|
||||
def set_conversation():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
if objs[0].source == "agent":
|
||||
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": request.args.get("user_id", ""),
|
||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
||||
"source": "agent"
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
else:
|
||||
e, dia = DialogService.get_by_id(objs[0].dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found")
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": dia.id,
|
||||
"user_id": request.args.get("user_id", ""),
|
||||
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
||||
@validate_request("conversation_id", "messages")
|
||||
def completion():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
if "quote" not in req:
|
||||
req["quote"] = False
|
||||
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
def fillin_conv(ans):
|
||||
nonlocal conv, message_id
|
||||
if not conv.reference:
|
||||
conv.reference.append(ans["reference"])
|
||||
else:
|
||||
conv.reference[-1] = ans["reference"]
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
||||
ans["id"] = message_id
|
||||
|
||||
def rename_field(ans):
|
||||
reference = ans['reference']
|
||||
if not isinstance(reference, dict):
|
||||
return
|
||||
for chunk_i in reference.get('chunks', []):
|
||||
if 'docnm_kwd' in chunk_i:
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
|
||||
try:
|
||||
if conv.source == "agent":
|
||||
stream = req.get("stream", True)
|
||||
conv.message.append(msg[-1])
|
||||
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
final_ans = {"reference": [], "content": ""}
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
|
||||
canvas.messages.append(msg[-1])
|
||||
canvas.add_user_input(msg[-1]["content"])
|
||||
answer = canvas.run(stream=stream)
|
||||
|
||||
assert answer is not None, "Nothing. Is it over?"
|
||||
|
||||
if stream:
|
||||
assert isinstance(answer, partial), "Nothing. Is it over?"
|
||||
|
||||
def sse():
|
||||
nonlocal answer, cvs, conv
|
||||
try:
|
||||
for ans in answer():
|
||||
for k in ans.keys():
|
||||
final_ans[k] = ans[k]
|
||||
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
|
||||
fillin_conv(ans)
|
||||
rename_field(ans)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
canvas.history.append(("assistant", final_ans["content"]))
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(sse(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
|
||||
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
||||
fillin_conv(result)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
rename_field(result)
|
||||
return get_json_result(data=result)
|
||||
|
||||
# ******************For dialog******************
|
||||
conv.message.append(msg[-1])
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
def stream():
|
||||
nonlocal dia, msg, req, conv
|
||||
try:
|
||||
for ans in chat(dia, msg, True, **req):
|
||||
fillin_conv(ans)
|
||||
rename_field(ans)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
if req.get("stream", True):
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
answer = None
|
||||
for ans in chat(dia, msg, **req):
|
||||
answer = ans
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
break
|
||||
rename_field(answer)
|
||||
return get_json_result(data=answer)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/conversation/<conversation_id>', methods=['GET']) # noqa: F821
|
||||
# @login_required
|
||||
def get_conversation(conversation_id):
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
e, conv = API4ConversationService.get_by_id(conversation_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
|
||||
conv = conv.to_dict()
|
||||
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
|
||||
return get_json_result(data=False, message='Authentication error: API key is invalid for this conversation_id!"',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
for referenct_i in conv['reference']:
|
||||
if referenct_i is None or len(referenct_i) == 0:
|
||||
continue
|
||||
for chunk_i in referenct_i['chunks']:
|
||||
if 'docnm_kwd' in chunk_i.keys():
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/document/upload', methods=['POST']) # noqa: F821
|
||||
@validate_request("kb_name")
|
||||
def upload():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
kb_name = request.form.get("kb_name").strip()
|
||||
tenant_id = objs[0].tenant_id
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
kb_id = kb.id
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
kb_root_folder = FileService.get_kb_folder(tenant_id)
|
||||
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
|
||||
|
||||
try:
|
||||
if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)):
|
||||
return get_data_error_result(
|
||||
message="Exceed the maximum file number of a free user!")
|
||||
|
||||
filename = duplicate_name(
|
||||
DocumentService.query,
|
||||
name=file.filename,
|
||||
kb_id=kb_id)
|
||||
filetype = filename_type(filename)
|
||||
if not filetype:
|
||||
return get_data_error_result(
|
||||
message="This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
while settings.STORAGE_IMPL.obj_exist(kb_id, location):
|
||||
location += "_"
|
||||
blob = request.files['file'].read()
|
||||
settings.STORAGE_IMPL.put(kb_id, location, blob)
|
||||
doc = {
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": kb.tenant_id,
|
||||
"type": filetype,
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
"thumbnail": thumbnail(filename, blob),
|
||||
"suffix": Path(filename).suffix.lstrip("."),
|
||||
}
|
||||
|
||||
form_data = request.form
|
||||
if "parser_id" in form_data.keys():
|
||||
if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
|
||||
doc["parser_id"] = request.form.get("parser_id").strip()
|
||||
if doc["type"] == FileType.VISUAL:
|
||||
doc["parser_id"] = ParserType.PICTURE.value
|
||||
if doc["type"] == FileType.AURAL:
|
||||
doc["parser_id"] = ParserType.AUDIO.value
|
||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||
doc["parser_id"] = ParserType.PRESENTATION.value
|
||||
if re.search(r"\.(eml)$", filename):
|
||||
doc["parser_id"] = ParserType.EMAIL.value
|
||||
|
||||
doc_result = DocumentService.insert(doc)
|
||||
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
if "run" in form_data.keys():
|
||||
if request.form.get("run").strip() == "1":
|
||||
try:
|
||||
info = {"run": 1, "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
|
||||
DocumentService.update_by_id(doc["id"], info)
|
||||
# if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||
tenant_id = DocumentService.get_tenant_id(doc["id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
# e, doc = DocumentService.get_by_id(doc["id"])
|
||||
TaskService.filter_delete([Task.doc_id == doc["id"]])
|
||||
e, doc = DocumentService.get_by_id(doc["id"])
|
||||
doc = doc.to_dict()
|
||||
doc["tenant_id"] = tenant_id
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
return get_json_result(data=doc_result.to_json())
|
||||
|
||||
|
||||
@manager.route('/document/upload_and_parse', methods=['POST']) # noqa: F821
|
||||
@validate_request("conversation_id")
|
||||
def upload_parse():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist('file')
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
||||
return get_json_result(data=doc_ids)
|
||||
|
||||
|
||||
@manager.route('/list_chunks', methods=['POST']) # noqa: F821
|
||||
# @login_required
|
||||
def list_chunks():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
|
||||
try:
|
||||
if "doc_name" in req.keys():
|
||||
tenant_id = DocumentService.get_tenant_id_by_name(req['doc_name'])
|
||||
doc_id = DocumentService.get_doc_id_by_doc_name(req['doc_name'])
|
||||
|
||||
elif "doc_id" in req.keys():
|
||||
tenant_id = DocumentService.get_tenant_id(req['doc_id'])
|
||||
doc_id = req['doc_id']
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message="Can't find doc_name or doc_id"
|
||||
)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
|
||||
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = [
|
||||
{
|
||||
"content": res_item["content_with_weight"],
|
||||
"doc_name": res_item["docnm_kwd"],
|
||||
"image_id": res_item["img_id"]
|
||||
} for res_item in res
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
return get_json_result(data=res)
|
||||
|
||||
@manager.route('/get_chunk/<chunk_id>', methods=['GET']) # noqa: F821
|
||||
# @login_required
|
||||
def get_chunk(chunk_id):
|
||||
from rag.nlp import search
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
tenant_id = objs[0].tenant_id
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||
if chunk is None:
|
||||
return server_error_response(Exception("Chunk not found"))
|
||||
k = []
|
||||
for n in chunk.keys():
|
||||
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
|
||||
k.append(n)
|
||||
for n in k:
|
||||
del chunk[n]
|
||||
|
||||
return get_json_result(data=chunk)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/list_kb_docs', methods=['POST']) # noqa: F821
|
||||
# @login_required
|
||||
def list_kb_docs():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
tenant_id = objs[0].tenant_id
|
||||
kb_name = req.get("kb_name", "").strip()
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
kb_id = kb.id
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
page_number = int(req.get("page", 1))
|
||||
items_per_page = int(req.get("page_size", 15))
|
||||
orderby = req.get("orderby", "create_time")
|
||||
desc = req.get("desc", True)
|
||||
keywords = req.get("keywords", "")
|
||||
status = req.get("status", [])
|
||||
if status:
|
||||
invalid_status = {s for s in status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(
|
||||
message=f"Invalid filter status conditions: {', '.join(invalid_status)}"
|
||||
)
|
||||
types = req.get("types", [])
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(
|
||||
message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}"
|
||||
)
|
||||
try:
|
||||
docs, tol = DocumentService.get_by_kb_id(
|
||||
kb_id, page_number, items_per_page, orderby, desc, keywords, status, types)
|
||||
docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs]
|
||||
|
||||
return get_json_result(data={"total": tol, "docs": docs})
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/document/infos', methods=['POST']) # noqa: F821
|
||||
@validate_request("doc_ids")
|
||||
def docinfos():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
doc_ids = req["doc_ids"]
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
return get_json_result(data=list(docs.dicts()))
|
||||
|
||||
|
||||
@manager.route('/document', methods=['DELETE']) # noqa: F821
|
||||
# @login_required
|
||||
def document_rm():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
req = request.json
|
||||
try:
|
||||
doc_ids = DocumentService.get_doc_ids_by_doc_names(req.get("doc_names", []))
|
||||
for doc_id in req.get("doc_ids", []):
|
||||
if doc_id not in doc_ids:
|
||||
doc_ids.append(doc_id)
|
||||
|
||||
if not doc_ids:
|
||||
return get_json_result(
|
||||
data=False, message="Can't find doc_names or doc_ids"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
|
||||
errors = ""
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
doc_dic = {}
|
||||
for doc in docs:
|
||||
doc_dic[doc.id] = doc
|
||||
|
||||
for doc_id in doc_ids:
|
||||
try:
|
||||
if doc_id not in doc_dic:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
doc = doc_dic[doc_id]
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
|
||||
f2d = File2DocumentService.get_by_document_id(doc_id)
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc_id)
|
||||
|
||||
settings.STORAGE_IMPL.rm(b, n)
|
||||
except Exception as e:
|
||||
errors += str(e)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/completion_aibotk', methods=['POST']) # noqa: F821
|
||||
@validate_request("Authorization", "conversation_id", "word")
|
||||
def completion_faq():
|
||||
import base64
|
||||
req = request.json
|
||||
|
||||
token = req["Authorization"]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
if "quote" not in req:
|
||||
req["quote"] = True
|
||||
|
||||
msg = [{"role": "user", "content": req["word"]}]
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
def fillin_conv(ans):
|
||||
nonlocal conv, message_id
|
||||
if not conv.reference:
|
||||
conv.reference.append(ans["reference"])
|
||||
else:
|
||||
conv.reference[-1] = ans["reference"]
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
||||
ans["id"] = message_id
|
||||
|
||||
try:
|
||||
if conv.source == "agent":
|
||||
conv.message.append(msg[-1])
|
||||
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
final_ans = {"reference": [], "doc_aggs": []}
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
|
||||
canvas.messages.append(msg[-1])
|
||||
canvas.add_user_input(msg[-1]["content"])
|
||||
answer = canvas.run(stream=False)
|
||||
|
||||
assert answer is not None, "Nothing. Is it over?"
|
||||
|
||||
data_type_picture = {
|
||||
"type": 3,
|
||||
"url": "base64 content"
|
||||
}
|
||||
data = [
|
||||
{
|
||||
"type": 1,
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
|
||||
ans = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
||||
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
|
||||
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
||||
for chunk_idx in chunk_idxs[:1]:
|
||||
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
||||
try:
|
||||
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
||||
response = settings.STORAGE_IMPL.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
response = {"code": 200, "msg": "success", "data": data}
|
||||
return response
|
||||
|
||||
# ******************For dialog******************
|
||||
conv.message.append(msg[-1])
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
del req["conversation_id"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
data_type_picture = {
|
||||
"type": 3,
|
||||
"url": "base64 content"
|
||||
}
|
||||
data = [
|
||||
{
|
||||
"type": 1,
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
ans = ""
|
||||
for a in chat(dia, msg, stream=False, **req):
|
||||
ans = a
|
||||
break
|
||||
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
|
||||
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
||||
for chunk_idx in chunk_idxs[:1]:
|
||||
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
||||
try:
|
||||
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
||||
response = settings.STORAGE_IMPL.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
response = {"code": 200, "msg": "success", "data": data}
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/retrieval', methods=['POST']) # noqa: F821
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
kb_ids = req.get("kb_id", [])
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
question = req.get("question")
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("page_size", 30))
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = int(req.get("top_k", 1024))
|
||||
highlight = bool(req.get("highlight", False))
|
||||
|
||||
try:
|
||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
if len(embd_nms) != 1:
|
||||
return get_json_result(
|
||||
data=False, message='Knowledge bases use different embedding models or does not exist."',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, llm_name=kbs[0].embd_id)
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
|
||||
rank_feature=label_question(question, kbs))
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
return get_json_result(data=ranks)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
code=RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import requests
|
||||
from common.http_client import async_request, sync_request
|
||||
from .oauth import OAuthClient, UserInfo
|
||||
|
||||
|
||||
@ -34,24 +34,49 @@ class GithubOAuthClient(OAuthClient):
|
||||
|
||||
def fetch_user_info(self, access_token, **kwargs):
|
||||
"""
|
||||
Fetch GitHub user info.
|
||||
Fetch GitHub user info (synchronous).
|
||||
"""
|
||||
user_info = {}
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
# user info
|
||||
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||
response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||
response.raise_for_status()
|
||||
user_info.update(response.json())
|
||||
# email info
|
||||
response = requests.get(self.userinfo_url+"/emails", headers=headers, timeout=self.http_request_timeout)
|
||||
response.raise_for_status()
|
||||
email_info = response.json()
|
||||
user_info["email"] = next(
|
||||
(email for email in email_info if email["primary"]), None
|
||||
)["email"]
|
||||
email_response = sync_request(
|
||||
"GET", self.userinfo_url + "/emails", headers=headers, timeout=self.http_request_timeout
|
||||
)
|
||||
email_response.raise_for_status()
|
||||
email_info = email_response.json()
|
||||
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||
return self.normalize_user_info(user_info)
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch github user info: {e}")
|
||||
|
||||
async def async_fetch_user_info(self, access_token, **kwargs):
|
||||
"""Async variant of fetch_user_info using httpx."""
|
||||
user_info = {}
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
try:
|
||||
response = await async_request(
|
||||
"GET",
|
||||
self.userinfo_url,
|
||||
headers=headers,
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_info.update(response.json())
|
||||
|
||||
email_response = await async_request(
|
||||
"GET",
|
||||
self.userinfo_url + "/emails",
|
||||
headers=headers,
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
email_response.raise_for_status()
|
||||
email_info = email_response.json()
|
||||
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||
return self.normalize_user_info(user_info)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch github user info: {e}")
|
||||
|
||||
|
||||
|
||||
@ -14,8 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import requests
|
||||
import urllib.parse
|
||||
from common.http_client import async_request, sync_request
|
||||
|
||||
|
||||
class UserInfo:
|
||||
@ -74,15 +74,40 @@ class OAuthClient:
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"grant_type": "authorization_code"
|
||||
}
|
||||
response = requests.post(
|
||||
response = sync_request(
|
||||
"POST",
|
||||
self.token_url,
|
||||
data=payload,
|
||||
headers={"Accept": "application/json"},
|
||||
timeout=self.http_request_timeout
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to exchange authorization code for token: {e}")
|
||||
|
||||
async def async_exchange_code_for_token(self, code):
|
||||
"""
|
||||
Async variant of exchange_code_for_token using httpx.
|
||||
"""
|
||||
payload = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
try:
|
||||
response = await async_request(
|
||||
"POST",
|
||||
self.token_url,
|
||||
data=payload,
|
||||
headers={"Accept": "application/json"},
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to exchange authorization code for token: {e}")
|
||||
|
||||
|
||||
@ -92,11 +117,27 @@ class OAuthClient:
|
||||
"""
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||
response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
return self.normalize_user_info(user_info)
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch user info: {e}")
|
||||
|
||||
async def async_fetch_user_info(self, access_token, **kwargs):
|
||||
"""Async variant of fetch_user_info using httpx."""
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
try:
|
||||
response = await async_request(
|
||||
"GET",
|
||||
self.userinfo_url,
|
||||
headers=headers,
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
return self.normalize_user_info(user_info)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch user info: {e}")
|
||||
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
#
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
from common.http_client import sync_request
|
||||
from .oauth import OAuthClient
|
||||
|
||||
|
||||
@ -50,10 +50,10 @@ class OIDCClient(OAuthClient):
|
||||
"""
|
||||
try:
|
||||
metadata_url = f"{issuer}/.well-known/openid-configuration"
|
||||
response = requests.get(metadata_url, timeout=7)
|
||||
response = sync_request("GET", metadata_url, timeout=7)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch OIDC metadata: {e}")
|
||||
|
||||
|
||||
@ -95,6 +95,13 @@ class OIDCClient(OAuthClient):
|
||||
user_info.update(super().fetch_user_info(access_token).to_dict())
|
||||
return self.normalize_user_info(user_info)
|
||||
|
||||
async def async_fetch_user_info(self, access_token, id_token=None, **kwargs):
|
||||
user_info = {}
|
||||
if id_token:
|
||||
user_info = self.parse_id_token(id_token)
|
||||
user_info.update((await super().async_fetch_user_info(access_token)).to_dict())
|
||||
return self.normalize_user_info(user_info)
|
||||
|
||||
|
||||
def normalize_user_info(self, user_info):
|
||||
return super().normalize_user_info(user_info)
|
||||
|
||||
@ -13,19 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
import flask
|
||||
import trio
|
||||
from flask import request, Response
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from quart import request, Response, make_response
|
||||
from agent.component import LLM
|
||||
from api.db import CanvasCategory, FileType
|
||||
from api.db import CanvasCategory
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
@ -35,17 +30,18 @@ from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
|
||||
get_request_json
|
||||
from agent.canvas import Canvas
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
from api.db.db_models import APIToken, Task
|
||||
import time
|
||||
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
||||
from rag.flow.pipeline import Pipeline
|
||||
from rag.nlp import search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/templates', methods=['GET']) # noqa: F821
|
||||
@ -57,8 +53,9 @@ def templates():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@validate_request("canvas_ids")
|
||||
@login_required
|
||||
def rm():
|
||||
for i in request.json["canvas_ids"]:
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
for i in req["canvas_ids"]:
|
||||
if not UserCanvasService.accessible(i, current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -70,8 +67,8 @@ def rm():
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@validate_request("dsl", "title")
|
||||
@login_required
|
||||
def save():
|
||||
req = request.json
|
||||
async def save():
|
||||
req = await get_request_json()
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
@ -129,18 +126,18 @@ def getsse(canvas_id):
|
||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def run():
|
||||
req = request.json
|
||||
async def run():
|
||||
req = await get_request_json()
|
||||
query = req.get("query", "")
|
||||
files = req.get("files", [])
|
||||
inputs = req.get("inputs", {})
|
||||
user_id = req.get("user_id", current_user.id)
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
||||
e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
@ -150,7 +147,7 @@ def run():
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
task_id = get_uuid()
|
||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
|
||||
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
|
||||
if not ok:
|
||||
return get_data_error_result(message=error_message)
|
||||
return get_json_result(data={"message_id": task_id})
|
||||
@ -160,10 +157,10 @@ def run():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
def sse():
|
||||
async def sse():
|
||||
nonlocal canvas, user_id
|
||||
try:
|
||||
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
@ -179,15 +176,15 @@ def run():
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
resp.call_on_close(lambda: canvas.cancel_task())
|
||||
#resp.call_on_close(lambda: canvas.cancel_task())
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "dsl", "component_id")
|
||||
@login_required
|
||||
def rerun():
|
||||
req = request.json
|
||||
async def rerun():
|
||||
req = await get_request_json()
|
||||
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||
if not doc:
|
||||
return get_data_error_result(message="Document not found.")
|
||||
@ -224,8 +221,8 @@ def cancel(task_id):
|
||||
@manager.route('/reset', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def reset():
|
||||
req = request.json
|
||||
async def reset():
|
||||
req = await get_request_json()
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -245,76 +242,16 @@ def reset():
|
||||
|
||||
|
||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||
def upload(canvas_id):
|
||||
async def upload(canvas_id):
|
||||
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
user_id = cvs["user_id"]
|
||||
def structured(filename, filetype, blob, content_type):
|
||||
nonlocal user_id
|
||||
if filetype == FileType.PDF.value:
|
||||
blob = read_potential_broken_pdf(blob)
|
||||
|
||||
location = get_uuid()
|
||||
FileService.put_blob(user_id, location, blob)
|
||||
|
||||
return {
|
||||
"id": location,
|
||||
"name": filename,
|
||||
"size": sys.getsizeof(blob),
|
||||
"extension": filename.split(".")[-1].lower(),
|
||||
"mime_type": content_type,
|
||||
"created_by": user_id,
|
||||
"created_at": time.time(),
|
||||
"preview_url": None
|
||||
}
|
||||
|
||||
if request.args.get("url"):
|
||||
from crawl4ai import (
|
||||
AsyncWebCrawler,
|
||||
BrowserConfig,
|
||||
CrawlerRunConfig,
|
||||
DefaultMarkdownGenerator,
|
||||
PruningContentFilter,
|
||||
CrawlResult
|
||||
)
|
||||
try:
|
||||
url = request.args.get("url")
|
||||
filename = re.sub(r"\?.*", "", url.split("/")[-1])
|
||||
async def adownload():
|
||||
browser_config = BrowserConfig(
|
||||
headless=True,
|
||||
verbose=False,
|
||||
)
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
crawler_config = CrawlerRunConfig(
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=PruningContentFilter()
|
||||
),
|
||||
pdf=True,
|
||||
screenshot=False
|
||||
)
|
||||
result: CrawlResult = await crawler.arun(
|
||||
url=url,
|
||||
config=crawler_config
|
||||
)
|
||||
return result
|
||||
page = trio.run(adownload())
|
||||
if page.pdf:
|
||||
if filename.split(".")[-1].lower() != "pdf":
|
||||
filename += ".pdf"
|
||||
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
|
||||
|
||||
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
file = request.files['file']
|
||||
files = await request.files
|
||||
file = files['file'] if files and files.get("file") else None
|
||||
try:
|
||||
DocumentService.check_doc_health(user_id, file.filename)
|
||||
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
|
||||
return get_json_result(data=FileService.upload_info(user_id, file, request.args.get("url")))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -342,8 +279,8 @@ def input_form():
|
||||
@manager.route('/debug', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "component_id", "params")
|
||||
@login_required
|
||||
def debug():
|
||||
req = request.json
|
||||
async def debug():
|
||||
req = await get_request_json()
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -363,8 +300,13 @@ def debug():
|
||||
for k in outputs.keys():
|
||||
if isinstance(outputs[k], partial):
|
||||
txt = ""
|
||||
for c in outputs[k]():
|
||||
txt += c
|
||||
iter_obj = outputs[k]()
|
||||
if inspect.isasyncgen(iter_obj):
|
||||
async for c in iter_obj:
|
||||
txt += c
|
||||
else:
|
||||
for c in iter_obj:
|
||||
txt += c
|
||||
outputs[k] = txt
|
||||
return get_json_result(data=outputs)
|
||||
except Exception as e:
|
||||
@ -374,8 +316,8 @@ def debug():
|
||||
@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
|
||||
@validate_request("db_type", "database", "username", "host", "port", "password")
|
||||
@login_required
|
||||
def test_db_connect():
|
||||
req = request.json
|
||||
async def test_db_connect():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
if req["db_type"] in ["mysql", "mariadb"]:
|
||||
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||
@ -406,7 +348,15 @@ def test_db_connect():
|
||||
f"UID={req['username']};"
|
||||
f"PWD={req['password']};"
|
||||
)
|
||||
logging.info(conn_str)
|
||||
redacted_conn_str = (
|
||||
f"DATABASE={req['database']};"
|
||||
f"HOSTNAME={req['host']};"
|
||||
f"PORT={req['port']};"
|
||||
f"PROTOCOL=TCPIP;"
|
||||
f"UID={req['username']};"
|
||||
f"PWD=****;"
|
||||
)
|
||||
logging.info(redacted_conn_str)
|
||||
conn = ibm_db.connect(conn_str, "", "")
|
||||
stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1")
|
||||
ibm_db.fetch_assoc(stmt)
|
||||
@ -426,7 +376,6 @@ def test_db_connect():
|
||||
try:
|
||||
import trino
|
||||
import os
|
||||
from trino.auth import BasicAuthentication
|
||||
except Exception as e:
|
||||
return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}")
|
||||
|
||||
@ -438,7 +387,7 @@ def test_db_connect():
|
||||
|
||||
auth = None
|
||||
if http_scheme == "https" and req.get("password"):
|
||||
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||
auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||
|
||||
conn = trino.dbapi.connect(
|
||||
host=req["host"],
|
||||
@ -471,8 +420,8 @@ def test_db_connect():
|
||||
@login_required
|
||||
def getlistversion(canvas_id):
|
||||
try:
|
||||
list =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
|
||||
return get_json_result(data=list)
|
||||
versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
|
||||
return get_json_result(data=versions)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"Error getting history files: {e}")
|
||||
|
||||
@ -520,8 +469,8 @@ def list_canvas():
|
||||
@manager.route('/setting', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "title", "permission")
|
||||
@login_required
|
||||
def setting():
|
||||
req = request.json
|
||||
async def setting():
|
||||
req = await get_request_json()
|
||||
req["user_id"] = current_user.id
|
||||
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
@ -602,8 +551,8 @@ def prompts():
|
||||
|
||||
|
||||
@manager.route('/download', methods=['GET']) # noqa: F821
|
||||
def download():
|
||||
async def download():
|
||||
id = request.args.get("id")
|
||||
created_by = request.args.get("created_by")
|
||||
blob = FileService.get_blob(created_by, id)
|
||||
return flask.make_response(blob)
|
||||
return await make_response(blob)
|
||||
|
||||
@ -13,35 +13,37 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
|
||||
import base64
|
||||
import xxhash
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import request
|
||||
|
||||
from api.db.services.dialog_service import meta_filter
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
get_request_json
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
|
||||
from rag.prompts.generator import cross_languages, keyword_extraction
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def list_chunk():
|
||||
req = request.json
|
||||
async def list_chunk():
|
||||
req = await get_request_json()
|
||||
doc_id = req["doc_id"]
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
@ -121,8 +123,8 @@ def get():
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
||||
def set():
|
||||
req = request.json
|
||||
async def set():
|
||||
req = await get_request_json()
|
||||
d = {
|
||||
"id": req["chunk_id"],
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
@ -146,31 +148,43 @@ def set():
|
||||
d["available_int"] = req["available_int"]
|
||||
|
||||
try:
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
def _set_sync():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
if doc.parser_id == ParserType.QA:
|
||||
arr = [
|
||||
t for t in re.split(
|
||||
r"[\n\t]",
|
||||
req["content_with_weight"]) if len(t) > 1]
|
||||
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
||||
d = beAdoc(d, q, a, not any(
|
||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||
_d = d
|
||||
if doc.parser_id == ParserType.QA:
|
||||
arr = [
|
||||
t for t in re.split(
|
||||
r"[\n\t]",
|
||||
req["content_with_weight"]) if len(t) > 1]
|
||||
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
||||
_d = beAdoc(d, q, a, not any(
|
||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
|
||||
return get_json_result(data=True)
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
_d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
# update image
|
||||
image_id = req.get("img_id")
|
||||
bkt, name = image_id.split("-")
|
||||
image_base64 = req.get("image_base64", None)
|
||||
if image_base64:
|
||||
image_binary = base64.b64decode(image_base64)
|
||||
settings.STORAGE_IMPL.put(bkt, name, image_binary)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_set_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -178,19 +192,22 @@ def set():
|
||||
@manager.route('/switch', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "available_int", "doc_id")
|
||||
def switch():
|
||||
req = request.json
|
||||
async def switch():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
for cid in req["chunk_ids"]:
|
||||
if not settings.docStoreConn.update({"id": cid},
|
||||
{"available_int": int(req["available_int"])},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
def _switch_sync():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
for cid in req["chunk_ids"]:
|
||||
if not settings.docStoreConn.update({"id": cid},
|
||||
{"available_int": int(req["available_int"])},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_switch_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -198,23 +215,26 @@ def switch():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "doc_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Chunk deleting failure")
|
||||
deleted_chunk_ids = req["chunk_ids"]
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
for cid in deleted_chunk_ids:
|
||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
return get_json_result(data=True)
|
||||
def _rm_sync():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Chunk deleting failure")
|
||||
deleted_chunk_ids = req["chunk_ids"]
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
for cid in deleted_chunk_ids:
|
||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -222,8 +242,8 @@ def rm():
|
||||
@manager.route('/create', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "content_with_weight")
|
||||
def create():
|
||||
req = request.json
|
||||
async def create():
|
||||
req = await get_request_json()
|
||||
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
|
||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
@ -244,35 +264,38 @@ def create():
|
||||
d["tag_feas"] = req["tag_feas"]
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
d["kb_id"] = [doc.kb_id]
|
||||
d["docnm_kwd"] = doc.name
|
||||
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
||||
d["doc_id"] = doc.id
|
||||
def _create_sync():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
d["kb_id"] = [doc.kb_id]
|
||||
d["docnm_kwd"] = doc.name
|
||||
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
||||
d["doc_id"] = doc.id
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
if kb.pagerank:
|
||||
d[PAGERANK_FLD] = kb.pagerank
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
if kb.pagerank:
|
||||
d[PAGERANK_FLD] = kb.pagerank
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
return get_json_result(data={"chunk_id": chunck_id})
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
return get_json_result(data={"chunk_id": chunck_id})
|
||||
|
||||
return await asyncio.to_thread(_create_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -280,8 +303,8 @@ def create():
|
||||
@manager.route('/retrieval_test', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval_test():
|
||||
req = request.json
|
||||
async def retrieval_test():
|
||||
req = await get_request_json()
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
@ -296,25 +319,29 @@ def retrieval_test():
|
||||
use_kg = req.get("use_kg", False)
|
||||
top = int(req.get("top_k", 1024))
|
||||
langs = req.get("cross_languages", [])
|
||||
tenant_ids = []
|
||||
user_id = current_user.id
|
||||
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
async def _retrieval():
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
tenant_ids = []
|
||||
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
meta_data_filter = {}
|
||||
chat_mdl = None
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
else:
|
||||
meta_data_filter = req.get("meta_data_filter") or {}
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids)
|
||||
|
||||
tenants = UserTenantService.query(user_id=user_id)
|
||||
for kb_id in kb_ids:
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(
|
||||
@ -323,15 +350,16 @@ def retrieval_test():
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
data=False, message='Only owner of dataset authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
|
||||
_question = question
|
||||
if langs:
|
||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
||||
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
@ -341,31 +369,35 @@ def retrieval_test():
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
_question += await keyword_extraction(chat_mdl, _question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
labels = label_question(_question, [kb])
|
||||
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
float(req.get("similarity_threshold", 0.0)),
|
||||
float(req.get("vector_similarity_weight", 0.3)),
|
||||
top,
|
||||
doc_ids, rerank_mdl=rerank_mdl,
|
||||
local_doc_ids, rerank_mdl=rerank_mdl,
|
||||
highlight=req.get("highlight", False),
|
||||
rank_feature=labels
|
||||
)
|
||||
if use_kg:
|
||||
ck = settings.kg_retriever.retrieval(question,
|
||||
ck = settings.kg_retriever.retrieval(_question,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
|
||||
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
ranks["labels"] = labels
|
||||
|
||||
return get_json_result(data=ranks)
|
||||
|
||||
try:
|
||||
return await _retrieval()
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
@ -20,24 +21,25 @@ import uuid
|
||||
from html import escape
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import request, make_response
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
|
||||
from api.db import InputType
|
||||
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, validate_request
|
||||
from common.constants import RetCode, TaskStatus
|
||||
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, DocumentSource
|
||||
from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
|
||||
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, BOX_WEB_OAUTH_REDIRECT_URI, DocumentSource
|
||||
from common.data_source.google_util.constant import WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
|
||||
from common.misc_utils import get_uuid
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api.apps import login_required, current_user
|
||||
from box_sdk_gen import BoxOAuth, OAuthConfig, GetAuthorizeUrlOptions
|
||||
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def set_connector():
|
||||
req = request.json
|
||||
async def set_connector():
|
||||
req = await get_request_json()
|
||||
if req.get("id"):
|
||||
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
|
||||
ConnectorService.update_by_id(req["id"], conn)
|
||||
@ -55,10 +57,9 @@ def set_connector():
|
||||
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
|
||||
"status": TaskStatus.SCHEDULE,
|
||||
}
|
||||
conn["status"] = TaskStatus.SCHEDULE
|
||||
ConnectorService.save(**conn)
|
||||
|
||||
time.sleep(1)
|
||||
await asyncio.sleep(1)
|
||||
e, conn = ConnectorService.get_by_id(req["id"])
|
||||
|
||||
return get_json_result(data=conn.to_dict())
|
||||
@ -89,8 +90,8 @@ def list_logs(connector_id):
|
||||
|
||||
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
def resume(connector_id):
|
||||
req = request.json
|
||||
async def resume(connector_id):
|
||||
req = await get_request_json()
|
||||
if req.get("resume"):
|
||||
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
|
||||
else:
|
||||
@ -101,8 +102,8 @@ def resume(connector_id):
|
||||
@manager.route("/<connector_id>/rebuild", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def rebuild(connector_id):
|
||||
req = request.json
|
||||
async def rebuild(connector_id):
|
||||
req = await get_request_json()
|
||||
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
|
||||
if err:
|
||||
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
|
||||
@ -117,17 +118,27 @@ def rm_connector(connector_id):
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
GOOGLE_WEB_FLOW_STATE_PREFIX = "google_drive_web_flow_state"
|
||||
GOOGLE_WEB_FLOW_RESULT_PREFIX = "google_drive_web_flow_result"
|
||||
WEB_FLOW_TTL_SECS = 15 * 60
|
||||
|
||||
|
||||
def _web_state_cache_key(flow_id: str) -> str:
|
||||
return f"{GOOGLE_WEB_FLOW_STATE_PREFIX}:{flow_id}"
|
||||
def _web_state_cache_key(flow_id: str, source_type: str | None = None) -> str:
|
||||
"""Return Redis key for web OAuth state.
|
||||
|
||||
The default prefix keeps backward compatibility for Google Drive.
|
||||
When source_type == "gmail", a different prefix is used so that
|
||||
Drive/Gmail flows don't clash in Redis.
|
||||
"""
|
||||
prefix = f"{source_type}_web_flow_state"
|
||||
return f"{prefix}:{flow_id}"
|
||||
|
||||
|
||||
def _web_result_cache_key(flow_id: str) -> str:
|
||||
return f"{GOOGLE_WEB_FLOW_RESULT_PREFIX}:{flow_id}"
|
||||
def _web_result_cache_key(flow_id: str, source_type: str | None = None) -> str:
|
||||
"""Return Redis key for web OAuth result.
|
||||
|
||||
Mirrors _web_state_cache_key logic for result storage.
|
||||
"""
|
||||
prefix = f"{source_type}_web_flow_result"
|
||||
return f"{prefix}:{flow_id}"
|
||||
|
||||
|
||||
def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]:
|
||||
@ -146,43 +157,61 @@ def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]:
|
||||
return {"web": web_section}
|
||||
|
||||
|
||||
def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
|
||||
async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, source="drive"):
|
||||
status = "success" if success else "error"
|
||||
auto_close = "window.close();" if success else ""
|
||||
escaped_message = escape(message)
|
||||
# Drive: ragflow-google-drive-oauth
|
||||
# Gmail: ragflow-gmail-oauth
|
||||
payload_type = f"ragflow-{source}-oauth"
|
||||
payload_json = json.dumps(
|
||||
{
|
||||
"type": "ragflow-google-drive-oauth",
|
||||
"type": payload_type,
|
||||
"status": status,
|
||||
"flowId": flow_id or "",
|
||||
"message": message,
|
||||
}
|
||||
)
|
||||
html = GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE.format(
|
||||
# TODO(google-oauth): title/heading/message may need to reflect drive/gmail based on cached type
|
||||
html = WEB_OAUTH_POPUP_TEMPLATE.format(
|
||||
title=f"Google {source.capitalize()} Authorization",
|
||||
heading="Authorization complete" if success else "Authorization failed",
|
||||
message=escaped_message,
|
||||
payload_json=payload_json,
|
||||
auto_close=auto_close,
|
||||
)
|
||||
response = make_response(html, 200)
|
||||
response = await make_response(html, 200)
|
||||
response.headers["Content-Type"] = "text/html; charset=utf-8"
|
||||
return response
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821
|
||||
@manager.route("/google/oauth/web/start", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("credentials")
|
||||
def start_google_drive_web_oauth():
|
||||
if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI:
|
||||
async def start_google_web_oauth():
|
||||
source = request.args.get("type", "google-drive")
|
||||
if source not in ("google-drive", "gmail"):
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.")
|
||||
|
||||
if source == "gmail":
|
||||
redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI
|
||||
scopes = GOOGLE_SCOPES[DocumentSource.GMAIL]
|
||||
else:
|
||||
redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||
scopes = GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]
|
||||
|
||||
if not redirect_uri:
|
||||
return get_json_result(
|
||||
code=RetCode.SERVER_ERROR,
|
||||
message="Google Drive OAuth redirect URI is not configured on the server.",
|
||||
message="Google OAuth redirect URI is not configured on the server.",
|
||||
)
|
||||
|
||||
req = request.json or {}
|
||||
req = await get_request_json()
|
||||
raw_credentials = req.get("credentials", "")
|
||||
|
||||
try:
|
||||
credentials = _load_credentials(raw_credentials)
|
||||
print(credentials)
|
||||
except ValueError as exc:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))
|
||||
|
||||
@ -199,8 +228,8 @@ def start_google_drive_web_oauth():
|
||||
|
||||
flow_id = str(uuid.uuid4())
|
||||
try:
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||
flow = Flow.from_client_config(client_config, scopes=scopes)
|
||||
flow.redirect_uri = redirect_uri
|
||||
authorization_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
include_granted_scopes="true",
|
||||
@ -219,7 +248,7 @@ def start_google_drive_web_oauth():
|
||||
"client_config": client_config,
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
REDIS_CONN.set_obj(_web_state_cache_key(flow_id), cache_payload, WEB_FLOW_TTL_SECS)
|
||||
REDIS_CONN.set_obj(_web_state_cache_key(flow_id, source), cache_payload, WEB_FLOW_TTL_SECS)
|
||||
|
||||
return get_json_result(
|
||||
data={
|
||||
@ -230,60 +259,115 @@ def start_google_drive_web_oauth():
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||
def google_drive_web_oauth_callback():
|
||||
@manager.route("/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||
async def google_gmail_web_oauth_callback():
|
||||
state_id = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
source = "gmail"
|
||||
|
||||
error_description = request.args.get("error_description") or error
|
||||
|
||||
if not state_id:
|
||||
return _render_web_oauth_popup("", False, "Missing OAuth state parameter.")
|
||||
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source)
|
||||
|
||||
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id))
|
||||
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source))
|
||||
if not state_cache:
|
||||
return _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.")
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source)
|
||||
|
||||
state_obj = json.loads(state_cache)
|
||||
client_config = state_obj.get("client_config")
|
||||
if not client_config:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.")
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
|
||||
|
||||
if error:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.")
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source)
|
||||
|
||||
code = request.args.get("code")
|
||||
if not code:
|
||||
return _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.")
|
||||
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
|
||||
|
||||
try:
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||
# TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail)
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GMAIL])
|
||||
flow.redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI
|
||||
flow.fetch_token(code=code)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.")
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source)
|
||||
|
||||
creds_json = flow.credentials.to_json()
|
||||
result_payload = {
|
||||
"user_id": state_obj.get("user_id"),
|
||||
"credentials": creds_json,
|
||||
}
|
||||
REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
|
||||
return _render_web_oauth_popup(state_id, True, "Authorization completed successfully.")
|
||||
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821
|
||||
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||
async def google_drive_web_oauth_callback():
|
||||
state_id = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
source = "google-drive"
|
||||
|
||||
error_description = request.args.get("error_description") or error
|
||||
|
||||
if not state_id:
|
||||
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source)
|
||||
|
||||
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source))
|
||||
if not state_cache:
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source)
|
||||
|
||||
state_obj = json.loads(state_cache)
|
||||
client_config = state_obj.get("client_config")
|
||||
if not client_config:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
|
||||
|
||||
if error:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source)
|
||||
|
||||
code = request.args.get("code")
|
||||
if not code:
|
||||
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
|
||||
|
||||
try:
|
||||
# TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail)
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||
flow.fetch_token(code=code)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source)
|
||||
|
||||
creds_json = flow.credentials.to_json()
|
||||
result_payload = {
|
||||
"user_id": state_obj.get("user_id"),
|
||||
"credentials": creds_json,
|
||||
}
|
||||
REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
|
||||
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
|
||||
|
||||
@manager.route("/google/oauth/web/result", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("flow_id")
|
||||
def poll_google_drive_web_result():
|
||||
req = request.json or {}
|
||||
async def poll_google_web_result():
|
||||
req = await request.json or {}
|
||||
source = request.args.get("type")
|
||||
if source not in ("google-drive", "gmail"):
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.")
|
||||
flow_id = req.get("flow_id")
|
||||
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id))
|
||||
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id, source))
|
||||
if not cache_raw:
|
||||
return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.")
|
||||
|
||||
@ -291,5 +375,109 @@ def poll_google_drive_web_result():
|
||||
if result.get("user_id") != current_user.id:
|
||||
return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.")
|
||||
|
||||
REDIS_CONN.delete(_web_result_cache_key(flow_id))
|
||||
REDIS_CONN.delete(_web_result_cache_key(flow_id, source))
|
||||
return get_json_result(data={"credentials": result.get("credentials")})
|
||||
|
||||
@manager.route("/box/oauth/web/start", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def start_box_web_oauth():
|
||||
req = await get_request_json()
|
||||
|
||||
client_id = req.get("client_id")
|
||||
client_secret = req.get("client_secret")
|
||||
redirect_uri = req.get("redirect_uri", BOX_WEB_OAUTH_REDIRECT_URI)
|
||||
|
||||
if not client_id or not client_secret:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Box client_id and client_secret are required.")
|
||||
|
||||
flow_id = str(uuid.uuid4())
|
||||
|
||||
box_auth = BoxOAuth(
|
||||
OAuthConfig(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
)
|
||||
|
||||
auth_url = box_auth.get_authorize_url(
|
||||
options=GetAuthorizeUrlOptions(
|
||||
redirect_uri=redirect_uri,
|
||||
state=flow_id,
|
||||
)
|
||||
)
|
||||
|
||||
cache_payload = {
|
||||
"user_id": current_user.id,
|
||||
"auth_url": auth_url,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
REDIS_CONN.set_obj(_web_state_cache_key(flow_id, "box"), cache_payload, WEB_FLOW_TTL_SECS)
|
||||
return get_json_result(
|
||||
data = {
|
||||
"flow_id": flow_id,
|
||||
"authorization_url": auth_url,
|
||||
"expires_in": WEB_FLOW_TTL_SECS,}
|
||||
)
|
||||
|
||||
@manager.route("/box/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||
async def box_web_oauth_callback():
|
||||
flow_id = request.args.get("state")
|
||||
if not flow_id:
|
||||
return await _render_web_oauth_popup("", False, "Missing OAuth parameters.", "box")
|
||||
|
||||
code = request.args.get("code")
|
||||
if not code:
|
||||
return await _render_web_oauth_popup(flow_id, False, "Missing authorization code from Box.", "box")
|
||||
|
||||
cache_payload = json.loads(REDIS_CONN.get(_web_state_cache_key(flow_id, "box")))
|
||||
if not cache_payload:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Box OAuth session expired or invalid.")
|
||||
|
||||
error = request.args.get("error")
|
||||
error_description = request.args.get("error_description") or error
|
||||
if error:
|
||||
REDIS_CONN.delete(_web_state_cache_key(flow_id, "box"))
|
||||
return await _render_web_oauth_popup(flow_id, False, error_description or "Authorization failed.", "box")
|
||||
|
||||
auth = BoxOAuth(
|
||||
OAuthConfig(
|
||||
client_id=cache_payload.get("client_id"),
|
||||
client_secret=cache_payload.get("client_secret"),
|
||||
)
|
||||
)
|
||||
|
||||
auth.get_tokens_authorization_code_grant(code)
|
||||
token = auth.retrieve_token()
|
||||
result_payload = {
|
||||
"user_id": cache_payload.get("user_id"),
|
||||
"client_id": cache_payload.get("client_id"),
|
||||
"client_secret": cache_payload.get("client_secret"),
|
||||
"access_token": token.access_token,
|
||||
"refresh_token": token.refresh_token,
|
||||
}
|
||||
|
||||
REDIS_CONN.set_obj(_web_result_cache_key(flow_id, "box"), result_payload, WEB_FLOW_TTL_SECS)
|
||||
REDIS_CONN.delete(_web_state_cache_key(flow_id, "box"))
|
||||
|
||||
return await _render_web_oauth_popup(flow_id, True, "Authorization completed successfully.", "box")
|
||||
|
||||
@manager.route("/box/oauth/web/result", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("flow_id")
|
||||
async def poll_box_web_result():
|
||||
req = await get_request_json()
|
||||
flow_id = req.get("flow_id")
|
||||
|
||||
cache_blob = REDIS_CONN.get(_web_result_cache_key(flow_id, "box"))
|
||||
if not cache_blob:
|
||||
return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.")
|
||||
|
||||
cache_raw = json.loads(cache_blob)
|
||||
if cache_raw.get("user_id") != current_user.id:
|
||||
return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.")
|
||||
|
||||
REDIS_CONN.delete(_web_result_cache_key(flow_id, "box"))
|
||||
|
||||
return get_json_result(data={"credentials": cache_raw})
|
||||
@ -14,19 +14,21 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
import tempfile
|
||||
from quart import Response, request
|
||||
from api.apps import current_user, login_required
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
|
||||
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import chunks_format
|
||||
from common.constants import RetCode, LLMType
|
||||
@ -34,8 +36,8 @@ from common.constants import RetCode, LLMType
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def set_conversation():
|
||||
req = request.json
|
||||
async def set_conversation():
|
||||
req = await get_request_json()
|
||||
conv_id = req.get("conversation_id")
|
||||
is_new = req.get("is_new")
|
||||
name = req.get("name", "New conversation")
|
||||
@ -78,14 +80,13 @@ def set_conversation():
|
||||
|
||||
@manager.route("/get", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def get():
|
||||
async def get():
|
||||
conv_id = request.args["conversation_id"]
|
||||
try:
|
||||
e, conv = ConversationService.get_by_id(conv_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
avatar = None
|
||||
for tenant in tenants:
|
||||
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
|
||||
if dialog and len(dialog) > 0:
|
||||
@ -129,8 +130,9 @@ def getsse(dialog_id):
|
||||
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def rm():
|
||||
conv_ids = request.json["conversation_ids"]
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
conv_ids = req["conversation_ids"]
|
||||
try:
|
||||
for cid in conv_ids:
|
||||
exist, conv = ConversationService.get_by_id(cid)
|
||||
@ -150,7 +152,7 @@ def rm():
|
||||
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_conversation():
|
||||
async def list_conversation():
|
||||
dialog_id = request.args["dialog_id"]
|
||||
try:
|
||||
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
||||
@ -166,8 +168,8 @@ def list_conversation():
|
||||
@manager.route("/completion", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "messages")
|
||||
def completion():
|
||||
req = request.json
|
||||
async def completion():
|
||||
req = await get_request_json()
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
@ -216,10 +218,10 @@ def completion():
|
||||
dia.llm_setting = chat_model_config
|
||||
|
||||
is_embedded = bool(chat_model_id)
|
||||
def stream():
|
||||
async def stream():
|
||||
nonlocal dia, msg, req, conv
|
||||
try:
|
||||
for ans in chat(dia, msg, True, **req):
|
||||
async for ans in async_chat(dia, msg, True, **req):
|
||||
ans = structure_answer(conv, ans, message_id, conv.id)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
if not is_embedded:
|
||||
@ -239,7 +241,7 @@ def completion():
|
||||
|
||||
else:
|
||||
answer = None
|
||||
for ans in chat(dia, msg, **req):
|
||||
async for ans in async_chat(dia, msg, **req):
|
||||
answer = structure_answer(conv, ans, message_id, conv.id)
|
||||
if not is_embedded:
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
@ -248,11 +250,69 @@ def completion():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def sequence2txt():
|
||||
req = await request.form
|
||||
stream_mode = req.get("stream", "false").lower() == "true"
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_data_error_result(message="Missing 'file' in multipart form-data")
|
||||
|
||||
uploaded = files["file"]
|
||||
|
||||
ALLOWED_EXTS = {
|
||||
".wav", ".mp3", ".m4a", ".aac",
|
||||
".flac", ".ogg", ".webm",
|
||||
".opus", ".wma"
|
||||
}
|
||||
|
||||
filename = uploaded.filename or ""
|
||||
suffix = os.path.splitext(filename)[-1].lower()
|
||||
if suffix not in ALLOWED_EXTS:
|
||||
return get_data_error_result(message=
|
||||
f"Unsupported audio format: {suffix}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}"
|
||||
)
|
||||
fd, temp_audio_path = tempfile.mkstemp(suffix=suffix)
|
||||
os.close(fd)
|
||||
await uploaded.save(temp_audio_path)
|
||||
|
||||
tenants = TenantService.get_info_by(current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
asr_id = tenants[0]["asr_id"]
|
||||
if not asr_id:
|
||||
return get_data_error_result(message="No default ASR model is set")
|
||||
|
||||
asr_mdl=LLMBundle(tenants[0]["tenant_id"], LLMType.SPEECH2TEXT, asr_id)
|
||||
if not stream_mode:
|
||||
text = asr_mdl.transcription(temp_audio_path)
|
||||
try:
|
||||
os.remove(temp_audio_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||
return get_json_result(data={"text": text})
|
||||
async def event_stream():
|
||||
try:
|
||||
for evt in asr_mdl.stream_transcription(temp_audio_path):
|
||||
yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
err = {"event": "error", "text": str(e)}
|
||||
yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n"
|
||||
finally:
|
||||
try:
|
||||
os.remove(temp_audio_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||
|
||||
return Response(event_stream(), content_type="text/event-stream")
|
||||
|
||||
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def tts():
|
||||
req = request.json
|
||||
async def tts():
|
||||
req = await get_request_json()
|
||||
text = req["text"]
|
||||
|
||||
tenants = TenantService.get_info_by(current_user.id)
|
||||
@ -284,8 +344,8 @@ def tts():
|
||||
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def delete_msg():
|
||||
req = request.json
|
||||
async def delete_msg():
|
||||
req = await get_request_json()
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
@ -307,8 +367,8 @@ def delete_msg():
|
||||
@manager.route("/thumbup", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def thumbup():
|
||||
req = request.json
|
||||
async def thumbup():
|
||||
req = await get_request_json()
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
@ -334,8 +394,8 @@ def thumbup():
|
||||
@manager.route("/ask", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def ask_about():
|
||||
req = request.json
|
||||
async def ask_about():
|
||||
req = await get_request_json()
|
||||
uid = current_user.id
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
@ -346,10 +406,10 @@ def ask_about():
|
||||
if search_app:
|
||||
search_config = search_app.get("search_config", {})
|
||||
|
||||
def stream():
|
||||
async def stream():
|
||||
nonlocal req, uid
|
||||
try:
|
||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
@ -366,8 +426,8 @@ def ask_about():
|
||||
@manager.route("/mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def mindmap():
|
||||
req = request.json
|
||||
async def mindmap():
|
||||
req = await get_request_json()
|
||||
search_id = req.get("search_id", "")
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
search_config = search_app.get("search_config", {}) if search_app else {}
|
||||
@ -375,7 +435,7 @@ def mindmap():
|
||||
kb_ids.extend(req["kb_ids"])
|
||||
kb_ids = list(set(kb_ids))
|
||||
|
||||
mind_map = gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config)
|
||||
mind_map = await gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config)
|
||||
if "error" in mind_map:
|
||||
return server_error_response(Exception(mind_map["error"]))
|
||||
return get_json_result(data=mind_map)
|
||||
@ -384,8 +444,8 @@ def mindmap():
|
||||
@manager.route("/related_questions", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question")
|
||||
def related_questions():
|
||||
req = request.json
|
||||
async def related_questions():
|
||||
req = await get_request_json()
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_config = {}
|
||||
@ -402,7 +462,7 @@ def related_questions():
|
||||
if "parameter" in gen_conf:
|
||||
del gen_conf["parameter"]
|
||||
prompt = load_prompt("related_question")
|
||||
ans = chat_mdl.chat(
|
||||
ans = await chat_mdl.async_chat(
|
||||
prompt,
|
||||
[
|
||||
{
|
||||
|
||||
@ -14,25 +14,24 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from common.constants import StatusEnum
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@validate_request("prompt_config")
|
||||
@login_required
|
||||
def set_dialog():
|
||||
req = request.json
|
||||
async def set_dialog():
|
||||
req = await get_request_json()
|
||||
dialog_id = req.get("dialog_id", "")
|
||||
is_create = not dialog_id
|
||||
name = req.get("name", "New Dialog")
|
||||
@ -66,7 +65,7 @@ def set_dialog():
|
||||
|
||||
if not is_create:
|
||||
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']:
|
||||
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no knowledge base / Tavily used here.")
|
||||
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.")
|
||||
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["optional"]:
|
||||
@ -154,33 +153,34 @@ def get_kb_names(kb_ids):
|
||||
@login_required
|
||||
def list_dialogs():
|
||||
try:
|
||||
diags = DialogService.query(
|
||||
conversations = DialogService.query(
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value,
|
||||
reverse=True,
|
||||
order_by=DialogService.model.create_time)
|
||||
diags = [d.to_dict() for d in diags]
|
||||
for d in diags:
|
||||
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
|
||||
return get_json_result(data=diags)
|
||||
conversations = [d.to_dict() for d in conversations]
|
||||
for conversation in conversations:
|
||||
conversation["kb_ids"], conversation["kb_names"] = get_kb_names(conversation["kb_ids"])
|
||||
return get_json_result(data=conversations)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/next', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def list_dialogs_next():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
parser_id = request.args.get("parser_id")
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
async def list_dialogs_next():
|
||||
args = request.args
|
||||
keywords = args.get("keywords", "")
|
||||
page_number = int(args.get("page", 0))
|
||||
items_per_page = int(args.get("page_size", 0))
|
||||
parser_id = args.get("parser_id")
|
||||
orderby = args.get("orderby", "create_time")
|
||||
if args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -207,8 +207,8 @@ def list_dialogs_next():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("dialog_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
dialog_list=[]
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
try:
|
||||
|
||||
@ -13,22 +13,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import os.path
|
||||
import pathlib
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
|
||||
from quart import request, make_response
|
||||
from api.apps import current_user, login_required
|
||||
from api.common.check_team_permission import check_kb_team_permission
|
||||
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
||||
from api.db import VALID_FILE_TYPES, FileType
|
||||
from api.db.db_models import Task
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from common.metadata_utils import meta_filter, convert_conditions
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -39,7 +38,7 @@ from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
validate_request, get_request_json,
|
||||
)
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from common.file_utils import get_project_base_directory
|
||||
@ -53,14 +52,16 @@ from common import settings
|
||||
@manager.route("/upload", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def upload():
|
||||
kb_id = request.form.get("kb_id")
|
||||
async def upload():
|
||||
form = await request.form
|
||||
kb_id = form.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
if "file" not in request.files:
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -69,11 +70,11 @@ def upload():
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
raise LookupError("Can't find this knowledgebase!")
|
||||
raise LookupError("Can't find this dataset!")
|
||||
if not check_kb_team_permission(kb, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
err, files = FileService.upload_document(kb, file_objs, current_user.id)
|
||||
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
|
||||
if err:
|
||||
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||
|
||||
@ -87,17 +88,18 @@ def upload():
|
||||
@manager.route("/web_crawl", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id", "name", "url")
|
||||
def web_crawl():
|
||||
kb_id = request.form.get("kb_id")
|
||||
async def web_crawl():
|
||||
form = await request.form
|
||||
kb_id = form.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
name = request.form.get("name")
|
||||
url = request.form.get("url")
|
||||
name = form.get("name")
|
||||
url = form.get("url")
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
raise LookupError("Can't find this knowledgebase!")
|
||||
raise LookupError("Can't find this dataset!")
|
||||
if check_kb_team_permission(kb, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
@ -152,8 +154,8 @@ def web_crawl():
|
||||
@manager.route("/create", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "kb_id")
|
||||
def create():
|
||||
req = request.json
|
||||
async def create():
|
||||
req = await get_request_json()
|
||||
kb_id = req["kb_id"]
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -167,10 +169,10 @@ def create():
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Can't find this knowledgebase!")
|
||||
return get_data_error_result(message="Can't find this dataset!")
|
||||
|
||||
if DocumentService.query(name=req["name"], kb_id=kb_id):
|
||||
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
||||
return get_data_error_result(message="Duplicated document name in the same dataset.")
|
||||
|
||||
kb_root_folder = FileService.get_kb_folder(kb.tenant_id)
|
||||
if not kb_root_folder:
|
||||
@ -208,7 +210,7 @@ def create():
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_docs():
|
||||
async def list_docs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -217,7 +219,7 @@ def list_docs():
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
keywords = request.args.get("keywords", "")
|
||||
|
||||
page_number = int(request.args.get("page", 0))
|
||||
@ -230,7 +232,7 @@ def list_docs():
|
||||
create_time_from = int(request.args.get("create_time_from", 0))
|
||||
create_time_to = int(request.args.get("create_time_to", 0))
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
|
||||
run_status = req.get("run_status", [])
|
||||
if run_status:
|
||||
@ -245,9 +247,55 @@ def list_docs():
|
||||
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
|
||||
|
||||
suffix = req.get("suffix", [])
|
||||
metadata_condition = req.get("metadata_condition", {}) or {}
|
||||
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||
return get_data_error_result(message="metadata_condition must be an object.")
|
||||
metadata = req.get("metadata", {}) or {}
|
||||
if metadata and not isinstance(metadata, dict):
|
||||
return get_data_error_result(message="metadata must be an object.")
|
||||
|
||||
doc_ids_filter = None
|
||||
metas = None
|
||||
if metadata_condition or metadata:
|
||||
metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
|
||||
|
||||
if metadata_condition:
|
||||
doc_ids_filter = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||
if metadata_condition.get("conditions") and not doc_ids_filter:
|
||||
return get_json_result(data={"total": 0, "docs": []})
|
||||
|
||||
if metadata:
|
||||
metadata_doc_ids = None
|
||||
for key, values in metadata.items():
|
||||
if not values:
|
||||
continue
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
values = [str(v) for v in values if v is not None and str(v).strip()]
|
||||
if not values:
|
||||
continue
|
||||
key_doc_ids = set()
|
||||
for value in values:
|
||||
key_doc_ids.update(metas.get(key, {}).get(value, []))
|
||||
if metadata_doc_ids is None:
|
||||
metadata_doc_ids = key_doc_ids
|
||||
else:
|
||||
metadata_doc_ids &= key_doc_ids
|
||||
if not metadata_doc_ids:
|
||||
return get_json_result(data={"total": 0, "docs": []})
|
||||
if metadata_doc_ids is not None:
|
||||
if doc_ids_filter is None:
|
||||
doc_ids_filter = metadata_doc_ids
|
||||
else:
|
||||
doc_ids_filter &= metadata_doc_ids
|
||||
if not doc_ids_filter:
|
||||
return get_json_result(data={"total": 0, "docs": []})
|
||||
|
||||
if doc_ids_filter is not None:
|
||||
doc_ids_filter = list(doc_ids_filter)
|
||||
|
||||
try:
|
||||
docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix)
|
||||
docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, doc_ids_filter)
|
||||
|
||||
if create_time_from or create_time_to:
|
||||
filtered_docs = []
|
||||
@ -270,8 +318,8 @@ def list_docs():
|
||||
|
||||
@manager.route("/filter", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def get_filter():
|
||||
req = request.get_json()
|
||||
async def get_filter():
|
||||
req = await get_request_json()
|
||||
|
||||
kb_id = req.get("kb_id")
|
||||
if not kb_id:
|
||||
@ -281,7 +329,7 @@ def get_filter():
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
|
||||
keywords = req.get("keywords", "")
|
||||
|
||||
@ -308,8 +356,8 @@ def get_filter():
|
||||
|
||||
@manager.route("/infos", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def docinfos():
|
||||
req = request.json
|
||||
async def doc_infos():
|
||||
req = await get_request_json()
|
||||
doc_ids = req["doc_ids"]
|
||||
for doc_id in doc_ids:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
@ -318,6 +366,87 @@ def docinfos():
|
||||
return get_json_result(data=list(docs.dicts()))
|
||||
|
||||
|
||||
@manager.route("/metadata/summary", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def metadata_summary():
|
||||
req = await get_request_json()
|
||||
kb_id = req.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
|
||||
try:
|
||||
summary = DocumentService.get_metadata_summary(kb_id)
|
||||
return get_json_result(data={"summary": summary})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/metadata/update", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def metadata_update():
|
||||
req = await get_request_json()
|
||||
kb_id = req.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
|
||||
selector = req.get("selector", {}) or {}
|
||||
updates = req.get("updates", []) or []
|
||||
deletes = req.get("deletes", []) or []
|
||||
|
||||
if not isinstance(selector, dict):
|
||||
return get_json_result(data=False, message="selector must be an object.", code=RetCode.ARGUMENT_ERROR)
|
||||
if not isinstance(updates, list) or not isinstance(deletes, list):
|
||||
return get_json_result(data=False, message="updates and deletes must be lists.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
metadata_condition = selector.get("metadata_condition", {}) or {}
|
||||
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||
return get_json_result(data=False, message="metadata_condition must be an object.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
document_ids = selector.get("document_ids", []) or []
|
||||
if document_ids and not isinstance(document_ids, list):
|
||||
return get_json_result(data=False, message="document_ids must be a list.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
for upd in updates:
|
||||
if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd:
|
||||
return get_json_result(data=False, message="Each update requires key and value.", code=RetCode.ARGUMENT_ERROR)
|
||||
for d in deletes:
|
||||
if not isinstance(d, dict) or not d.get("key"):
|
||||
return get_json_result(data=False, message="Each delete requires key.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([kb_id])
|
||||
target_doc_ids = set(kb_doc_ids)
|
||||
if document_ids:
|
||||
invalid_ids = set(document_ids) - set(kb_doc_ids)
|
||||
if invalid_ids:
|
||||
return get_json_result(data=False, message=f"These documents do not belong to dataset {kb_id}: {', '.join(invalid_ids)}", code=RetCode.ARGUMENT_ERROR)
|
||||
target_doc_ids = set(document_ids)
|
||||
|
||||
if metadata_condition:
|
||||
metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
|
||||
filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||
target_doc_ids = target_doc_ids & filtered_ids
|
||||
if metadata_condition.get("conditions") and not target_doc_ids:
|
||||
return get_json_result(data={"updated": 0, "matched_docs": 0})
|
||||
|
||||
target_doc_ids = list(target_doc_ids)
|
||||
updated = DocumentService.batch_update_metadata(kb_id, target_doc_ids, updates, deletes)
|
||||
return get_json_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
|
||||
|
||||
|
||||
@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
|
||||
# @login_required
|
||||
def thumbnails():
|
||||
@ -340,8 +469,8 @@ def thumbnails():
|
||||
@manager.route("/change_status", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_ids", "status")
|
||||
def change_status():
|
||||
req = request.get_json()
|
||||
async def change_status():
|
||||
req = await get_request_json()
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
status = str(req.get("status", ""))
|
||||
|
||||
@ -361,7 +490,7 @@ def change_status():
|
||||
continue
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
result[doc_id] = {"error": "Can't find this knowledgebase!"}
|
||||
result[doc_id] = {"error": "Can't find this dataset!"}
|
||||
continue
|
||||
if not DocumentService.update_by_id(doc_id, {"status": str(status)}):
|
||||
result[doc_id] = {"error": "Database error (Document update)!"}
|
||||
@ -380,8 +509,8 @@ def change_status():
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
doc_ids = req["doc_id"]
|
||||
if isinstance(doc_ids, str):
|
||||
doc_ids = [doc_ids]
|
||||
@ -390,7 +519,7 @@ def rm():
|
||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
errors = FileService.delete_docs(doc_ids, current_user.id)
|
||||
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||
@ -401,46 +530,50 @@ def rm():
|
||||
@manager.route("/run", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_ids", "run")
|
||||
def run():
|
||||
req = request.json
|
||||
for doc_id in req["doc_ids"]:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
async def run():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
kb_table_num_map = {}
|
||||
for id in req["doc_ids"]:
|
||||
info = {"run": str(req["run"]), "progress": 0}
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
|
||||
info["progress_msg"] = ""
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
def _run_sync():
|
||||
for doc_id in req["doc_ids"]:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
kb_table_num_map = {}
|
||||
for id in req["doc_ids"]:
|
||||
info = {"run": str(req["run"]), "progress": 0}
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
|
||||
info["progress_msg"] = ""
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
|
||||
if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||
if str(doc.run) == TaskStatus.RUNNING.value:
|
||||
cancel_all_task_of(id)
|
||||
else:
|
||||
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
|
||||
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
|
||||
DocumentService.clear_chunk_num_when_rerun(doc.id)
|
||||
tenant_id = DocumentService.get_tenant_id(id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
DocumentService.update_by_id(id, info)
|
||||
if req.get("delete", False):
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||
if str(doc.run) == TaskStatus.RUNNING.value:
|
||||
cancel_all_task_of(id)
|
||||
else:
|
||||
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
|
||||
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
|
||||
DocumentService.clear_chunk_num_when_rerun(doc.id)
|
||||
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||
doc = doc.to_dict()
|
||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||
DocumentService.update_by_id(id, info)
|
||||
if req.get("delete", False):
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||
doc_dict = doc.to_dict()
|
||||
DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_run_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -448,66 +581,72 @@ def run():
|
||||
@manager.route("/rename", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "name")
|
||||
def rename():
|
||||
req = request.json
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
async def rename():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
|
||||
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
|
||||
def _rename_sync():
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||
if d.name == req["name"]:
|
||||
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
|
||||
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
|
||||
return get_data_error_result(message="Database error (Document rename)!")
|
||||
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||
if d.name == req["name"]:
|
||||
return get_data_error_result(message="Duplicated document name in the same dataset.")
|
||||
|
||||
informs = File2DocumentService.get_by_document_id(req["doc_id"])
|
||||
if informs:
|
||||
e, file = FileService.get_by_id(informs[0].file_id)
|
||||
FileService.update_by_id(file.id, {"name": req["name"]})
|
||||
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
|
||||
return get_data_error_result(message="Database error (Document rename)!")
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
title_tks = rag_tokenizer.tokenize(req["name"])
|
||||
es_body = {
|
||||
"docnm_kwd": req["name"],
|
||||
"title_tks": title_tks,
|
||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||
}
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.update(
|
||||
{"doc_id": req["doc_id"]},
|
||||
es_body,
|
||||
search.index_name(tenant_id),
|
||||
doc.kb_id,
|
||||
)
|
||||
informs = File2DocumentService.get_by_document_id(req["doc_id"])
|
||||
if informs:
|
||||
e, file = FileService.get_by_id(informs[0].file_id)
|
||||
FileService.update_by_id(file.id, {"name": req["name"]})
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
title_tks = rag_tokenizer.tokenize(req["name"])
|
||||
es_body = {
|
||||
"docnm_kwd": req["name"],
|
||||
"title_tks": title_tks,
|
||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||
}
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.update(
|
||||
{"doc_id": req["doc_id"]},
|
||||
es_body,
|
||||
search.index_name(tenant_id),
|
||||
doc.kb_id,
|
||||
)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rename_sync)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/get/<doc_id>", methods=["GET"]) # noqa: F821
|
||||
# @login_required
|
||||
def get(doc_id):
|
||||
async def get(doc_id):
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
response = flask.make_response(settings.STORAGE_IMPL.get(b, n))
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||
response = await make_response(data)
|
||||
|
||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
if ext:
|
||||
if doc.type == FileType.VISUAL.value:
|
||||
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
|
||||
else:
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
|
||||
@ -517,12 +656,27 @@ def get(doc_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/download/<attachment_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def download_attachment(attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/change_parser", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def change_parser():
|
||||
async def change_parser():
|
||||
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
@ -544,6 +698,7 @@ def change_parser():
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
if "pipeline_id" in req and req["pipeline_id"] != "":
|
||||
@ -572,13 +727,14 @@ def change_parser():
|
||||
|
||||
@manager.route("/image/<image_id>", methods=["GET"]) # noqa: F821
|
||||
# @login_required
|
||||
def get_image(image_id):
|
||||
async def get_image(image_id):
|
||||
try:
|
||||
arr = image_id.split("-")
|
||||
if len(arr) != 2:
|
||||
return get_data_error_result(message="Image not found.")
|
||||
bkt, nm = image_id.split("-")
|
||||
response = flask.make_response(settings.STORAGE_IMPL.get(bkt, nm))
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", "image/JPEG")
|
||||
return response
|
||||
except Exception as e:
|
||||
@ -588,24 +744,26 @@ def get_image(image_id):
|
||||
@manager.route("/upload_and_parse", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id")
|
||||
def upload_and_parse():
|
||||
if "file" not in request.files:
|
||||
async def upload_and_parse():
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
|
||||
|
||||
form = await request.form
|
||||
doc_ids = doc_upload_and_parse(form.get("conversation_id"), file_objs, current_user.id)
|
||||
return get_json_result(data=doc_ids)
|
||||
|
||||
|
||||
@manager.route("/parse", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def parse():
|
||||
url = request.json.get("url") if request.json else ""
|
||||
async def parse():
|
||||
req = await get_request_json()
|
||||
url = req.get("url", "")
|
||||
if url:
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -646,10 +804,11 @@ def parse():
|
||||
txt = FileService.parse_docs([f], current_user.id)
|
||||
return get_json_result(data=txt)
|
||||
|
||||
if "file" not in request.files:
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
txt = FileService.parse_docs(file_objs, current_user.id)
|
||||
|
||||
return get_json_result(data=txt)
|
||||
@ -658,8 +817,8 @@ def parse():
|
||||
@manager.route("/set_meta", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "meta")
|
||||
def set_meta():
|
||||
req = request.json
|
||||
async def set_meta():
|
||||
req = await get_request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
@ -667,7 +826,10 @@ def set_meta():
|
||||
if not isinstance(meta, dict):
|
||||
return get_json_result(data=False, message="Only dictionary type supported.", code=RetCode.ARGUMENT_ERROR)
|
||||
for k, v in meta.items():
|
||||
if not isinstance(v, str) and not isinstance(v, int) and not isinstance(v, float):
|
||||
if isinstance(v, list):
|
||||
if not all(isinstance(i, (str, int, float)) for i in v):
|
||||
return get_json_result(data=False, message=f"The type is not supported in list: {v}", code=RetCode.ARGUMENT_ERROR)
|
||||
elif not isinstance(v, (str, int, float)):
|
||||
return get_json_result(data=False, message=f"The type is not supported: {v}", code=RetCode.ARGUMENT_ERROR)
|
||||
except Exception as e:
|
||||
return get_json_result(data=False, message=f"Json syntax error: {e}", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -685,3 +847,13 @@ def set_meta():
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/upload_info", methods=["POST"]) # noqa: F821
|
||||
async def upload_info():
|
||||
files = await request.files
|
||||
file = files['file'] if files and files.get("file") else None
|
||||
try:
|
||||
return get_json_result(data=FileService.upload_info(current_user.id, file, request.args.get("url")))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
479
api/apps/evaluation_app.py
Normal file
479
api/apps/evaluation_app.py
Normal file
@ -0,0 +1,479 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
RAG Evaluation API Endpoints
|
||||
|
||||
Provides REST API for RAG evaluation functionality including:
|
||||
- Dataset management
|
||||
- Test case management
|
||||
- Evaluation execution
|
||||
- Results retrieval
|
||||
- Configuration recommendations
|
||||
"""
|
||||
|
||||
from quart import request
|
||||
from api.apps import login_required, current_user
|
||||
from api.db.services.evaluation_service import EvaluationService
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
get_request_json,
|
||||
server_error_response,
|
||||
validate_request
|
||||
)
|
||||
from common.constants import RetCode
|
||||
|
||||
|
||||
# ==================== Dataset Management ====================
|
||||
|
||||
@manager.route('/dataset/create', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "kb_ids")
|
||||
async def create_dataset():
|
||||
"""
|
||||
Create a new evaluation dataset.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"name": "Dataset name",
|
||||
"description": "Optional description",
|
||||
"kb_ids": ["kb_id1", "kb_id2"]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
name = req.get("name", "").strip()
|
||||
description = req.get("description", "")
|
||||
kb_ids = req.get("kb_ids", [])
|
||||
|
||||
if not name:
|
||||
return get_data_error_result(message="Dataset name cannot be empty")
|
||||
|
||||
if not kb_ids or not isinstance(kb_ids, list):
|
||||
return get_data_error_result(message="kb_ids must be a non-empty list")
|
||||
|
||||
success, result = EvaluationService.create_dataset(
|
||||
name=name,
|
||||
description=description,
|
||||
kb_ids=kb_ids,
|
||||
tenant_id=current_user.id,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message=result)
|
||||
|
||||
return get_json_result(data={"dataset_id": result})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def list_datasets():
|
||||
"""
|
||||
List evaluation datasets for current tenant.
|
||||
|
||||
Query params:
|
||||
- page: Page number (default: 1)
|
||||
- page_size: Items per page (default: 20)
|
||||
"""
|
||||
try:
|
||||
page = int(request.args.get("page", 1))
|
||||
page_size = int(request.args.get("page_size", 20))
|
||||
|
||||
result = EvaluationService.list_datasets(
|
||||
tenant_id=current_user.id,
|
||||
user_id=current_user.id,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_dataset(dataset_id):
|
||||
"""Get dataset details by ID"""
|
||||
try:
|
||||
dataset = EvaluationService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
return get_data_error_result(
|
||||
message="Dataset not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
return get_json_result(data=dataset)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>', methods=['PUT']) # noqa: F821
|
||||
@login_required
|
||||
async def update_dataset(dataset_id):
|
||||
"""
|
||||
Update dataset.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"name": "New name",
|
||||
"description": "New description",
|
||||
"kb_ids": ["kb_id1", "kb_id2"]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
|
||||
# Remove fields that shouldn't be updated
|
||||
req.pop("id", None)
|
||||
req.pop("tenant_id", None)
|
||||
req.pop("created_by", None)
|
||||
req.pop("create_time", None)
|
||||
|
||||
success = EvaluationService.update_dataset(dataset_id, **req)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message="Failed to update dataset")
|
||||
|
||||
return get_json_result(data={"dataset_id": dataset_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>', methods=['DELETE']) # noqa: F821
|
||||
@login_required
|
||||
async def delete_dataset(dataset_id):
|
||||
"""Delete dataset (soft delete)"""
|
||||
try:
|
||||
success = EvaluationService.delete_dataset(dataset_id)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message="Failed to delete dataset")
|
||||
|
||||
return get_json_result(data={"dataset_id": dataset_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
# ==================== Test Case Management ====================
|
||||
|
||||
@manager.route('/dataset/<dataset_id>/case/add', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question")
|
||||
async def add_test_case(dataset_id):
|
||||
"""
|
||||
Add a test case to a dataset.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"question": "Test question",
|
||||
"reference_answer": "Optional ground truth answer",
|
||||
"relevant_doc_ids": ["doc_id1", "doc_id2"],
|
||||
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"],
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
question = req.get("question", "").strip()
|
||||
|
||||
if not question:
|
||||
return get_data_error_result(message="Question cannot be empty")
|
||||
|
||||
success, result = EvaluationService.add_test_case(
|
||||
dataset_id=dataset_id,
|
||||
question=question,
|
||||
reference_answer=req.get("reference_answer"),
|
||||
relevant_doc_ids=req.get("relevant_doc_ids"),
|
||||
relevant_chunk_ids=req.get("relevant_chunk_ids"),
|
||||
metadata=req.get("metadata")
|
||||
)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message=result)
|
||||
|
||||
return get_json_result(data={"case_id": result})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>/case/import', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("cases")
|
||||
async def import_test_cases(dataset_id):
|
||||
"""
|
||||
Bulk import test cases.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"cases": [
|
||||
{
|
||||
"question": "Question 1",
|
||||
"reference_answer": "Answer 1",
|
||||
...
|
||||
},
|
||||
{
|
||||
"question": "Question 2",
|
||||
...
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
cases = req.get("cases", [])
|
||||
|
||||
if not cases or not isinstance(cases, list):
|
||||
return get_data_error_result(message="cases must be a non-empty list")
|
||||
|
||||
success_count, failure_count = EvaluationService.import_test_cases(
|
||||
dataset_id=dataset_id,
|
||||
cases=cases
|
||||
)
|
||||
|
||||
return get_json_result(data={
|
||||
"success_count": success_count,
|
||||
"failure_count": failure_count,
|
||||
"total": len(cases)
|
||||
})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>/cases', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_test_cases(dataset_id):
|
||||
"""Get all test cases for a dataset"""
|
||||
try:
|
||||
cases = EvaluationService.get_test_cases(dataset_id)
|
||||
return get_json_result(data={"cases": cases, "total": len(cases)})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/case/<case_id>', methods=['DELETE']) # noqa: F821
|
||||
@login_required
|
||||
async def delete_test_case(case_id):
|
||||
"""Delete a test case"""
|
||||
try:
|
||||
success = EvaluationService.delete_test_case(case_id)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message="Failed to delete test case")
|
||||
|
||||
return get_json_result(data={"case_id": case_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
# ==================== Evaluation Execution ====================
|
||||
|
||||
@manager.route('/run/start', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("dataset_id", "dialog_id")
|
||||
async def start_evaluation():
|
||||
"""
|
||||
Start an evaluation run.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"dataset_id": "dataset_id",
|
||||
"dialog_id": "dialog_id",
|
||||
"name": "Optional run name"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
dataset_id = req.get("dataset_id")
|
||||
dialog_id = req.get("dialog_id")
|
||||
name = req.get("name")
|
||||
|
||||
success, result = EvaluationService.start_evaluation(
|
||||
dataset_id=dataset_id,
|
||||
dialog_id=dialog_id,
|
||||
user_id=current_user.id,
|
||||
name=name
|
||||
)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message=result)
|
||||
|
||||
return get_json_result(data={"run_id": result})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/<run_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_evaluation_run(run_id):
|
||||
"""Get evaluation run details"""
|
||||
try:
|
||||
result = EvaluationService.get_run_results(run_id)
|
||||
|
||||
if not result:
|
||||
return get_data_error_result(
|
||||
message="Evaluation run not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/<run_id>/results', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_run_results(run_id):
|
||||
"""Get detailed results for an evaluation run"""
|
||||
try:
|
||||
result = EvaluationService.get_run_results(run_id)
|
||||
|
||||
if not result:
|
||||
return get_data_error_result(
|
||||
message="Evaluation run not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def list_evaluation_runs():
|
||||
"""
|
||||
List evaluation runs.
|
||||
|
||||
Query params:
|
||||
- dataset_id: Filter by dataset (optional)
|
||||
- dialog_id: Filter by dialog (optional)
|
||||
- page: Page number (default: 1)
|
||||
- page_size: Items per page (default: 20)
|
||||
"""
|
||||
try:
|
||||
# TODO: Implement list_runs in EvaluationService
|
||||
return get_json_result(data={"runs": [], "total": 0})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/<run_id>', methods=['DELETE']) # noqa: F821
|
||||
@login_required
|
||||
async def delete_evaluation_run(run_id):
|
||||
"""Delete an evaluation run"""
|
||||
try:
|
||||
# TODO: Implement delete_run in EvaluationService
|
||||
return get_json_result(data={"run_id": run_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
# ==================== Analysis & Recommendations ====================
|
||||
|
||||
@manager.route('/run/<run_id>/recommendations', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_recommendations(run_id):
|
||||
"""Get configuration recommendations based on evaluation results"""
|
||||
try:
|
||||
recommendations = EvaluationService.get_recommendations(run_id)
|
||||
return get_json_result(data={"recommendations": recommendations})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/compare', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("run_ids")
|
||||
async def compare_runs():
|
||||
"""
|
||||
Compare multiple evaluation runs.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"run_ids": ["run_id1", "run_id2", "run_id3"]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
run_ids = req.get("run_ids", [])
|
||||
|
||||
if not run_ids or not isinstance(run_ids, list) or len(run_ids) < 2:
|
||||
return get_data_error_result(
|
||||
message="run_ids must be a list with at least 2 run IDs"
|
||||
)
|
||||
|
||||
# TODO: Implement compare_runs in EvaluationService
|
||||
return get_json_result(data={"comparison": {}})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/<run_id>/export', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def export_results(run_id):
|
||||
"""Export evaluation results as JSON/CSV"""
|
||||
try:
|
||||
# format_type = request.args.get("format", "json") # TODO: Use for CSV export
|
||||
|
||||
result = EvaluationService.get_run_results(run_id)
|
||||
|
||||
if not result:
|
||||
return get_data_error_result(
|
||||
message="Evaluation run not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
# TODO: Implement CSV export
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
# ==================== Real-time Evaluation ====================
|
||||
|
||||
@manager.route('/evaluate_single', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "dialog_id")
|
||||
async def evaluate_single():
|
||||
"""
|
||||
Evaluate a single question-answer pair in real-time.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"question": "Test question",
|
||||
"dialog_id": "dialog_id",
|
||||
"reference_answer": "Optional ground truth",
|
||||
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# req = await get_request_json() # TODO: Use for single evaluation implementation
|
||||
|
||||
# TODO: Implement single evaluation
|
||||
# This would execute the RAG pipeline and return metrics immediately
|
||||
|
||||
return get_json_result(data={
|
||||
"answer": "",
|
||||
"metrics": {},
|
||||
"retrieved_chunks": []
|
||||
})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -19,22 +19,20 @@ from pathlib import Path
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from api.apps import login_required, current_user
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode
|
||||
from api.db import FileType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.utils.api_utils import get_json_result
|
||||
|
||||
|
||||
@manager.route('/convert', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids", "kb_ids")
|
||||
def convert():
|
||||
req = request.json
|
||||
async def convert():
|
||||
req = await get_request_json()
|
||||
kb_ids = req["kb_ids"]
|
||||
file_ids = req["file_ids"]
|
||||
file2documents = []
|
||||
@ -70,7 +68,7 @@ def convert():
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
message="Can't find this dataset!")
|
||||
e, file = FileService.get_by_id(id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
@ -79,7 +77,8 @@ def convert():
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": FileService.get_parser(file.type, file.name, kb.parser_id),
|
||||
"parser_id": kb.parser_id,
|
||||
"pipeline_id": kb.pipeline_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": current_user.id,
|
||||
"type": file.type,
|
||||
@ -103,8 +102,8 @@ def convert():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
file_ids = req["file_ids"]
|
||||
if not file_ids:
|
||||
return get_json_result(
|
||||
|
||||
@ -14,13 +14,12 @@
|
||||
# limitations under the License
|
||||
#
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request, make_response
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
from api.common.check_team_permission import check_file_team_permission
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -31,7 +30,7 @@ from common.constants import RetCode, FileSource
|
||||
from api.db import FileType
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.file_service import FileService
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.api_utils import get_json_result, get_request_json
|
||||
from api.utils.file_utils import filename_type
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from common import settings
|
||||
@ -40,17 +39,19 @@ from common import settings
|
||||
@manager.route('/upload', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
# @validate_request("parent_id")
|
||||
def upload():
|
||||
pf_id = request.form.get("parent_id")
|
||||
async def upload():
|
||||
form = await request.form
|
||||
pf_id = form.get("parent_id")
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
if 'file' not in request.files:
|
||||
files = await request.files
|
||||
if 'file' not in files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
file_objs = request.files.getlist('file')
|
||||
file_objs = files.getlist('file')
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
@ -61,9 +62,10 @@ def upload():
|
||||
e, pf_folder = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_data_error_result( message="Can't find this folder!")
|
||||
for file_obj in file_objs:
|
||||
|
||||
async def _handle_single_file(file_obj):
|
||||
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id):
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
|
||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
||||
|
||||
# split file name path
|
||||
@ -75,35 +77,36 @@ def upload():
|
||||
file_len = len(file_obj_names)
|
||||
|
||||
# get folder
|
||||
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
||||
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
|
||||
len_id_list = len(file_id_list)
|
||||
|
||||
# create folder
|
||||
if file_len != len_id_list:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
len_id_list)
|
||||
else:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
||||
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
len_id_list)
|
||||
|
||||
# file type
|
||||
filetype = filename_type(file_obj_names[file_len - 1])
|
||||
location = file_obj_names[file_len - 1]
|
||||
while settings.STORAGE_IMPL.obj_exist(last_folder.id, location):
|
||||
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
|
||||
location += "_"
|
||||
blob = file_obj.read()
|
||||
filename = duplicate_name(
|
||||
blob = await asyncio.to_thread(file_obj.read)
|
||||
filename = await asyncio.to_thread(
|
||||
duplicate_name,
|
||||
FileService.query,
|
||||
name=file_obj_names[file_len - 1],
|
||||
parent_id=last_folder.id)
|
||||
settings.STORAGE_IMPL.put(last_folder.id, location, blob)
|
||||
file = {
|
||||
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
|
||||
file_data = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": last_folder.id,
|
||||
"tenant_id": current_user.id,
|
||||
@ -113,8 +116,13 @@ def upload():
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
}
|
||||
file = FileService.insert(file)
|
||||
file_res.append(file.to_json())
|
||||
inserted = await asyncio.to_thread(FileService.insert, file_data)
|
||||
return inserted.to_json()
|
||||
|
||||
for file_obj in file_objs:
|
||||
res = await _handle_single_file(file_obj)
|
||||
file_res.append(res)
|
||||
|
||||
return get_json_result(data=file_res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -123,10 +131,10 @@ def upload():
|
||||
@manager.route('/create', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.json
|
||||
pf_id = request.json.get("parent_id")
|
||||
input_file_type = request.json.get("type")
|
||||
async def create():
|
||||
req = await get_request_json()
|
||||
pf_id = req.get("parent_id")
|
||||
input_file_type = req.get("type")
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
@ -238,59 +246,62 @@ def get_all_parent_folders():
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
file_ids = req["file_ids"]
|
||||
|
||||
def _delete_single_file(file):
|
||||
try:
|
||||
if file.location:
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
except Exception:
|
||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}")
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(file.id)
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if e and doc:
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if tenant_id:
|
||||
DocumentService.remove_document(doc, tenant_id)
|
||||
File2DocumentService.delete_by_file_id(file.id)
|
||||
|
||||
FileService.delete(file)
|
||||
|
||||
def _delete_folder_recursive(folder, tenant_id):
|
||||
sub_files = FileService.list_all_files_by_parent_id(folder.id)
|
||||
for sub_file in sub_files:
|
||||
if sub_file.type == FileType.FOLDER.value:
|
||||
_delete_folder_recursive(sub_file, tenant_id)
|
||||
else:
|
||||
_delete_single_file(sub_file)
|
||||
|
||||
FileService.delete(folder)
|
||||
|
||||
try:
|
||||
for file_id in file_ids:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e or not file:
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
def _delete_single_file(file):
|
||||
try:
|
||||
if file.location:
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
except Exception as e:
|
||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
|
||||
|
||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||
continue
|
||||
informs = File2DocumentService.get_by_file_id(file.id)
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if e and doc:
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if tenant_id:
|
||||
DocumentService.remove_document(doc, tenant_id)
|
||||
File2DocumentService.delete_by_file_id(file.id)
|
||||
|
||||
if file.type == FileType.FOLDER.value:
|
||||
_delete_folder_recursive(file, current_user.id)
|
||||
continue
|
||||
FileService.delete(file)
|
||||
|
||||
_delete_single_file(file)
|
||||
def _delete_folder_recursive(folder, tenant_id):
|
||||
sub_files = FileService.list_all_files_by_parent_id(folder.id)
|
||||
for sub_file in sub_files:
|
||||
if sub_file.type == FileType.FOLDER.value:
|
||||
_delete_folder_recursive(sub_file, tenant_id)
|
||||
else:
|
||||
_delete_single_file(sub_file)
|
||||
|
||||
return get_json_result(data=True)
|
||||
FileService.delete(folder)
|
||||
|
||||
def _rm_sync():
|
||||
for file_id in file_ids:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e or not file:
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||
continue
|
||||
|
||||
if file.type == FileType.FOLDER.value:
|
||||
_delete_folder_recursive(file, current_user.id)
|
||||
continue
|
||||
|
||||
_delete_single_file(file)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -299,8 +310,8 @@ def rm():
|
||||
@manager.route('/rename', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_id", "name")
|
||||
def rename():
|
||||
req = request.json
|
||||
async def rename():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
@ -338,7 +349,7 @@ def rename():
|
||||
|
||||
@manager.route('/get/<file_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get(file_id):
|
||||
async def get(file_id):
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
@ -346,12 +357,12 @@ def get(file_id):
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
|
||||
if not blob:
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = settings.STORAGE_IMPL.get(b, n)
|
||||
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||
|
||||
response = flask.make_response(blob)
|
||||
response = await make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
if ext:
|
||||
@ -368,8 +379,8 @@ def get(file_id):
|
||||
@manager.route("/mv", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("src_file_ids", "dest_file_id")
|
||||
def move():
|
||||
req = request.json
|
||||
async def move():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
file_ids = req["src_file_ids"]
|
||||
dest_parent_id = req["dest_file_id"]
|
||||
@ -444,10 +455,12 @@ def move():
|
||||
},
|
||||
)
|
||||
|
||||
for file in files:
|
||||
_move_entry_recursive(file, dest_folder)
|
||||
def _move_sync():
|
||||
for file in files:
|
||||
_move_entry_recursive(file, dest_folder)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return get_json_result(data=True)
|
||||
return await asyncio.to_thread(_move_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -16,12 +16,12 @@
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import asyncio
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
import numpy as np
|
||||
|
||||
|
||||
from api.db.services.connector_service import Connector2KbService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
@ -30,7 +30,8 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
|
||||
get_request_json
|
||||
from api.db import VALID_FILE_TYPES
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
@ -41,23 +42,28 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/create', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.json
|
||||
req = KnowledgebaseService.create_with_name(
|
||||
async def create():
|
||||
req = await get_request_json()
|
||||
e, res = KnowledgebaseService.create_with_name(
|
||||
name = req.pop("name", None),
|
||||
tenant_id = current_user.id,
|
||||
parser_id = req.pop("parser_id", None),
|
||||
**req
|
||||
)
|
||||
|
||||
if not e:
|
||||
return res
|
||||
|
||||
try:
|
||||
if not KnowledgebaseService.save(**req):
|
||||
if not KnowledgebaseService.save(**res):
|
||||
return get_data_error_result()
|
||||
return get_json_result(data={"kb_id":req["id"]})
|
||||
return get_json_result(data={"kb_id":res["id"]})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -66,8 +72,8 @@ def create():
|
||||
@login_required
|
||||
@validate_request("kb_id", "name", "description", "parser_id")
|
||||
@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||
def update():
|
||||
req = request.json
|
||||
async def update():
|
||||
req = await get_request_json()
|
||||
if not isinstance(req["name"], str):
|
||||
return get_data_error_result(message="Dataset name must be string.")
|
||||
if req["name"].strip() == "":
|
||||
@ -87,19 +93,19 @@ def update():
|
||||
if not KnowledgebaseService.query(
|
||||
created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
data=False, message='Only owner of dataset authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
message="Can't find this dataset!")
|
||||
|
||||
if req["name"].lower() != kb.name.lower() \
|
||||
and len(
|
||||
KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
|
||||
return get_data_error_result(
|
||||
message="Duplicated knowledgebase name.")
|
||||
message="Duplicated dataset name.")
|
||||
|
||||
del req["kb_id"]
|
||||
connectors = []
|
||||
@ -111,12 +117,22 @@ def update():
|
||||
|
||||
if kb.pagerank != req.get("pagerank", 0):
|
||||
if req.get("pagerank", 0) > 0:
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
await asyncio.to_thread(
|
||||
settings.docStoreConn.update,
|
||||
{"kb_id": kb.id},
|
||||
{PAGERANK_FLD: req["pagerank"]},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb.id,
|
||||
)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
await asyncio.to_thread(
|
||||
settings.docStoreConn.update,
|
||||
{"exists": PAGERANK_FLD},
|
||||
{"remove": PAGERANK_FLD},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb.id,
|
||||
)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||
if not e:
|
||||
@ -146,12 +162,12 @@ def detail():
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
data=False, message='Only owner of dataset authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
kb = KnowledgebaseService.get_detail(kb_id)
|
||||
if not kb:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
message="Can't find this dataset!")
|
||||
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
|
||||
kb["connectors"] = Connector2KbService.list_connectors(kb_id)
|
||||
|
||||
@ -165,18 +181,19 @@ def detail():
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def list_kbs():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
parser_id = request.args.get("parser_id")
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
async def list_kbs():
|
||||
args = request.args
|
||||
keywords = args.get("keywords", "")
|
||||
page_number = int(args.get("page", 0))
|
||||
items_per_page = int(args.get("page_size", 0))
|
||||
parser_id = args.get("parser_id")
|
||||
orderby = args.get("orderby", "create_time")
|
||||
if args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -198,11 +215,12 @@ def list_kbs():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -214,28 +232,31 @@ def rm():
|
||||
created_by=current_user.id, id=req["kb_id"])
|
||||
if not kbs:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
data=False, message='Only owner of dataset authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
||||
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||
def _rm_sync():
|
||||
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
||||
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
f2d = File2DocumentService.get_by_document_id(doc.id)
|
||||
if f2d:
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
FileService.filter_delete(
|
||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
|
||||
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
f2d = File2DocumentService.get_by_document_id(doc.id)
|
||||
if f2d:
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
FileService.filter_delete(
|
||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
|
||||
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
for kb in kbs:
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
for kb in kbs:
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -278,8 +299,8 @@ def list_tags_from_kbs():
|
||||
|
||||
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def rm_tags(kb_id):
|
||||
req = request.json
|
||||
async def rm_tags(kb_id):
|
||||
req = await get_request_json()
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -298,8 +319,8 @@ def rm_tags(kb_id):
|
||||
|
||||
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def rename_tags(kb_id):
|
||||
req = request.json
|
||||
async def rename_tags(kb_id):
|
||||
req = await get_request_json()
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -402,7 +423,7 @@ def get_basic_info():
|
||||
|
||||
@manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_pipeline_logs():
|
||||
async def list_pipeline_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -421,7 +442,7 @@ def list_pipeline_logs():
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
@ -446,7 +467,7 @@ def list_pipeline_logs():
|
||||
|
||||
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_pipeline_dataset_logs():
|
||||
async def list_pipeline_dataset_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -463,7 +484,7 @@ def list_pipeline_dataset_logs():
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
@ -480,12 +501,12 @@ def list_pipeline_dataset_logs():
|
||||
|
||||
@manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def delete_pipeline_logs():
|
||||
async def delete_pipeline_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
log_ids = req.get("log_ids", [])
|
||||
|
||||
PipelineOperationLogService.delete_by_ids(log_ids)
|
||||
@ -509,8 +530,8 @@ def pipeline_log_detail():
|
||||
|
||||
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_graphrag():
|
||||
req = request.json
|
||||
async def run_graphrag():
|
||||
req = await get_request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -578,8 +599,8 @@ def trace_graphrag():
|
||||
|
||||
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_raptor():
|
||||
req = request.json
|
||||
async def run_raptor():
|
||||
req = await get_request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -647,8 +668,8 @@ def trace_raptor():
|
||||
|
||||
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_mindmap():
|
||||
req = request.json
|
||||
async def run_mindmap():
|
||||
req = await get_request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -731,6 +752,8 @@ def delete_kb_task():
|
||||
def cancel_task(task_id):
|
||||
REDIS_CONN.set(f"{task_id}-cancel", "x")
|
||||
|
||||
kb_task_id_field: str = ""
|
||||
kb_task_finish_at: str = ""
|
||||
match pipeline_task_type:
|
||||
case PipelineTaskType.GRAPH_RAG:
|
||||
kb_task_id_field = "graphrag_task_id"
|
||||
@ -761,7 +784,7 @@ def delete_kb_task():
|
||||
|
||||
@manager.route("/check_embedding", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
def check_embedding():
|
||||
async def check_embedding():
|
||||
|
||||
def _guess_vec_field(src: dict) -> str | None:
|
||||
for k in src or {}:
|
||||
@ -807,12 +830,12 @@ def check_embedding():
|
||||
offset=0, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
)
|
||||
total = docStoreConn.getTotal(res0)
|
||||
total = docStoreConn.get_total(res0)
|
||||
if total <= 0:
|
||||
return []
|
||||
|
||||
n = min(n, total)
|
||||
offsets = sorted(random.sample(range(total), n))
|
||||
offsets = sorted(random.sample(range(min(total,1000)), n))
|
||||
out = []
|
||||
|
||||
for off in offsets:
|
||||
@ -824,7 +847,7 @@ def check_embedding():
|
||||
offset=off, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
)
|
||||
ids = docStoreConn.getChunkIds(res1)
|
||||
ids = docStoreConn.get_chunk_ids(res1)
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
@ -845,9 +868,14 @@ def check_embedding():
|
||||
"position_int": full_doc.get("position_int"),
|
||||
"top_int": full_doc.get("top_int"),
|
||||
"content_with_weight": full_doc.get("content_with_weight") or "",
|
||||
"question_kwd": full_doc.get("question_kwd") or []
|
||||
})
|
||||
return out
|
||||
req = request.json
|
||||
|
||||
def _clean(s: str) -> str:
|
||||
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
|
||||
return s if s else "None"
|
||||
req = await get_request_json()
|
||||
kb_id = req.get("kb_id", "")
|
||||
embd_id = req.get("embd_id", "")
|
||||
n = int(req.get("check_num", 5))
|
||||
@ -859,8 +887,10 @@ def check_embedding():
|
||||
|
||||
results, eff_sims = [], []
|
||||
for ck in samples:
|
||||
txt = (ck.get("content_with_weight") or "").strip()
|
||||
if not txt:
|
||||
title = ck.get("doc_name") or "Title"
|
||||
txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
|
||||
txt_in = _clean(txt_in)
|
||||
if not txt_in:
|
||||
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
|
||||
continue
|
||||
|
||||
@ -869,10 +899,19 @@ def check_embedding():
|
||||
continue
|
||||
|
||||
try:
|
||||
qv, _ = emb_mdl.encode_queries(txt)
|
||||
sim = _cos_sim(qv, ck["vector"])
|
||||
except Exception:
|
||||
return get_error_data_result(message="embedding failure")
|
||||
v, _ = emb_mdl.encode([title, txt_in])
|
||||
assert len(v[1]) == len(ck["vector"]), f"The dimension ({len(v[1])}) of given embedding model is different from the original ({len(ck['vector'])})"
|
||||
sim_content = _cos_sim(v[1], ck["vector"])
|
||||
title_w = 0.1
|
||||
qv_mix = title_w * v[0] + (1 - title_w) * v[1]
|
||||
sim_mix = _cos_sim(qv_mix, ck["vector"])
|
||||
sim = sim_content
|
||||
mode = "content_only"
|
||||
if sim_mix > sim:
|
||||
sim = sim_mix
|
||||
mode = "title+content"
|
||||
except Exception as e:
|
||||
return get_error_data_result(message=f"Embedding failure. {e}")
|
||||
|
||||
eff_sims.append(sim)
|
||||
results.append({
|
||||
@ -892,9 +931,8 @@ def check_embedding():
|
||||
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"match_mode": mode,
|
||||
}
|
||||
if summary["avg_cos_sim"] > 0.99:
|
||||
if summary["avg_cos_sim"] > 0.9:
|
||||
return get_json_result(data={"summary": summary, "results": results})
|
||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})
|
||||
|
||||
|
||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})
|
||||
|
||||
@ -15,28 +15,28 @@
|
||||
#
|
||||
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
from api.apps import current_user, login_required
|
||||
from langfuse import Langfuse
|
||||
|
||||
from api.db.db_models import DB
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.utils.api_utils import get_error_data_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_error_data_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
|
||||
|
||||
@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("secret_key", "public_key", "host")
|
||||
def set_api_key():
|
||||
req = request.get_json()
|
||||
async def set_api_key():
|
||||
req = await get_request_json()
|
||||
secret_key = req.get("secret_key", "")
|
||||
public_key = req.get("public_key", "")
|
||||
host = req.get("host", "")
|
||||
if not all([secret_key, public_key, host]):
|
||||
return get_error_data_result(message="Missing required fields")
|
||||
|
||||
current_user_id = current_user.id
|
||||
langfuse_keys = dict(
|
||||
tenant_id=current_user.id,
|
||||
tenant_id=current_user_id,
|
||||
secret_key=secret_key,
|
||||
public_key=public_key,
|
||||
host=host,
|
||||
@ -46,23 +46,24 @@ def set_api_key():
|
||||
if not langfuse.auth_check():
|
||||
return get_error_data_result(message="Invalid Langfuse keys")
|
||||
|
||||
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
|
||||
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
|
||||
with DB.atomic():
|
||||
try:
|
||||
if not langfuse_entry:
|
||||
TenantLangfuseService.save(**langfuse_keys)
|
||||
else:
|
||||
TenantLangfuseService.update_by_tenant(tenant_id=current_user.id, langfuse_keys=langfuse_keys)
|
||||
TenantLangfuseService.update_by_tenant(tenant_id=current_user_id, langfuse_keys=langfuse_keys)
|
||||
return get_json_result(data=langfuse_keys)
|
||||
except Exception as e:
|
||||
server_error_response(e)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/api_key", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request()
|
||||
def get_api_key():
|
||||
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user.id)
|
||||
current_user_id = current_user.id
|
||||
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user_id)
|
||||
if not langfuse_entry:
|
||||
return get_json_result(message="Have not record any Langfuse keys.")
|
||||
|
||||
@ -73,7 +74,7 @@ def get_api_key():
|
||||
except langfuse.api.core.api_error.ApiError as api_err:
|
||||
return get_json_result(message=f"Error from Langfuse: {api_err}")
|
||||
except Exception as e:
|
||||
server_error_response(e)
|
||||
return server_error_response(e)
|
||||
|
||||
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
|
||||
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
|
||||
@ -85,7 +86,8 @@ def get_api_key():
|
||||
@login_required
|
||||
@validate_request()
|
||||
def delete_api_key():
|
||||
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
|
||||
current_user_id = current_user.id
|
||||
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
|
||||
if not langfuse_entry:
|
||||
return get_json_result(message="Have not record any Langfuse keys.")
|
||||
|
||||
@ -94,4 +96,4 @@ def delete_api_key():
|
||||
TenantLangfuseService.delete_model(langfuse_entry)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
server_error_response(e)
|
||||
return server_error_response(e)
|
||||
|
||||
@ -16,16 +16,16 @@
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
|
||||
from api.apps import login_required, current_user
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from common.constants import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.utils.api_utils import get_json_result, get_allowed_llm_factories
|
||||
from rag.utils.base64_image import test_image
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel
|
||||
|
||||
|
||||
@manager.route("/factories", methods=["GET"]) # noqa: F821
|
||||
@ -43,7 +43,13 @@ def factories():
|
||||
mdl_types[m.fid] = set([])
|
||||
mdl_types[m.fid].add(m.model_type)
|
||||
for f in fac:
|
||||
f["model_types"] = list(mdl_types.get(f["name"], [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS]))
|
||||
f["model_types"] = list(
|
||||
mdl_types.get(
|
||||
f["name"],
|
||||
[LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS, LLMType.OCR],
|
||||
)
|
||||
)
|
||||
|
||||
return get_json_result(data=fac)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -52,8 +58,8 @@ def factories():
|
||||
@manager.route("/set_api_key", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "api_key")
|
||||
def set_api_key():
|
||||
req = request.json
|
||||
async def set_api_key():
|
||||
req = await get_request_json()
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
factory = req["llm_factory"]
|
||||
@ -74,7 +80,7 @@ def set_api_key():
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
|
||||
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
|
||||
if m.find("**ERROR**") >= 0:
|
||||
raise Exception(m)
|
||||
chat_passed = True
|
||||
@ -122,8 +128,8 @@ def set_api_key():
|
||||
@manager.route("/add_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
def add_llm():
|
||||
req = request.json
|
||||
async def add_llm():
|
||||
req = await get_request_json()
|
||||
factory = req["llm_factory"]
|
||||
api_key = req.get("api_key", "x")
|
||||
llm_name = req.get("llm_name")
|
||||
@ -142,16 +148,16 @@ def add_llm():
|
||||
|
||||
elif factory == "Tencent Hunyuan":
|
||||
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
|
||||
return set_api_key()
|
||||
return await set_api_key()
|
||||
|
||||
elif factory == "Tencent Cloud":
|
||||
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
|
||||
return set_api_key()
|
||||
return await set_api_key()
|
||||
|
||||
elif factory == "Bedrock":
|
||||
# For Bedrock, due to its special authentication method
|
||||
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
||||
api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])
|
||||
api_key = apikey_json(["auth_mode", "bedrock_ak", "bedrock_sk", "bedrock_region", "aws_role_arn"])
|
||||
|
||||
elif factory == "LocalAI":
|
||||
llm_name += "___LocalAI"
|
||||
@ -186,6 +192,9 @@ def add_llm():
|
||||
elif factory == "OpenRouter":
|
||||
api_key = apikey_json(["api_key", "provider_order"])
|
||||
|
||||
elif factory == "MinerU":
|
||||
api_key = apikey_json(["api_key", "provider_order"])
|
||||
|
||||
llm = {
|
||||
"tenant_id": current_user.id,
|
||||
"llm_factory": factory,
|
||||
@ -217,7 +226,7 @@ def add_llm():
|
||||
**extra,
|
||||
)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||
if not tc and m.find("**ERROR**:") >= 0:
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
@ -251,6 +260,15 @@ def add_llm():
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.OCR.value:
|
||||
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = OcrModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm.get("api_base", ""))
|
||||
ok, reason = mdl.check_available()
|
||||
if not ok:
|
||||
raise RuntimeError(reason or "Model not available")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
else:
|
||||
# TODO: check other type of models
|
||||
pass
|
||||
@ -267,8 +285,8 @@ def add_llm():
|
||||
@manager.route("/delete_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
def delete_llm():
|
||||
req = request.json
|
||||
async def delete_llm():
|
||||
req = await get_request_json()
|
||||
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -276,8 +294,8 @@ def delete_llm():
|
||||
@manager.route("/enable_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
def enable_llm():
|
||||
req = request.json
|
||||
async def enable_llm():
|
||||
req = await get_request_json()
|
||||
TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
|
||||
)
|
||||
@ -287,8 +305,8 @@ def enable_llm():
|
||||
@manager.route("/delete_factory", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
def delete_factory():
|
||||
req = request.json
|
||||
async def delete_factory():
|
||||
req = await get_request_json()
|
||||
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -297,6 +315,7 @@ def delete_factory():
|
||||
@login_required
|
||||
def my_llms():
|
||||
try:
|
||||
TenantLLMService.ensure_mineru_from_env(current_user.id)
|
||||
include_details = request.args.get("include_details", "false").lower() == "true"
|
||||
|
||||
if include_details:
|
||||
@ -344,6 +363,7 @@ def list_app():
|
||||
weighted = []
|
||||
model_type = request.args.get("model_type")
|
||||
try:
|
||||
TenantLLMService.ensure_mineru_from_env(current_user.id)
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value])
|
||||
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
|
||||
|
||||
@ -13,8 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
import asyncio
|
||||
|
||||
from quart import Response, request
|
||||
from api.apps import current_user, login_required
|
||||
|
||||
from api.db.db_models import MCPServer
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
@ -22,15 +24,14 @@ from api.db.services.user_service import TenantService
|
||||
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
|
||||
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
get_mcp_tools
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
|
||||
from api.utils.web_utils import get_float, safe_json_parse
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_mcp() -> Response:
|
||||
async def list_mcp() -> Response:
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
@ -40,7 +41,7 @@ def list_mcp() -> Response:
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
try:
|
||||
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
|
||||
@ -72,8 +73,8 @@ def detail() -> Response:
|
||||
@manager.route("/create", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "url", "server_type")
|
||||
def create() -> Response:
|
||||
req = request.get_json()
|
||||
async def create() -> Response:
|
||||
req = await get_request_json()
|
||||
|
||||
server_type = req.get("server_type", "")
|
||||
if server_type not in VALID_MCP_SERVER_TYPES:
|
||||
@ -107,7 +108,7 @@ def create() -> Response:
|
||||
return get_data_error_result(message="Tenant not found.")
|
||||
|
||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
return get_data_error_result(err_message)
|
||||
|
||||
@ -127,8 +128,8 @@ def create() -> Response:
|
||||
@manager.route("/update", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id")
|
||||
def update() -> Response:
|
||||
req = request.get_json()
|
||||
async def update() -> Response:
|
||||
req = await get_request_json()
|
||||
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
@ -159,7 +160,7 @@ def update() -> Response:
|
||||
req["id"] = mcp_id
|
||||
|
||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
return get_data_error_result(err_message)
|
||||
|
||||
@ -183,8 +184,8 @@ def update() -> Response:
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def rm() -> Response:
|
||||
req = request.get_json()
|
||||
async def rm() -> Response:
|
||||
req = await get_request_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
|
||||
try:
|
||||
@ -201,8 +202,8 @@ def rm() -> Response:
|
||||
@manager.route("/import", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcpServers")
|
||||
def import_multiple() -> Response:
|
||||
req = request.get_json()
|
||||
async def import_multiple() -> Response:
|
||||
req = await get_request_json()
|
||||
servers = req.get("mcpServers", {})
|
||||
if not servers:
|
||||
return get_data_error_result(message="No MCP servers provided.")
|
||||
@ -243,7 +244,7 @@ def import_multiple() -> Response:
|
||||
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
|
||||
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
|
||||
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
|
||||
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
results.append({"server": base_name, "success": False, "message": err_message})
|
||||
continue
|
||||
@ -268,8 +269,8 @@ def import_multiple() -> Response:
|
||||
@manager.route("/export", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def export_multiple() -> Response:
|
||||
req = request.get_json()
|
||||
async def export_multiple() -> Response:
|
||||
req = await get_request_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
|
||||
if not mcp_ids:
|
||||
@ -300,8 +301,8 @@ def export_multiple() -> Response:
|
||||
@manager.route("/list_tools", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def list_tools() -> Response:
|
||||
req = request.get_json()
|
||||
async def list_tools() -> Response:
|
||||
req = await get_request_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
if not mcp_ids:
|
||||
return get_data_error_result(message="No MCP server IDs provided.")
|
||||
@ -323,7 +324,7 @@ def list_tools() -> Response:
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
|
||||
try:
|
||||
tools = tool_call_session.get_tools(timeout)
|
||||
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
|
||||
except Exception as e:
|
||||
tools = []
|
||||
return get_data_error_result(message=f"MCP list tools error: {e}")
|
||||
@ -341,14 +342,14 @@ def list_tools() -> Response:
|
||||
return server_error_response(e)
|
||||
finally:
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
|
||||
|
||||
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id", "tool_name", "arguments")
|
||||
def test_tool() -> Response:
|
||||
req = request.get_json()
|
||||
async def test_tool() -> Response:
|
||||
req = await get_request_json()
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
@ -368,10 +369,10 @@ def test_tool() -> Response:
|
||||
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
result = tool_call_session.tool_call(tool_name, arguments, timeout)
|
||||
result = await asyncio.to_thread(tool_call_session.tool_call, tool_name, arguments, timeout)
|
||||
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -380,8 +381,8 @@ def test_tool() -> Response:
|
||||
@manager.route("/cache_tools", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id", "tools")
|
||||
def cache_tool() -> Response:
|
||||
req = request.get_json()
|
||||
async def cache_tool() -> Response:
|
||||
req = await get_request_json()
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
@ -403,8 +404,8 @@ def cache_tool() -> Response:
|
||||
|
||||
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821
|
||||
@validate_request("url", "server_type")
|
||||
def test_mcp() -> Response:
|
||||
req = request.get_json()
|
||||
async def test_mcp() -> Response:
|
||||
req = await get_request_json()
|
||||
|
||||
url = req.get("url", "")
|
||||
if not url:
|
||||
@ -425,13 +426,13 @@ def test_mcp() -> Response:
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
|
||||
try:
|
||||
tools = tool_call_session.get_tools(timeout)
|
||||
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
|
||||
except Exception as e:
|
||||
tools = []
|
||||
return get_data_error_result(message=f"Test MCP error: {e}")
|
||||
finally:
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
close_multiple_mcp_toolcall_sessions([tool_call_session])
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, [tool_call_session])
|
||||
|
||||
for tool in tools:
|
||||
tool_dict = tool.model_dump()
|
||||
|
||||
185
api/apps/memories_app.py
Normal file
185
api/apps/memories_app.py
Normal file
@ -0,0 +1,185 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
|
||||
from quart import request
|
||||
from api.apps import login_required, current_user
|
||||
from api.db import TenantPermission
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
|
||||
not_allowed_parameters
|
||||
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
||||
from common.constants import MemoryType, RetCode, ForgettingPolicy
|
||||
|
||||
|
||||
@manager.route("", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "memory_type", "embd_id", "llm_id")
|
||||
async def create_memory():
|
||||
req = await get_request_json()
|
||||
# check name length
|
||||
name = req["name"]
|
||||
memory_name = name.strip()
|
||||
if len(memory_name) == 0:
|
||||
return get_error_argument_result("Memory name cannot be empty or whitespace.")
|
||||
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
|
||||
# check memory_type valid
|
||||
memory_type = set(req["memory_type"])
|
||||
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
||||
if invalid_type:
|
||||
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
|
||||
memory_type = list(memory_type)
|
||||
|
||||
try:
|
||||
res, memory = MemoryService.create_memory(
|
||||
tenant_id=current_user.id,
|
||||
name=memory_name,
|
||||
memory_type=memory_type,
|
||||
embd_id=req["embd_id"],
|
||||
llm_id=req["llm_id"]
|
||||
)
|
||||
|
||||
if res:
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||
|
||||
else:
|
||||
return get_json_result(message=memory, code=RetCode.SERVER_ERROR)
|
||||
|
||||
except Exception as e:
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/<memory_id>", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type", "embd_id")
|
||||
async def update_memory(memory_id):
|
||||
req = await get_request_json()
|
||||
update_dict = {}
|
||||
# check name length
|
||||
if "name" in req:
|
||||
name = req["name"]
|
||||
memory_name = name.strip()
|
||||
if len(memory_name) == 0:
|
||||
return get_error_argument_result("Memory name cannot be empty or whitespace.")
|
||||
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
|
||||
update_dict["name"] = memory_name
|
||||
# check permissions valid
|
||||
if req.get("permissions"):
|
||||
if req["permissions"] not in [e.value for e in TenantPermission]:
|
||||
return get_error_argument_result(f"Unknown permission '{req['permissions']}'.")
|
||||
update_dict["permissions"] = req["permissions"]
|
||||
if req.get("llm_id"):
|
||||
update_dict["llm_id"] = req["llm_id"]
|
||||
# check memory_size valid
|
||||
if req.get("memory_size"):
|
||||
if not 0 < int(req["memory_size"]) <= MEMORY_SIZE_LIMIT:
|
||||
return get_error_argument_result(f"Memory size should be in range (0, {MEMORY_SIZE_LIMIT}] Bytes.")
|
||||
update_dict["memory_size"] = req["memory_size"]
|
||||
# check forgetting_policy valid
|
||||
if req.get("forgetting_policy"):
|
||||
if req["forgetting_policy"] not in [e.value for e in ForgettingPolicy]:
|
||||
return get_error_argument_result(f"Forgetting policy '{req['forgetting_policy']}' is not supported.")
|
||||
update_dict["forgetting_policy"] = req["forgetting_policy"]
|
||||
# check temperature valid
|
||||
if "temperature" in req:
|
||||
temperature = float(req["temperature"])
|
||||
if not 0 <= temperature <= 1:
|
||||
return get_error_argument_result("Temperature should be in range [0, 1].")
|
||||
update_dict["temperature"] = temperature
|
||||
# allow update to empty fields
|
||||
for field in ["avatar", "description", "system_prompt", "user_prompt"]:
|
||||
if field in req:
|
||||
update_dict[field] = req[field]
|
||||
current_memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not current_memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
|
||||
memory_dict = current_memory.to_dict()
|
||||
memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)})
|
||||
to_update = {}
|
||||
for k, v in update_dict.items():
|
||||
if isinstance(v, list) and set(memory_dict[k]) != set(v):
|
||||
to_update[k] = v
|
||||
elif memory_dict[k] != v:
|
||||
to_update[k] = v
|
||||
|
||||
if not to_update:
|
||||
return get_json_result(message=True, data=memory_dict)
|
||||
|
||||
try:
|
||||
MemoryService.update_memory(memory_id, to_update)
|
||||
updated_memory = MemoryService.get_by_memory_id(memory_id)
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory))
|
||||
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
async def delete_memory(memory_id):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return get_json_result(message=True, code=RetCode.NOT_FOUND)
|
||||
try:
|
||||
MemoryService.delete_memory(memory_id)
|
||||
return get_json_result(message=True)
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def list_memory():
|
||||
args = request.args
|
||||
try:
|
||||
tenant_ids = args.getlist("tenant_id")
|
||||
memory_types = args.getlist("memory_type")
|
||||
storage_type = args.get("storage_type")
|
||||
keywords = args.get("keywords", "")
|
||||
page = int(args.get("page", 1))
|
||||
page_size = int(args.get("page_size", 50))
|
||||
# make filter dict
|
||||
filter_dict = {"memory_type": memory_types, "storage_type": storage_type}
|
||||
if not tenant_ids:
|
||||
# restrict to current user's tenants
|
||||
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
|
||||
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
|
||||
else:
|
||||
filter_dict["tenant_id"] = tenant_ids
|
||||
|
||||
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
|
||||
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
|
||||
return get_json_result(message=True, data={"memory_list": memory_list, "total_count": count})
|
||||
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/<memory_id>/config", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_memory_config(memory_id):
|
||||
memory = MemoryService.get_with_owner_name_by_id(memory_id)
|
||||
if not memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||
@ -15,8 +15,8 @@
|
||||
#
|
||||
|
||||
|
||||
from flask import Response
|
||||
from flask_login import login_required
|
||||
from quart import Response
|
||||
from api.apps import login_required
|
||||
from api.utils.api_utils import get_json_result
|
||||
from plugin import GlobalPluginManager
|
||||
|
||||
|
||||
@ -14,20 +14,29 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import jwt
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.db import CanvasCategory
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required
|
||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
|
||||
from api.utils.api_utils import get_result
|
||||
from flask import request, Response
|
||||
from quart import request, Response
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||
@ -41,19 +50,19 @@ def list_agents(tenant_id):
|
||||
return get_error_data_result("The agent doesn't exist.")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 30))
|
||||
orderby = request.args.get("orderby", "update_time")
|
||||
order_by = request.args.get("orderby", "update_time")
|
||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title)
|
||||
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, order_by, desc, id, title)
|
||||
return get_result(data=canvas)
|
||||
|
||||
|
||||
@manager.route("/agents", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create_agent(tenant_id: str):
|
||||
req: dict[str, Any] = cast(dict[str, Any], request.json)
|
||||
async def create_agent(tenant_id: str):
|
||||
req: dict[str, Any] = cast(dict[str, Any], await get_request_json())
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
@ -89,8 +98,8 @@ def create_agent(tenant_id: str):
|
||||
|
||||
@manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update_agent(tenant_id: str, agent_id: str):
|
||||
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], request.json).items() if v is not None}
|
||||
async def update_agent(tenant_id: str, agent_id: str):
|
||||
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await get_request_json())).items() if v is not None}
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
@ -132,48 +141,776 @@ def delete_agent(tenant_id: str, agent_id: str):
|
||||
UserCanvasService.delete_by_id(agent_id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
@manager.route("/webhook/<agent_id>", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
|
||||
@manager.route("/webhook_test/<agent_id>",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821
|
||||
async def webhook(agent_id: str):
|
||||
is_test = request.path.startswith("/api/v1/webhook_test")
|
||||
start_ts = time.time()
|
||||
|
||||
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def webhook(tenant_id: str, agent_id: str):
|
||||
req = request.json
|
||||
if not UserCanvasService.accessible(req["id"], tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
# 1. Fetch canvas by agent_id
|
||||
exists, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
if not exists:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST
|
||||
|
||||
# 2. Check canvas category
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
return get_data_error_result(message="Dataflow can not be triggered by webhook.")
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST
|
||||
|
||||
# 3. Load DSL from canvas
|
||||
dsl = getattr(cvs, "dsl", None)
|
||||
if not isinstance(dsl, dict):
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
|
||||
|
||||
# 4. Check webhook configuration in DSL
|
||||
components = dsl.get("components", {})
|
||||
for k, _ in components.items():
|
||||
cpn_obj = components[k]["obj"]
|
||||
if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook":
|
||||
webhook_cfg = cpn_obj["params"]
|
||||
|
||||
if not webhook_cfg:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST
|
||||
|
||||
# 5. Validate request method against webhook_cfg.methods
|
||||
allowed_methods = webhook_cfg.get("methods", [])
|
||||
request_method = request.method.upper()
|
||||
if allowed_methods and request_method not in allowed_methods:
|
||||
return get_data_error_result(
|
||||
code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook."
|
||||
),RetCode.BAD_REQUEST
|
||||
|
||||
# 6. Validate webhook security
|
||||
async def validate_webhook_security(security_cfg: dict):
|
||||
"""Validate webhook security rules based on security configuration."""
|
||||
|
||||
if not security_cfg:
|
||||
return # No security config → allowed by default
|
||||
|
||||
# 1. Validate max body size
|
||||
await _validate_max_body_size(security_cfg)
|
||||
|
||||
# 2. Validate IP whitelist
|
||||
_validate_ip_whitelist(security_cfg)
|
||||
|
||||
# # 3. Validate rate limiting
|
||||
_validate_rate_limit(security_cfg)
|
||||
|
||||
# 4. Validate authentication
|
||||
auth_type = security_cfg.get("auth_type", "none")
|
||||
|
||||
if auth_type == "none":
|
||||
return
|
||||
|
||||
if auth_type == "token":
|
||||
_validate_token_auth(security_cfg)
|
||||
|
||||
elif auth_type == "basic":
|
||||
_validate_basic_auth(security_cfg)
|
||||
|
||||
elif auth_type == "jwt":
|
||||
_validate_jwt_auth(security_cfg)
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported auth_type: {auth_type}")
|
||||
|
||||
async def _validate_max_body_size(security_cfg):
|
||||
"""Check request size does not exceed max_body_size."""
|
||||
max_size = security_cfg.get("max_body_size")
|
||||
if not max_size:
|
||||
return
|
||||
|
||||
# Convert "10MB" → bytes
|
||||
units = {"kb": 1024, "mb": 1024**2}
|
||||
size_str = max_size.lower()
|
||||
|
||||
for suffix, factor in units.items():
|
||||
if size_str.endswith(suffix):
|
||||
limit = int(size_str.replace(suffix, "")) * factor
|
||||
break
|
||||
else:
|
||||
raise Exception("Invalid max_body_size format")
|
||||
MAX_LIMIT = 10 * 1024 * 1024 # 10MB
|
||||
if limit > MAX_LIMIT:
|
||||
raise Exception("max_body_size exceeds maximum allowed size (10MB)")
|
||||
|
||||
content_length = request.content_length or 0
|
||||
if content_length > limit:
|
||||
raise Exception(f"Request body too large: {content_length} > {limit}")
|
||||
|
||||
def _validate_ip_whitelist(security_cfg):
|
||||
"""Allow only IPs listed in ip_whitelist."""
|
||||
whitelist = security_cfg.get("ip_whitelist", [])
|
||||
if not whitelist:
|
||||
return
|
||||
|
||||
client_ip = request.remote_addr
|
||||
|
||||
|
||||
for rule in whitelist:
|
||||
if "/" in rule:
|
||||
# CIDR notation
|
||||
if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
|
||||
return
|
||||
else:
|
||||
# Single IP
|
||||
if client_ip == rule:
|
||||
return
|
||||
|
||||
raise Exception(f"IP {client_ip} is not allowed by whitelist")
|
||||
|
||||
def _validate_rate_limit(security_cfg):
|
||||
"""Simple in-memory rate limiting."""
|
||||
rl = security_cfg.get("rate_limit")
|
||||
if not rl:
|
||||
return
|
||||
|
||||
limit = int(rl.get("limit", 60))
|
||||
if limit <= 0:
|
||||
raise Exception("rate_limit.limit must be > 0")
|
||||
per = rl.get("per", "minute")
|
||||
|
||||
window = {
|
||||
"second": 1,
|
||||
"minute": 60,
|
||||
"hour": 3600,
|
||||
"day": 86400,
|
||||
}.get(per)
|
||||
|
||||
if not window:
|
||||
raise Exception(f"Invalid rate_limit.per: {per}")
|
||||
|
||||
capacity = limit
|
||||
rate = limit / window
|
||||
cost = 1
|
||||
|
||||
key = f"rl:tb:{agent_id}"
|
||||
now = time.time()
|
||||
|
||||
try:
|
||||
res = REDIS_CONN.lua_token_bucket(
|
||||
keys=[key],
|
||||
args=[capacity, rate, now, cost],
|
||||
client=REDIS_CONN.REDIS,
|
||||
)
|
||||
|
||||
allowed = int(res[0])
|
||||
if allowed != 1:
|
||||
raise Exception("Too many requests (rate limit exceeded)")
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Rate limit error: {e}")
|
||||
|
||||
def _validate_token_auth(security_cfg):
|
||||
"""Validate header-based token authentication."""
|
||||
token_cfg = security_cfg.get("token",{})
|
||||
header = token_cfg.get("token_header")
|
||||
token_value = token_cfg.get("token_value")
|
||||
|
||||
provided = request.headers.get(header)
|
||||
if provided != token_value:
|
||||
raise Exception("Invalid token authentication")
|
||||
|
||||
def _validate_basic_auth(security_cfg):
|
||||
"""Validate HTTP Basic Auth credentials."""
|
||||
auth_cfg = security_cfg.get("basic_auth", {})
|
||||
username = auth_cfg.get("username")
|
||||
password = auth_cfg.get("password")
|
||||
|
||||
auth = request.authorization
|
||||
if not auth or auth.username != username or auth.password != password:
|
||||
raise Exception("Invalid Basic Auth credentials")
|
||||
|
||||
def _validate_jwt_auth(security_cfg):
|
||||
"""Validate JWT token in Authorization header."""
|
||||
jwt_cfg = security_cfg.get("jwt", {})
|
||||
secret = jwt_cfg.get("secret")
|
||||
if not secret:
|
||||
raise Exception("JWT secret not configured")
|
||||
required_claims = jwt_cfg.get("required_claims", [])
|
||||
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise Exception("Missing Bearer token")
|
||||
|
||||
token = auth_header[len("Bearer "):].strip()
|
||||
if not token:
|
||||
raise Exception("Empty Bearer token")
|
||||
|
||||
alg = (jwt_cfg.get("algorithm") or "HS256").upper()
|
||||
|
||||
decode_kwargs = {
|
||||
"key": secret,
|
||||
"algorithms": [alg],
|
||||
}
|
||||
options = {}
|
||||
if jwt_cfg.get("audience"):
|
||||
decode_kwargs["audience"] = jwt_cfg["audience"]
|
||||
options["verify_aud"] = True
|
||||
else:
|
||||
options["verify_aud"] = False
|
||||
|
||||
if jwt_cfg.get("issuer"):
|
||||
decode_kwargs["issuer"] = jwt_cfg["issuer"]
|
||||
options["verify_iss"] = True
|
||||
else:
|
||||
options["verify_iss"] = False
|
||||
try:
|
||||
decoded = jwt.decode(
|
||||
token,
|
||||
options=options,
|
||||
**decode_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Invalid JWT: {str(e)}")
|
||||
|
||||
raw_required_claims = jwt_cfg.get("required_claims", [])
|
||||
if isinstance(raw_required_claims, str):
|
||||
required_claims = [raw_required_claims]
|
||||
elif isinstance(raw_required_claims, (list, tuple, set)):
|
||||
required_claims = list(raw_required_claims)
|
||||
else:
|
||||
required_claims = []
|
||||
|
||||
required_claims = [
|
||||
c for c in required_claims
|
||||
if isinstance(c, str) and c.strip()
|
||||
]
|
||||
|
||||
RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"}
|
||||
for claim in required_claims:
|
||||
if claim in RESERVED_CLAIMS:
|
||||
raise Exception(f"Reserved JWT claim cannot be required: {claim}")
|
||||
|
||||
for claim in required_claims:
|
||||
if claim not in decoded:
|
||||
raise Exception(f"Missing JWT claim: {claim}")
|
||||
|
||||
return decoded
|
||||
|
||||
try:
|
||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
||||
security_config=webhook_cfg.get("security", {})
|
||||
await validate_webhook_security(security_config)
|
||||
except Exception as e:
|
||||
return get_json_result(
|
||||
data=False, message=str(e),
|
||||
code=RetCode.EXCEPTION_ERROR)
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||
if not isinstance(cvs.dsl, str):
|
||||
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
try:
|
||||
canvas = Canvas(dsl, cvs.user_id, agent_id)
|
||||
except Exception as e:
|
||||
resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e))
|
||||
resp.status_code = RetCode.BAD_REQUEST
|
||||
return resp
|
||||
|
||||
# 7. Parse request body
|
||||
async def parse_webhook_request(content_type):
|
||||
"""Parse request based on content-type and return structured data."""
|
||||
|
||||
# 1. Query
|
||||
query_data = {k: v for k, v in request.args.items()}
|
||||
|
||||
# 2. Headers
|
||||
header_data = {k: v for k, v in request.headers.items()}
|
||||
|
||||
# 3. Body
|
||||
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
if ctype and ctype != content_type:
|
||||
raise ValueError(
|
||||
f"Invalid Content-Type: expect '{content_type}', got '{ctype}'"
|
||||
)
|
||||
|
||||
body_data: dict = {}
|
||||
|
||||
def sse():
|
||||
nonlocal canvas
|
||||
try:
|
||||
for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
if ctype == "application/json":
|
||||
body_data = await request.get_json() or {}
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
|
||||
elif ctype == "multipart/form-data":
|
||||
nonlocal canvas
|
||||
form = await request.form
|
||||
files = await request.files
|
||||
|
||||
resp = Response(sse(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
body_data = {}
|
||||
|
||||
for key, value in form.items():
|
||||
body_data[key] = value
|
||||
|
||||
if len(files) > 10:
|
||||
raise Exception("Too many uploaded files")
|
||||
for key, file in files.items():
|
||||
desc = FileService.upload_info(
|
||||
cvs.user_id, # user
|
||||
file, # FileStorage
|
||||
None # url (None for webhook)
|
||||
)
|
||||
file_parsed= await canvas.get_files_async([desc])
|
||||
body_data[key] = file_parsed
|
||||
|
||||
elif ctype == "application/x-www-form-urlencoded":
|
||||
form = await request.form
|
||||
body_data = dict(form)
|
||||
|
||||
else:
|
||||
# text/plain / octet-stream / empty / unknown
|
||||
raw = await request.get_data()
|
||||
if raw:
|
||||
try:
|
||||
body_data = json.loads(raw.decode("utf-8"))
|
||||
except Exception:
|
||||
body_data = {}
|
||||
else:
|
||||
body_data = {}
|
||||
|
||||
except Exception:
|
||||
body_data = {}
|
||||
|
||||
return {
|
||||
"query": query_data,
|
||||
"headers": header_data,
|
||||
"body": body_data,
|
||||
"content_type": ctype,
|
||||
}
|
||||
|
||||
def extract_by_schema(data, schema, name="section"):
|
||||
"""
|
||||
Extract only fields defined in schema.
|
||||
Required fields must exist.
|
||||
Optional fields default to type-based default values.
|
||||
Type validation included.
|
||||
"""
|
||||
props = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
|
||||
extracted = {}
|
||||
|
||||
for field, field_schema in props.items():
|
||||
field_type = field_schema.get("type")
|
||||
|
||||
# 1. Required field missing
|
||||
if field in required and field not in data:
|
||||
raise Exception(f"{name} missing required field: {field}")
|
||||
|
||||
# 2. Optional → default value
|
||||
if field not in data:
|
||||
extracted[field] = default_for_type(field_type)
|
||||
continue
|
||||
|
||||
raw_value = data[field]
|
||||
|
||||
# 3. Auto convert value
|
||||
try:
|
||||
value = auto_cast_value(raw_value, field_type)
|
||||
except Exception as e:
|
||||
raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
|
||||
|
||||
# 4. Type validation
|
||||
if not validate_type(value, field_type):
|
||||
raise Exception(
|
||||
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
|
||||
)
|
||||
|
||||
extracted[field] = value
|
||||
|
||||
return extracted
|
||||
|
||||
|
||||
def default_for_type(t):
|
||||
"""Return default value for the given schema type."""
|
||||
if t == "file":
|
||||
return []
|
||||
if t == "object":
|
||||
return {}
|
||||
if t == "boolean":
|
||||
return False
|
||||
if t == "number":
|
||||
return 0
|
||||
if t == "string":
|
||||
return ""
|
||||
if t and t.startswith("array"):
|
||||
return []
|
||||
if t == "null":
|
||||
return None
|
||||
return None
|
||||
|
||||
def auto_cast_value(value, expected_type):
|
||||
"""Convert string values into schema type when possible."""
|
||||
|
||||
# Non-string values already good
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
v = value.strip()
|
||||
|
||||
# Boolean
|
||||
if expected_type == "boolean":
|
||||
if v.lower() in ["true", "1"]:
|
||||
return True
|
||||
if v.lower() in ["false", "0"]:
|
||||
return False
|
||||
raise Exception(f"Cannot convert '{value}' to boolean")
|
||||
|
||||
# Number
|
||||
if expected_type == "number":
|
||||
# integer
|
||||
if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
|
||||
return int(v)
|
||||
|
||||
# float
|
||||
try:
|
||||
return float(v)
|
||||
except Exception:
|
||||
raise Exception(f"Cannot convert '{value}' to number")
|
||||
|
||||
# Object
|
||||
if expected_type == "object":
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
else:
|
||||
raise Exception("JSON is not an object")
|
||||
except Exception:
|
||||
raise Exception(f"Cannot convert '{value}' to object")
|
||||
|
||||
# Array <T>
|
||||
if expected_type.startswith("array"):
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
if isinstance(parsed, list):
|
||||
return parsed
|
||||
else:
|
||||
raise Exception("JSON is not an array")
|
||||
except Exception:
|
||||
raise Exception(f"Cannot convert '{value}' to array")
|
||||
|
||||
# String (accept original)
|
||||
if expected_type == "string":
|
||||
return value
|
||||
|
||||
# File
|
||||
if expected_type == "file":
|
||||
return value
|
||||
# Default: do nothing
|
||||
return value
|
||||
|
||||
|
||||
def validate_type(value, t):
|
||||
"""Validate value type against schema type t."""
|
||||
if t == "file":
|
||||
return isinstance(value, list)
|
||||
|
||||
if t == "string":
|
||||
return isinstance(value, str)
|
||||
|
||||
if t == "number":
|
||||
return isinstance(value, (int, float))
|
||||
|
||||
if t == "boolean":
|
||||
return isinstance(value, bool)
|
||||
|
||||
if t == "object":
|
||||
return isinstance(value, dict)
|
||||
|
||||
# array<string> / array<number> / array<object>
|
||||
if t.startswith("array"):
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
|
||||
if "<" in t and ">" in t:
|
||||
inner = t[t.find("<") + 1 : t.find(">")]
|
||||
|
||||
# Check each element type
|
||||
for item in value:
|
||||
if not validate_type(item, inner):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return True
|
||||
parsed = await parse_webhook_request(webhook_cfg.get("content_types"))
|
||||
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
|
||||
|
||||
# Extract strictly by schema
|
||||
try:
|
||||
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
|
||||
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
|
||||
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
|
||||
except Exception as e:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||
|
||||
clean_request = {
|
||||
"query": query_clean,
|
||||
"headers": header_clean,
|
||||
"body": body_clean,
|
||||
"input": parsed
|
||||
}
|
||||
|
||||
execution_mode = webhook_cfg.get("execution_mode", "Immediately")
|
||||
response_cfg = webhook_cfg.get("response", {})
|
||||
|
||||
def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600):
|
||||
key = f"webhook-trace-{agent_id}-logs"
|
||||
|
||||
raw = REDIS_CONN.get(key)
|
||||
obj = json.loads(raw) if raw else {"webhooks": {}}
|
||||
|
||||
ws = obj["webhooks"].setdefault(
|
||||
str(start_ts),
|
||||
{"start_ts": start_ts, "events": []}
|
||||
)
|
||||
|
||||
ws["events"].append({
|
||||
"ts": time.time(),
|
||||
**event
|
||||
})
|
||||
|
||||
REDIS_CONN.set_obj(key, obj, ttl)
|
||||
|
||||
if execution_mode == "Immediately":
|
||||
status = response_cfg.get("status", 200)
|
||||
try:
|
||||
status = int(status)
|
||||
except (TypeError, ValueError):
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST
|
||||
|
||||
if not (200 <= status <= 399):
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST
|
||||
|
||||
body_tpl = response_cfg.get("body_template", "")
|
||||
|
||||
def parse_body(body: str):
|
||||
if not body:
|
||||
return None, "application/json"
|
||||
|
||||
try:
|
||||
parsed = json.loads(body)
|
||||
return parsed, "application/json"
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return body, "text/plain"
|
||||
|
||||
|
||||
body, content_type = parse_body(body_tpl)
|
||||
resp = Response(
|
||||
json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body,
|
||||
status=status,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
async def background_run():
|
||||
try:
|
||||
async for ans in canvas.run(
|
||||
query="",
|
||||
user_id=cvs.user_id,
|
||||
webhook_payload=clean_request
|
||||
):
|
||||
if is_test:
|
||||
append_webhook_trace(agent_id, start_ts, ans)
|
||||
|
||||
if is_test:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict())
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Webhook background run failed")
|
||||
if is_test:
|
||||
try:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "error",
|
||||
"message": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
)
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": False,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Failed to append webhook trace")
|
||||
|
||||
asyncio.create_task(background_run())
|
||||
return resp
|
||||
else:
|
||||
async def sse():
|
||||
nonlocal canvas
|
||||
contents: list[str] = []
|
||||
|
||||
try:
|
||||
async for ans in canvas.run(
|
||||
query="",
|
||||
user_id=cvs.user_id,
|
||||
webhook_payload=clean_request,
|
||||
):
|
||||
if ans["event"] == "message":
|
||||
content = ans["data"]["content"]
|
||||
if ans["data"].get("start_to_think", False):
|
||||
content = "<think>"
|
||||
elif ans["data"].get("end_to_think", False):
|
||||
content = "</think>"
|
||||
if content:
|
||||
contents.append(content)
|
||||
if is_test:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
ans
|
||||
)
|
||||
if is_test:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
final_content = "".join(contents)
|
||||
yield json.dumps(final_content, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
if is_test:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "error",
|
||||
"message": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
)
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": False,
|
||||
}
|
||||
)
|
||||
yield json.dumps({"code": 500, "message": str(e)}, ensure_ascii=False)
|
||||
|
||||
resp = Response(sse(), mimetype="application/json")
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route("/webhook_trace/<agent_id>", methods=["GET"]) # noqa: F821
|
||||
async def webhook_trace(agent_id: str):
|
||||
def encode_webhook_id(start_ts: str) -> str:
|
||||
WEBHOOK_ID_SECRET = "webhook_id_secret"
|
||||
sig = hmac.new(
|
||||
WEBHOOK_ID_SECRET.encode("utf-8"),
|
||||
start_ts.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
return base64.urlsafe_b64encode(sig).decode("utf-8").rstrip("=")
|
||||
|
||||
def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None:
|
||||
for ts in webhooks.keys():
|
||||
if encode_webhook_id(ts) == enc_id:
|
||||
return ts
|
||||
return None
|
||||
since_ts = request.args.get("since_ts", type=float)
|
||||
webhook_id = request.args.get("webhook_id")
|
||||
|
||||
key = f"webhook-trace-{agent_id}-logs"
|
||||
raw = REDIS_CONN.get(key)
|
||||
|
||||
if since_ts is None:
|
||||
now = time.time()
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": None,
|
||||
"events": [],
|
||||
"next_since_ts": now,
|
||||
"finished": False,
|
||||
}
|
||||
)
|
||||
|
||||
if not raw:
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": None,
|
||||
"events": [],
|
||||
"next_since_ts": since_ts,
|
||||
"finished": False,
|
||||
}
|
||||
)
|
||||
|
||||
obj = json.loads(raw)
|
||||
webhooks = obj.get("webhooks", {})
|
||||
|
||||
if webhook_id is None:
|
||||
candidates = [
|
||||
float(k) for k in webhooks.keys() if float(k) > since_ts
|
||||
]
|
||||
|
||||
if not candidates:
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": None,
|
||||
"events": [],
|
||||
"next_since_ts": since_ts,
|
||||
"finished": False,
|
||||
}
|
||||
)
|
||||
|
||||
start_ts = min(candidates)
|
||||
real_id = str(start_ts)
|
||||
webhook_id = encode_webhook_id(real_id)
|
||||
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": webhook_id,
|
||||
"events": [],
|
||||
"next_since_ts": start_ts,
|
||||
"finished": False,
|
||||
}
|
||||
)
|
||||
|
||||
real_id = decode_webhook_id(webhook_id, webhooks)
|
||||
|
||||
if not real_id:
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": webhook_id,
|
||||
"events": [],
|
||||
"next_since_ts": since_ts,
|
||||
"finished": True,
|
||||
}
|
||||
)
|
||||
|
||||
ws = webhooks.get(str(real_id))
|
||||
events = ws.get("events", [])
|
||||
new_events = [e for e in events if e.get("ts", 0) > since_ts]
|
||||
|
||||
next_ts = since_ts
|
||||
for e in new_events:
|
||||
next_ts = max(next_ts, e["ts"])
|
||||
|
||||
finished = any(e.get("event") == "finished" for e in new_events)
|
||||
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": webhook_id,
|
||||
"events": new_events,
|
||||
"next_since_ts": next_ts,
|
||||
"finished": finished,
|
||||
}
|
||||
)
|
||||
|
||||
@ -14,22 +14,20 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
|
||||
from quart import request
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode, StatusEnum
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, get_request_json
|
||||
|
||||
|
||||
@manager.route("/chats", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
req = request.json
|
||||
async def create(tenant_id):
|
||||
req = await get_request_json()
|
||||
ids = [i for i in req.get("dataset_ids", []) if i]
|
||||
for kb_id in ids:
|
||||
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
|
||||
@ -94,7 +92,7 @@ def create(tenant_id):
|
||||
req["tenant_id"] = tenant_id
|
||||
# prompt more parameter
|
||||
default_prompt = {
|
||||
"system": """You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.
|
||||
"system": """You are an intelligent assistant. Please summarize the content of the dataset to answer the question. Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the dataset!" Answers need to consider chat history.
|
||||
Here is the knowledge base:
|
||||
{knowledge}
|
||||
The above is the knowledge base.""",
|
||||
@ -145,10 +143,10 @@ def create(tenant_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, chat_id):
|
||||
async def update(tenant_id, chat_id):
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message="You do not own the chat")
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
ids = req.get("dataset_ids", [])
|
||||
if "show_quotation" in req:
|
||||
req["do_refer"] = req.pop("show_quotation")
|
||||
@ -176,7 +174,9 @@ def update(tenant_id, chat_id):
|
||||
req["llm_id"] = llm.pop("model_name")
|
||||
if req.get("llm_id") is not None:
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
|
||||
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"):
|
||||
model_type = llm.pop("model_type")
|
||||
model_type = model_type if model_type in ["chat", "image2text"] else "chat"
|
||||
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type):
|
||||
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
|
||||
req["llm_setting"] = req.pop("llm")
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
@ -228,10 +228,10 @@ def update(tenant_id, chat_id):
|
||||
|
||||
@manager.route("/chats", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id):
|
||||
async def delete_chats(tenant_id):
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
if not req:
|
||||
ids = None
|
||||
else:
|
||||
@ -251,8 +251,7 @@ def delete(tenant_id):
|
||||
errors.append(f"Assistant({id}) not found.")
|
||||
continue
|
||||
temp_dict = {"status": StatusEnum.INVALID.value}
|
||||
DialogService.update_by_id(id, temp_dict)
|
||||
success_count += 1
|
||||
success_count += DialogService.update_by_id(id, temp_dict)
|
||||
|
||||
if errors:
|
||||
if success_count > 0:
|
||||
|
||||
@ -18,13 +18,14 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from flask import request
|
||||
from quart import request
|
||||
from peewee import OperationalError
|
||||
from api.db.db_models import File
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.constants import RetCode, FileSource, StatusEnum
|
||||
from api.utils.api_utils import (
|
||||
@ -53,7 +54,7 @@ from common import settings
|
||||
|
||||
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
async def create(tenant_id):
|
||||
"""
|
||||
Create a new dataset.
|
||||
---
|
||||
@ -115,17 +116,19 @@ def create(tenant_id):
|
||||
# | embedding_model| embd_id |
|
||||
# | chunk_method | parser_id |
|
||||
|
||||
req, err = validate_and_parse_json_request(request, CreateDatasetReq)
|
||||
req, err = await validate_and_parse_json_request(request, CreateDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
req = KnowledgebaseService.create_with_name(
|
||||
e, req = KnowledgebaseService.create_with_name(
|
||||
name = req.pop("name", None),
|
||||
tenant_id = tenant_id,
|
||||
parser_id = req.pop("parser_id", None),
|
||||
**req
|
||||
)
|
||||
|
||||
if not e:
|
||||
return req
|
||||
|
||||
# Insert embedding model(embd id)
|
||||
ok, t = TenantService.get_by_id(tenant_id)
|
||||
if not ok:
|
||||
@ -144,7 +147,6 @@ def create(tenant_id):
|
||||
ok, k = KnowledgebaseService.get_by_id(req["id"])
|
||||
if not ok:
|
||||
return get_error_data_result(message="Dataset created failed")
|
||||
|
||||
response_data = remap_dictionary_keys(k.to_dict())
|
||||
return get_result(data=response_data)
|
||||
except Exception as e:
|
||||
@ -153,7 +155,7 @@ def create(tenant_id):
|
||||
|
||||
@manager.route("/datasets", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id):
|
||||
async def delete(tenant_id):
|
||||
"""
|
||||
Delete datasets.
|
||||
---
|
||||
@ -191,7 +193,7 @@ def delete(tenant_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req, err = validate_and_parse_json_request(request, DeleteDatasetReq)
|
||||
req, err = await validate_and_parse_json_request(request, DeleteDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
@ -251,7 +253,7 @@ def delete(tenant_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, dataset_id):
|
||||
async def update(tenant_id, dataset_id):
|
||||
"""
|
||||
Update a dataset.
|
||||
---
|
||||
@ -317,7 +319,7 @@ def update(tenant_id, dataset_id):
|
||||
# | embedding_model| embd_id |
|
||||
# | chunk_method | parser_id |
|
||||
extras = {"dataset_id": dataset_id}
|
||||
req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
|
||||
req, err = await validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
@ -532,3 +534,157 @@ def delete_knowledge_graph(tenant_id, dataset_id):
|
||||
search.index_name(kb.tenant_id), dataset_id)
|
||||
|
||||
return get_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/run_graphrag", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def run_graphrag(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=dataset_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
||||
logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}")
|
||||
|
||||
return get_result(data={"graphrag_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/trace_graphrag", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def trace_graphrag(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if not task_id:
|
||||
return get_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_result(data={})
|
||||
|
||||
return get_result(data=task.to_dict())
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/run_raptor", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def run_raptor(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=dataset_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
|
||||
logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}")
|
||||
|
||||
return get_result(data={"raptor_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/trace_raptor", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def trace_raptor(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if not task_id:
|
||||
return get_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
|
||||
|
||||
return get_result(data=task.to_dict())
|
||||
@ -15,21 +15,21 @@
|
||||
#
|
||||
import logging
|
||||
|
||||
from flask import request, jsonify
|
||||
from quart import jsonify
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
||||
from common.metadata_utils import meta_filter, convert_conditions
|
||||
from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request
|
||||
from rag.app.tag import label_question
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
from common.constants import RetCode, LLMType
|
||||
from common import settings
|
||||
|
||||
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
|
||||
@apikey_required
|
||||
@validate_request("knowledge_id", "query")
|
||||
def retrieval(tenant_id):
|
||||
async def retrieval(tenant_id):
|
||||
"""
|
||||
Dify-compatible retrieval API
|
||||
---
|
||||
@ -113,14 +113,14 @@ def retrieval(tenant_id):
|
||||
404:
|
||||
description: Knowledge base or document not found
|
||||
"""
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
question = req["query"]
|
||||
kb_id = req["knowledge_id"]
|
||||
use_kg = req.get("use_kg", False)
|
||||
retrieval_setting = req.get("retrieval_setting", {})
|
||||
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
|
||||
top = int(retrieval_setting.get("top_k", 1024))
|
||||
metadata_condition = req.get("metadata_condition", {})
|
||||
metadata_condition = req.get("metadata_condition", {}) or {}
|
||||
metas = DocumentService.get_meta_by_kbs([kb_id])
|
||||
|
||||
doc_ids = []
|
||||
@ -131,12 +131,10 @@ def retrieval(tenant_id):
|
||||
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
print(metadata_condition)
|
||||
# print("after", convert_conditions(metadata_condition))
|
||||
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition)))
|
||||
# print("doc_ids", doc_ids)
|
||||
if not doc_ids and metadata_condition is not None:
|
||||
doc_ids = ['-999']
|
||||
if metadata_condition:
|
||||
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||
if not doc_ids and metadata_condition:
|
||||
doc_ids = ["-999"]
|
||||
ranks = settings.retriever.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
|
||||
@ -14,13 +14,14 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import pathlib
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
import xxhash
|
||||
from flask import request, send_file
|
||||
from quart import request, send_file
|
||||
from peewee import OperationalError
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
@ -33,9 +34,10 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.task_service import TaskService, queue_tasks
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
|
||||
from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of
|
||||
from common.metadata_utils import meta_filter, convert_conditions
|
||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
|
||||
get_request_json
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
@ -69,7 +71,7 @@ class Chunk(BaseModel):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def upload(dataset_id, tenant_id):
|
||||
async def upload(dataset_id, tenant_id):
|
||||
"""
|
||||
Upload documents to a dataset.
|
||||
---
|
||||
@ -93,6 +95,10 @@ def upload(dataset_id, tenant_id):
|
||||
type: file
|
||||
required: true
|
||||
description: Document files to upload.
|
||||
- in: formData
|
||||
name: parent_path
|
||||
type: string
|
||||
description: Optional nested path under the parent folder. Uses '/' separators.
|
||||
responses:
|
||||
200:
|
||||
description: Successfully uploaded documents.
|
||||
@ -126,9 +132,11 @@ def upload(dataset_id, tenant_id):
|
||||
type: string
|
||||
description: Processing status.
|
||||
"""
|
||||
if "file" not in request.files:
|
||||
form = await request.form
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_error_data_result(message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_result(message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -151,7 +159,7 @@ def upload(dataset_id, tenant_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not e:
|
||||
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
|
||||
err, files = FileService.upload_document(kb, file_objs, tenant_id)
|
||||
err, files = FileService.upload_document(kb, file_objs, tenant_id, parent_path=form.get("parent_path"))
|
||||
if err:
|
||||
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||
# rename key's name
|
||||
@ -175,7 +183,7 @@ def upload(dataset_id, tenant_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update_doc(tenant_id, dataset_id, document_id):
|
||||
async def update_doc(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Update a document within a dataset.
|
||||
---
|
||||
@ -224,12 +232,12 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
|
||||
return get_error_data_result(message="You don't own the dataset.")
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not e:
|
||||
return get_error_data_result(message="Can't find this knowledgebase!")
|
||||
return get_error_data_result(message="Can't find this dataset!")
|
||||
doc = DocumentService.query(kb_id=dataset_id, id=document_id)
|
||||
if not doc:
|
||||
return get_error_data_result(message="The dataset doesn't own the document.")
|
||||
@ -314,9 +322,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
try:
|
||||
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
|
||||
return get_error_data_result(message="Database error (Document update)!")
|
||||
|
||||
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
||||
return get_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -343,19 +349,17 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
}
|
||||
renamed_doc = {}
|
||||
for key, value in doc.to_dict().items():
|
||||
if key == "run":
|
||||
renamed_doc["run"] = run_mapping.get(str(value))
|
||||
new_key = key_mapping.get(key, key)
|
||||
renamed_doc[new_key] = value
|
||||
if key == "run":
|
||||
renamed_doc["run"] = run_mapping.get(value)
|
||||
renamed_doc["run"] = run_mapping.get(str(value))
|
||||
|
||||
return get_result(data=renamed_doc)
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def download(tenant_id, dataset_id, document_id):
|
||||
async def download(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Download a document from a dataset.
|
||||
---
|
||||
@ -405,10 +409,10 @@ def download(tenant_id, dataset_id, document_id):
|
||||
return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR)
|
||||
file = BytesIO(file_stream)
|
||||
# Use send_file with a proper filename and MIME type
|
||||
return send_file(
|
||||
return await send_file(
|
||||
file,
|
||||
as_attachment=True,
|
||||
download_name=doc[0].name,
|
||||
attachment_filename=doc[0].name,
|
||||
mimetype="application/octet-stream", # Set a default MIME type
|
||||
)
|
||||
|
||||
@ -529,7 +533,7 @@ def list_docs(dataset_id, tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||
|
||||
q = request.args
|
||||
document_id = q.get("id")
|
||||
document_id = q.get("id")
|
||||
name = q.get("name")
|
||||
|
||||
if document_id and not DocumentService.query(id=document_id, kb_id=dataset_id):
|
||||
@ -538,23 +542,39 @@ def list_docs(dataset_id, tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the document {name}.")
|
||||
|
||||
page = int(q.get("page", 1))
|
||||
page_size = int(q.get("page_size", 30))
|
||||
page_size = int(q.get("page_size", 30))
|
||||
orderby = q.get("orderby", "create_time")
|
||||
desc = str(q.get("desc", "true")).strip().lower() != "false"
|
||||
keywords = q.get("keywords", "")
|
||||
|
||||
# filters - align with OpenAPI parameter names
|
||||
suffix = q.getlist("suffix")
|
||||
run_status = q.getlist("run")
|
||||
create_time_from = int(q.get("create_time_from", 0))
|
||||
create_time_to = int(q.get("create_time_to", 0))
|
||||
suffix = q.getlist("suffix")
|
||||
run_status = q.getlist("run")
|
||||
create_time_from = int(q.get("create_time_from", 0))
|
||||
create_time_to = int(q.get("create_time_to", 0))
|
||||
metadata_condition_raw = q.get("metadata_condition")
|
||||
metadata_condition = {}
|
||||
if metadata_condition_raw:
|
||||
try:
|
||||
metadata_condition = json.loads(metadata_condition_raw)
|
||||
except Exception:
|
||||
return get_error_data_result(message="metadata_condition must be valid JSON.")
|
||||
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||
return get_error_data_result(message="metadata_condition must be an object.")
|
||||
|
||||
# map run status (accept text or numeric) - align with API parameter
|
||||
# map run status (text or numeric) - align with API parameter
|
||||
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
|
||||
run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status]
|
||||
|
||||
doc_ids_filter = None
|
||||
if metadata_condition:
|
||||
metas = DocumentService.get_flatted_meta_by_kbs([dataset_id])
|
||||
doc_ids_filter = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
|
||||
if metadata_condition.get("conditions") and not doc_ids_filter:
|
||||
return get_result(data={"total": 0, "docs": []})
|
||||
|
||||
docs, total = DocumentService.get_list(
|
||||
dataset_id, page, page_size, orderby, desc, keywords, document_id, name, suffix, run_status_converted
|
||||
dataset_id, page, page_size, orderby, desc, keywords, document_id, name, suffix, run_status_converted, doc_ids_filter
|
||||
)
|
||||
|
||||
# time range filter (0 means no bound)
|
||||
@ -568,7 +588,7 @@ def list_docs(dataset_id, tenant_id):
|
||||
# rename keys + map run status back to text for output
|
||||
key_mapping = {
|
||||
"chunk_num": "chunk_count",
|
||||
"kb_id": "dataset_id",
|
||||
"kb_id": "dataset_id",
|
||||
"token_num": "token_count",
|
||||
"parser_id": "chunk_method",
|
||||
}
|
||||
@ -583,9 +603,73 @@ def list_docs(dataset_id, tenant_id):
|
||||
|
||||
return get_result(data={"total": total, "docs": output_docs})
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/metadata/summary", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def metadata_summary(dataset_id, tenant_id):
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||
|
||||
try:
|
||||
summary = DocumentService.get_metadata_summary(dataset_id)
|
||||
return get_result(data={"summary": summary})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/metadata/update", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def metadata_batch_update(dataset_id, tenant_id):
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||
|
||||
req = await get_request_json()
|
||||
selector = req.get("selector", {}) or {}
|
||||
updates = req.get("updates", []) or []
|
||||
deletes = req.get("deletes", []) or []
|
||||
|
||||
if not isinstance(selector, dict):
|
||||
return get_error_data_result(message="selector must be an object.")
|
||||
if not isinstance(updates, list) or not isinstance(deletes, list):
|
||||
return get_error_data_result(message="updates and deletes must be lists.")
|
||||
|
||||
metadata_condition = selector.get("metadata_condition", {}) or {}
|
||||
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||
return get_error_data_result(message="metadata_condition must be an object.")
|
||||
|
||||
document_ids = selector.get("document_ids", []) or []
|
||||
if document_ids and not isinstance(document_ids, list):
|
||||
return get_error_data_result(message="document_ids must be a list.")
|
||||
|
||||
for upd in updates:
|
||||
if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd:
|
||||
return get_error_data_result(message="Each update requires key and value.")
|
||||
for d in deletes:
|
||||
if not isinstance(d, dict) or not d.get("key"):
|
||||
return get_error_data_result(message="Each delete requires key.")
|
||||
|
||||
kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
|
||||
target_doc_ids = set(kb_doc_ids)
|
||||
if document_ids:
|
||||
invalid_ids = set(document_ids) - set(kb_doc_ids)
|
||||
if invalid_ids:
|
||||
return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}")
|
||||
target_doc_ids = set(document_ids)
|
||||
|
||||
if metadata_condition:
|
||||
metas = DocumentService.get_flatted_meta_by_kbs([dataset_id])
|
||||
filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||
target_doc_ids = target_doc_ids & filtered_ids
|
||||
if metadata_condition.get("conditions") and not target_doc_ids:
|
||||
return get_result(data={"updated": 0, "matched_docs": 0})
|
||||
|
||||
target_doc_ids = list(target_doc_ids)
|
||||
updated = DocumentService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes)
|
||||
return get_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id, dataset_id):
|
||||
async def delete(tenant_id, dataset_id):
|
||||
"""
|
||||
Delete documents from a dataset.
|
||||
---
|
||||
@ -624,7 +708,7 @@ def delete(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
if not req:
|
||||
doc_ids = None
|
||||
else:
|
||||
@ -695,7 +779,7 @@ def delete(tenant_id, dataset_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/chunks", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def parse(tenant_id, dataset_id):
|
||||
async def parse(tenant_id, dataset_id):
|
||||
"""
|
||||
Start parsing documents into chunks.
|
||||
---
|
||||
@ -734,7 +818,7 @@ def parse(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
if not req.get("document_ids"):
|
||||
return get_error_data_result("`document_ids` is required")
|
||||
doc_list = req.get("document_ids")
|
||||
@ -778,7 +862,7 @@ def parse(tenant_id, dataset_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/chunks", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def stop_parsing(tenant_id, dataset_id):
|
||||
async def stop_parsing(tenant_id, dataset_id):
|
||||
"""
|
||||
Stop parsing documents into chunks.
|
||||
---
|
||||
@ -817,7 +901,7 @@ def stop_parsing(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
|
||||
if not req.get("document_ids"):
|
||||
return get_error_data_result("`document_ids` is required")
|
||||
@ -832,6 +916,8 @@ def stop_parsing(tenant_id, dataset_id):
|
||||
return get_error_data_result(message=f"You don't own the document {id}.")
|
||||
if int(doc[0].progress) == 1 or doc[0].progress == 0:
|
||||
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
|
||||
# Send cancellation signal via Redis to stop background task
|
||||
cancel_all_task_of(id)
|
||||
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
||||
DocumentService.update_by_id(id, info)
|
||||
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
|
||||
@ -885,7 +971,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
description: Chunk Id.
|
||||
description: Chunk id.
|
||||
- in: header
|
||||
name: Authorization
|
||||
type: string
|
||||
@ -1019,7 +1105,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
"/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["POST"]
|
||||
)
|
||||
@token_required
|
||||
def add_chunk(tenant_id, dataset_id, document_id):
|
||||
async def add_chunk(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Add a chunk to a document.
|
||||
---
|
||||
@ -1089,7 +1175,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
||||
if not doc:
|
||||
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
||||
doc = doc[0]
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
if not str(req.get("content", "")).strip():
|
||||
return get_error_data_result(message="`content` is required")
|
||||
if "important_keywords" in req:
|
||||
@ -1148,7 +1234,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
||||
"datasets/<dataset_id>/documents/<document_id>/chunks", methods=["DELETE"]
|
||||
)
|
||||
@token_required
|
||||
def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
async def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Remove chunks from a document.
|
||||
---
|
||||
@ -1195,7 +1281,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
docs = DocumentService.get_by_ids([document_id])
|
||||
if not docs:
|
||||
raise LookupError(f"Can't find the document with ID {document_id}!")
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
condition = {"doc_id": document_id}
|
||||
if "chunk_ids" in req:
|
||||
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
|
||||
@ -1219,7 +1305,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
"/datasets/<dataset_id>/documents/<document_id>/chunks/<chunk_id>", methods=["PUT"]
|
||||
)
|
||||
@token_required
|
||||
def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
"""
|
||||
Update a chunk within a document.
|
||||
---
|
||||
@ -1281,8 +1367,8 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
if not doc:
|
||||
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
||||
doc = doc[0]
|
||||
req = request.json
|
||||
if "content" in req:
|
||||
req = await get_request_json()
|
||||
if "content" in req and req["content"] is not None:
|
||||
content = req["content"]
|
||||
else:
|
||||
content = chunk.get("content_with_weight", "")
|
||||
@ -1323,7 +1409,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
|
||||
@manager.route("/retrieval", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def retrieval_test(tenant_id):
|
||||
async def retrieval_test(tenant_id):
|
||||
"""
|
||||
Retrieve chunks based on a query.
|
||||
---
|
||||
@ -1404,7 +1490,7 @@ def retrieval_test(tenant_id):
|
||||
format: float
|
||||
description: Similarity score.
|
||||
"""
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
if not req.get("dataset_ids"):
|
||||
return get_error_data_result("`dataset_ids` is required.")
|
||||
kb_ids = req["dataset_ids"]
|
||||
@ -1427,6 +1513,7 @@ def retrieval_test(tenant_id):
|
||||
question = req["question"]
|
||||
doc_ids = req.get("document_ids", [])
|
||||
use_kg = req.get("use_kg", False)
|
||||
toc_enhance = req.get("toc_enhance", False)
|
||||
langs = req.get("cross_languages", [])
|
||||
if not isinstance(doc_ids, list):
|
||||
return get_error_data_result("`documents` should be a list")
|
||||
@ -1435,9 +1522,14 @@ def retrieval_test(tenant_id):
|
||||
if doc_id not in doc_ids_list:
|
||||
return get_error_data_result(f"The datasets don't own the document {doc_id}")
|
||||
if not doc_ids:
|
||||
metadata_condition = req.get("metadata_condition", {})
|
||||
metadata_condition = req.get("metadata_condition", {}) or {}
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
doc_ids = meta_filter(metas, convert_conditions(metadata_condition))
|
||||
doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
|
||||
# If metadata_condition has conditions but no docs match, return empty result
|
||||
if not doc_ids and metadata_condition.get("conditions"):
|
||||
return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}})
|
||||
if metadata_condition and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = int(req.get("top_k", 1024))
|
||||
@ -1457,11 +1549,11 @@ def retrieval_test(tenant_id):
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
|
||||
|
||||
if langs:
|
||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
||||
question = await cross_languages(kb.tenant_id, None, question, langs)
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
question += await keyword_extraction(chat_mdl, question)
|
||||
|
||||
ranks = settings.retriever.retrieval(
|
||||
question,
|
||||
@ -1478,6 +1570,11 @@ def retrieval_test(tenant_id):
|
||||
highlight=highlight,
|
||||
rank_feature=label_question(question, kbs),
|
||||
)
|
||||
if toc_enhance:
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
cks = settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
|
||||
if cks:
|
||||
ranks["chunks"] = cks
|
||||
if use_kg:
|
||||
ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
|
||||
@ -14,35 +14,34 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from quart import request, make_response
|
||||
from pathlib import Path
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import server_error_response, token_required
|
||||
from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
|
||||
from common.misc_utils import get_uuid
|
||||
from api.db import FileType
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.file_service import FileService
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.file_utils import filename_type
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from common import settings
|
||||
from common.constants import RetCode
|
||||
|
||||
|
||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def upload(tenant_id):
|
||||
async def upload(tenant_id):
|
||||
"""
|
||||
Upload a file to the system.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -79,26 +78,28 @@ def upload(tenant_id):
|
||||
type: string
|
||||
description: File type (e.g., document, folder)
|
||||
"""
|
||||
pf_id = request.form.get("parent_id")
|
||||
form = await request.form
|
||||
files = await request.files
|
||||
pf_id = form.get("parent_id")
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(data=False, message='No file part!', code=400)
|
||||
file_objs = request.files.getlist('file')
|
||||
if 'file' not in files:
|
||||
return get_json_result(data=False, message='No file part!', code=RetCode.BAD_REQUEST)
|
||||
file_objs = files.getlist('file')
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(data=False, message='No selected file!', code=400)
|
||||
return get_json_result(data=False, message='No selected file!', code=RetCode.BAD_REQUEST)
|
||||
|
||||
file_res = []
|
||||
|
||||
try:
|
||||
e, pf_folder = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Can't find this folder!", code=404)
|
||||
return get_json_result(data=False, message="Can't find this folder!", code=RetCode.NOT_FOUND)
|
||||
|
||||
for file_obj in file_objs:
|
||||
# Handle file path
|
||||
@ -114,13 +115,13 @@ def upload(tenant_id):
|
||||
if file_len != len_id_list:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
||||
return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
len_id_list)
|
||||
else:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
||||
return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
len_id_list)
|
||||
|
||||
@ -151,12 +152,12 @@ def upload(tenant_id):
|
||||
|
||||
@manager.route('/file/create', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
async def create(tenant_id):
|
||||
"""
|
||||
Create a new file or folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -193,16 +194,16 @@ def create(tenant_id):
|
||||
type:
|
||||
type: string
|
||||
"""
|
||||
req = request.json
|
||||
pf_id = request.json.get("parent_id")
|
||||
input_file_type = request.json.get("type")
|
||||
req = await get_request_json()
|
||||
pf_id = req.get("parent_id")
|
||||
input_file_type = req.get("type")
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
try:
|
||||
if not FileService.is_parent_folder_exist(pf_id):
|
||||
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=400)
|
||||
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=RetCode.BAD_REQUEST)
|
||||
if FileService.query(name=req["name"], parent_id=pf_id):
|
||||
return get_json_result(data=False, message="Duplicated folder name in the same folder.", code=409)
|
||||
|
||||
@ -229,12 +230,12 @@ def create(tenant_id):
|
||||
|
||||
@manager.route('/file/list', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def list_files(tenant_id):
|
||||
async def list_files(tenant_id):
|
||||
"""
|
||||
List files under a specific folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -306,13 +307,13 @@ def list_files(tenant_id):
|
||||
try:
|
||||
e, file = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
files, total = FileService.get_by_pf_id(tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords)
|
||||
|
||||
parent_folder = FileService.get_parent_folder(pf_id)
|
||||
if not parent_folder:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
return get_json_result(data={"total": total, "files": files, "parent_folder": parent_folder.to_json()})
|
||||
except Exception as e:
|
||||
@ -321,12 +322,12 @@ def list_files(tenant_id):
|
||||
|
||||
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_root_folder(tenant_id):
|
||||
async def get_root_folder(tenant_id):
|
||||
"""
|
||||
Get user's root folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
responses:
|
||||
@ -357,12 +358,12 @@ def get_root_folder(tenant_id):
|
||||
|
||||
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_parent_folder():
|
||||
async def get_parent_folder():
|
||||
"""
|
||||
Get parent folder info of a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -392,7 +393,7 @@ def get_parent_folder():
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
parent_folder = FileService.get_parent_folder(file_id)
|
||||
return get_json_result(data={"parent_folder": parent_folder.to_json()})
|
||||
@ -402,12 +403,12 @@ def get_parent_folder():
|
||||
|
||||
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_all_parent_folders(tenant_id):
|
||||
async def get_all_parent_folders(tenant_id):
|
||||
"""
|
||||
Get all parent folders of a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -439,7 +440,7 @@ def get_all_parent_folders(tenant_id):
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
parent_folders = FileService.get_all_parent_folders(file_id)
|
||||
parent_folders_res = [folder.to_json() for folder in parent_folders]
|
||||
@ -450,12 +451,12 @@ def get_all_parent_folders(tenant_id):
|
||||
|
||||
@manager.route('/file/rm', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def rm(tenant_id):
|
||||
async def rm(tenant_id):
|
||||
"""
|
||||
Delete one or multiple files/folders.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -481,40 +482,40 @@ def rm(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
file_ids = req["file_ids"]
|
||||
try:
|
||||
for file_id in file_ids:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="File or Folder not found!", code=404)
|
||||
return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND)
|
||||
if not file.tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_id_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
for inner_file_id in file_id_list:
|
||||
e, file = FileService.get_by_id(inner_file_id)
|
||||
if not e:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
FileService.delete_folder_by_pf_id(tenant_id, file_id)
|
||||
else:
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
if not FileService.delete(file):
|
||||
return get_json_result(message="Database error (File removal)!", code=500)
|
||||
return get_json_result(message="Database error (File removal)!", code=RetCode.SERVER_ERROR)
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(file_id)
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_json_result(message="Database error (Document removal)!", code=500)
|
||||
return get_json_result(message="Database error (Document removal)!", code=RetCode.SERVER_ERROR)
|
||||
File2DocumentService.delete_by_file_id(file_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
@ -524,12 +525,12 @@ def rm(tenant_id):
|
||||
|
||||
@manager.route('/file/rename', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def rename(tenant_id):
|
||||
async def rename(tenant_id):
|
||||
"""
|
||||
Rename a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -556,27 +557,27 @@ def rename(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||
file.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=400)
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.BAD_REQUEST)
|
||||
|
||||
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||
if existing_file.name == req["name"]:
|
||||
return get_json_result(data=False, message="Duplicated file name in the same folder.", code=409)
|
||||
|
||||
if not FileService.update_by_id(req["file_id"], {"name": req["name"]}):
|
||||
return get_json_result(message="Database error (File rename)!", code=500)
|
||||
return get_json_result(message="Database error (File rename)!", code=RetCode.SERVER_ERROR)
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(req["file_id"])
|
||||
if informs:
|
||||
if not DocumentService.update_by_id(informs[0].document_id, {"name": req["name"]}):
|
||||
return get_json_result(message="Database error (Document rename)!", code=500)
|
||||
return get_json_result(message="Database error (Document rename)!", code=RetCode.SERVER_ERROR)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
@ -585,12 +586,12 @@ def rename(tenant_id):
|
||||
|
||||
@manager.route('/file/get/<file_id>', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get(tenant_id, file_id):
|
||||
async def get(tenant_id, file_id):
|
||||
"""
|
||||
Download a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
produces:
|
||||
@ -606,20 +607,20 @@ def get(tenant_id, file_id):
|
||||
description: File stream
|
||||
schema:
|
||||
type: file
|
||||
404:
|
||||
RetCode.NOT_FOUND:
|
||||
description: File not found
|
||||
"""
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
if not blob:
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = settings.STORAGE_IMPL.get(b, n)
|
||||
|
||||
response = flask.make_response(blob)
|
||||
response = await make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name)
|
||||
if ext:
|
||||
if file.type == FileType.VISUAL.value:
|
||||
@ -630,15 +631,28 @@ def get(tenant_id, file_id):
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route("/file/download/<attachment_id>", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
async def download_attachment(tenant_id,attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def move(tenant_id):
|
||||
async def move(tenant_id):
|
||||
"""
|
||||
Move one or multiple files to another folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -667,7 +681,7 @@ def move(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
try:
|
||||
file_ids = req["src_file_ids"]
|
||||
parent_id = req["dest_file_id"]
|
||||
@ -677,13 +691,13 @@ def move(tenant_id):
|
||||
for file_id in file_ids:
|
||||
file = files_dict[file_id]
|
||||
if not file:
|
||||
return get_json_result(message="File or Folder not found!", code=404)
|
||||
return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND)
|
||||
if not file.tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
fe, _ = FileService.get_by_id(parent_id)
|
||||
if not fe:
|
||||
return get_json_result(message="Parent Folder not found!", code=404)
|
||||
return get_json_result(message="Parent Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
FileService.move_file(file_ids, parent_id)
|
||||
return get_json_result(data=True)
|
||||
@ -693,8 +707,8 @@ def move(tenant_id):
|
||||
|
||||
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def convert(tenant_id):
|
||||
req = request.json
|
||||
async def convert(tenant_id):
|
||||
req = await get_request_json()
|
||||
kb_ids = req["kb_ids"]
|
||||
file_ids = req["file_ids"]
|
||||
file2documents = []
|
||||
@ -705,7 +719,7 @@ def convert(tenant_id):
|
||||
for file_id in file_ids:
|
||||
file = files_set[file_id]
|
||||
if not file:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
file_ids_list = [file_id]
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
@ -716,13 +730,13 @@ def convert(tenant_id):
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_json_result(
|
||||
message="Database error (Document removal)!", code=404)
|
||||
message="Database error (Document removal)!", code=RetCode.NOT_FOUND)
|
||||
File2DocumentService.delete_by_file_id(id)
|
||||
|
||||
# insert
|
||||
@ -730,11 +744,11 @@ def convert(tenant_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
message="Can't find this knowledgebase!", code=404)
|
||||
message="Can't find this dataset!", code=RetCode.NOT_FOUND)
|
||||
e, file = FileService.get_by_id(id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
message="Can't find this file!", code=404)
|
||||
message="Can't find this file!", code=RetCode.NOT_FOUND)
|
||||
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
|
||||
@ -14,38 +14,42 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import copy
|
||||
import re
|
||||
import time
|
||||
|
||||
import tiktoken
|
||||
from flask import Response, jsonify, request
|
||||
from quart import Response, jsonify, request
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService, completion_openai
|
||||
from api.db.services.canvas_service import completion as agent_completion
|
||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||
from api.db.services.conversation_service import completion as rag_completion
|
||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter
|
||||
from api.db.services.conversation_service import ConversationService
|
||||
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
|
||||
from api.db.services.conversation_service import async_completion as rag_completion
|
||||
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
|
||||
get_result, server_error_response, token_required, validate_request
|
||||
get_result, get_request_json, server_error_response, token_required, validate_request
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
||||
from rag.prompts.generator import cross_languages, keyword_extraction, chunks_format
|
||||
from common.constants import RetCode, LLMType, StatusEnum
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id, chat_id):
|
||||
req = request.json
|
||||
async def create(tenant_id, chat_id):
|
||||
req = await get_request_json()
|
||||
req["dialog_id"] = chat_id
|
||||
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
||||
if not dia:
|
||||
@ -73,7 +77,7 @@ def create(tenant_id, chat_id):
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create_agent_session(tenant_id, agent_id):
|
||||
async def create_agent_session(tenant_id, agent_id):
|
||||
user_id = request.args.get("user_id", tenant_id)
|
||||
e, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
if not e:
|
||||
@ -97,8 +101,8 @@ def create_agent_session(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, chat_id, session_id):
|
||||
req = request.json
|
||||
async def update(tenant_id, chat_id, session_id):
|
||||
req = await get_request_json()
|
||||
req["dialog_id"] = chat_id
|
||||
conv_id = session_id
|
||||
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
|
||||
@ -119,17 +123,39 @@ def update(tenant_id, chat_id, session_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def chat_completion(tenant_id, chat_id):
|
||||
req = request.json
|
||||
async def chat_completion(tenant_id, chat_id):
|
||||
req = await get_request_json()
|
||||
if not req:
|
||||
req = {"question": ""}
|
||||
if not req.get("session_id"):
|
||||
req["question"] = ""
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value)
|
||||
if not dia:
|
||||
return get_error_data_result(f"You don't own the chat {chat_id}")
|
||||
dia = dia[0]
|
||||
if req.get("session_id"):
|
||||
if not ConversationService.query(id=req["session_id"], dialog_id=chat_id):
|
||||
return get_error_data_result(f"You don't own the session {req['session_id']}")
|
||||
|
||||
metadata_condition = req.get("metadata_condition") or {}
|
||||
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||
return get_error_data_result(message="metadata_condition must be an object.")
|
||||
|
||||
if metadata_condition and req.get("question"):
|
||||
metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
|
||||
filtered_doc_ids = meta_filter(
|
||||
metas,
|
||||
convert_conditions(metadata_condition),
|
||||
metadata_condition.get("logic", "and"),
|
||||
)
|
||||
if metadata_condition.get("conditions") and not filtered_doc_ids:
|
||||
filtered_doc_ids = ["-999"]
|
||||
|
||||
if filtered_doc_ids:
|
||||
req["doc_ids"] = ",".join(filtered_doc_ids)
|
||||
else:
|
||||
req.pop("doc_ids", None)
|
||||
|
||||
if req.get("stream", True):
|
||||
resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
@ -140,7 +166,7 @@ def chat_completion(tenant_id, chat_id):
|
||||
return resp
|
||||
else:
|
||||
answer = None
|
||||
for ans in rag_completion(tenant_id, chat_id, **req):
|
||||
async for ans in rag_completion(tenant_id, chat_id, **req):
|
||||
answer = ans
|
||||
break
|
||||
return get_result(data=answer)
|
||||
@ -149,7 +175,7 @@ def chat_completion(tenant_id, chat_id):
|
||||
@manager.route("/chats_openai/<chat_id>/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("model", "messages") # noqa: F821
|
||||
@token_required
|
||||
def chat_completion_openai_like(tenant_id, chat_id):
|
||||
async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
"""
|
||||
OpenAI-like chat completion API that simulates the behavior of OpenAI's completions endpoint.
|
||||
|
||||
@ -192,7 +218,19 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
{"role": "user", "content": "Can you tell me how to install neovim"},
|
||||
],
|
||||
stream=stream,
|
||||
extra_body={"reference": reference}
|
||||
extra_body={
|
||||
"reference": reference,
|
||||
"metadata_condition": {
|
||||
"logic": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"name": "author",
|
||||
"comparison_operator": "is",
|
||||
"value": "bob"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if stream:
|
||||
@ -206,9 +244,13 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
if reference:
|
||||
print(completion.choices[0].message.reference)
|
||||
"""
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
|
||||
need_reference = bool(req.get("reference", False))
|
||||
extra_body = req.get("extra_body") or {}
|
||||
if extra_body and not isinstance(extra_body, dict):
|
||||
return get_error_data_result("extra_body must be an object.")
|
||||
|
||||
need_reference = bool(extra_body.get("reference", False))
|
||||
|
||||
messages = req.get("messages", [])
|
||||
# To prevent empty [] input
|
||||
@ -226,6 +268,22 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
return get_error_data_result(f"You don't own the chat {chat_id}")
|
||||
dia = dia[0]
|
||||
|
||||
metadata_condition = extra_body.get("metadata_condition") or {}
|
||||
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||
return get_error_data_result(message="metadata_condition must be an object.")
|
||||
|
||||
doc_ids_str = None
|
||||
if metadata_condition:
|
||||
metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
|
||||
filtered_doc_ids = meta_filter(
|
||||
metas,
|
||||
convert_conditions(metadata_condition),
|
||||
metadata_condition.get("logic", "and"),
|
||||
)
|
||||
if metadata_condition.get("conditions") and not filtered_doc_ids:
|
||||
filtered_doc_ids = ["-999"]
|
||||
doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None
|
||||
|
||||
# Filter system and non-sense assistant messages
|
||||
msg = []
|
||||
for m in messages:
|
||||
@ -244,7 +302,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
# The value for the usage field on all chunks except for the last one will be null.
|
||||
# The usage field on the last chunk contains token usage statistics for the entire request.
|
||||
# The choices field on the last chunk will always be an empty array [].
|
||||
def streamed_response_generator(chat_id, dia, msg):
|
||||
async def streamed_response_generator(chat_id, dia, msg):
|
||||
token_used = 0
|
||||
answer_cache = ""
|
||||
reasoning_cache = ""
|
||||
@ -273,14 +331,17 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
}
|
||||
|
||||
try:
|
||||
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||
chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
|
||||
if doc_ids_str:
|
||||
chat_kwargs["doc_ids"] = doc_ids_str
|
||||
async for ans in async_chat(dia, msg, True, **chat_kwargs):
|
||||
last_ans = ans
|
||||
answer = ans["answer"]
|
||||
|
||||
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
||||
if reasoning_match:
|
||||
reasoning_part = reasoning_match.group(1)
|
||||
content_part = answer[reasoning_match.end():]
|
||||
content_part = answer[reasoning_match.end() :]
|
||||
else:
|
||||
reasoning_part = ""
|
||||
content_part = answer
|
||||
@ -325,8 +386,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
response["choices"][0]["delta"]["content"] = None
|
||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||
response["choices"][0]["finish_reason"] = "stop"
|
||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used,
|
||||
"total_tokens": len(prompt) + token_used}
|
||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
|
||||
if need_reference:
|
||||
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
|
||||
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
|
||||
@ -341,7 +401,10 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
return resp
|
||||
else:
|
||||
answer = None
|
||||
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||
chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
|
||||
if doc_ids_str:
|
||||
chat_kwargs["doc_ids"] = doc_ids_str
|
||||
async for ans in async_chat(dia, msg, False, **chat_kwargs):
|
||||
# focus answer content only
|
||||
answer = ans
|
||||
break
|
||||
@ -383,8 +446,8 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
@manager.route("/agents_openai/<agent_id>/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("model", "messages") # noqa: F821
|
||||
@token_required
|
||||
def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
req = request.json
|
||||
async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
req = await get_request_json()
|
||||
tiktokenenc = tiktoken.get_encoding("cl100k_base")
|
||||
messages = req.get("messages", [])
|
||||
if not messages:
|
||||
@ -428,35 +491,49 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
return resp
|
||||
else:
|
||||
# For non-streaming, just return the response directly
|
||||
response = next(
|
||||
completion_openai(
|
||||
async for response in completion_openai(
|
||||
tenant_id,
|
||||
agent_id,
|
||||
question,
|
||||
session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""),
|
||||
stream=False,
|
||||
**req,
|
||||
)
|
||||
)
|
||||
return jsonify(response)
|
||||
):
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def agent_completions(tenant_id, agent_id):
|
||||
req = request.json
|
||||
async def agent_completions(tenant_id, agent_id):
|
||||
req = await get_request_json()
|
||||
return_trace = bool(req.get("return_trace", False))
|
||||
|
||||
if req.get("stream", True):
|
||||
|
||||
def generate():
|
||||
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
async def generate():
|
||||
trace_items = []
|
||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
if isinstance(answer, str):
|
||||
try:
|
||||
ans = json.loads(answer[5:]) # remove "data:"
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if ans.get("event") not in ["message", "message_end"]:
|
||||
event = ans.get("event")
|
||||
if event == "node_finished":
|
||||
if return_trace:
|
||||
data = ans.get("data", {})
|
||||
trace_items.append(
|
||||
{
|
||||
"component_id": data.get("component_id"),
|
||||
"trace": [copy.deepcopy(data)],
|
||||
}
|
||||
)
|
||||
ans.setdefault("data", {})["trace"] = trace_items
|
||||
answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
yield answer
|
||||
|
||||
if event not in ["message", "message_end"]:
|
||||
continue
|
||||
|
||||
yield answer
|
||||
@ -473,7 +550,8 @@ def agent_completions(tenant_id, agent_id):
|
||||
full_content = ""
|
||||
reference = {}
|
||||
final_ans = ""
|
||||
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
trace_items = []
|
||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
try:
|
||||
ans = json.loads(answer[5:])
|
||||
|
||||
@ -483,17 +561,28 @@ def agent_completions(tenant_id, agent_id):
|
||||
if ans.get("data", {}).get("reference", None):
|
||||
reference.update(ans["data"]["reference"])
|
||||
|
||||
if return_trace and ans.get("event") == "node_finished":
|
||||
data = ans.get("data", {})
|
||||
trace_items.append(
|
||||
{
|
||||
"component_id": data.get("component_id"),
|
||||
"trace": [copy.deepcopy(data)],
|
||||
}
|
||||
)
|
||||
|
||||
final_ans = ans
|
||||
except Exception as e:
|
||||
return get_result(data=f"**ERROR**: {str(e)}")
|
||||
final_ans["data"]["content"] = full_content
|
||||
final_ans["data"]["reference"] = reference
|
||||
if return_trace and final_ans:
|
||||
final_ans["data"]["trace"] = trace_items
|
||||
return get_result(data=final_ans)
|
||||
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list_session(tenant_id, chat_id):
|
||||
async def list_session(tenant_id, chat_id):
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
|
||||
id = request.args.get("id")
|
||||
@ -547,7 +636,7 @@ def list_session(tenant_id, chat_id):
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list_agent_session(tenant_id, agent_id):
|
||||
async def list_agent_session(tenant_id, agent_id):
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
return get_error_data_result(message=f"You don't own the agent {agent_id}.")
|
||||
id = request.args.get("id")
|
||||
@ -610,13 +699,13 @@ def list_agent_session(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id, chat_id):
|
||||
async def delete(tenant_id, chat_id):
|
||||
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message="You don't own the chat")
|
||||
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
convs = ConversationService.query(dialog_id=chat_id)
|
||||
if not req:
|
||||
ids = None
|
||||
@ -661,10 +750,10 @@ def delete(tenant_id, chat_id):
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete_agent_session(tenant_id, agent_id):
|
||||
async def delete_agent_session(tenant_id, agent_id):
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
|
||||
if not cvs:
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
@ -716,8 +805,8 @@ def delete_agent_session(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def ask_about(tenant_id):
|
||||
req = request.json
|
||||
async def ask_about(tenant_id):
|
||||
req = await get_request_json()
|
||||
if not req.get("question"):
|
||||
return get_error_data_result("`question` is required.")
|
||||
if not req.get("dataset_ids"):
|
||||
@ -734,10 +823,10 @@ def ask_about(tenant_id):
|
||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||
uid = tenant_id
|
||||
|
||||
def stream():
|
||||
async def stream():
|
||||
nonlocal req, uid
|
||||
try:
|
||||
for ans in ask(req["question"], req["kb_ids"], uid):
|
||||
async for ans in async_ask(req["question"], req["kb_ids"], uid):
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps(
|
||||
@ -755,8 +844,8 @@ def ask_about(tenant_id):
|
||||
|
||||
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def related_questions(tenant_id):
|
||||
req = request.json
|
||||
async def related_questions(tenant_id):
|
||||
req = await get_request_json()
|
||||
if not req.get("question"):
|
||||
return get_error_data_result("`question` is required.")
|
||||
question = req["question"]
|
||||
@ -789,7 +878,7 @@ Reason:
|
||||
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
||||
|
||||
"""
|
||||
ans = chat_mdl.chat(
|
||||
ans = await chat_mdl.async_chat(
|
||||
prompt,
|
||||
[
|
||||
{
|
||||
@ -806,8 +895,8 @@ Related search terms:
|
||||
|
||||
|
||||
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
|
||||
def chatbot_completions(dialog_id):
|
||||
req = request.json
|
||||
async def chatbot_completions(dialog_id):
|
||||
req = await get_request_json()
|
||||
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
@ -828,12 +917,12 @@ def chatbot_completions(dialog_id):
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
for answer in iframe_completion(dialog_id, **req):
|
||||
async for answer in iframe_completion(dialog_id, **req):
|
||||
return get_result(data=answer)
|
||||
|
||||
|
||||
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
|
||||
def chatbots_inputs(dialog_id):
|
||||
async def chatbots_inputs(dialog_id):
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -856,8 +945,8 @@ def chatbots_inputs(dialog_id):
|
||||
|
||||
|
||||
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
def agent_bot_completions(agent_id):
|
||||
req = request.json
|
||||
async def agent_bot_completions(agent_id):
|
||||
req = await get_request_json()
|
||||
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
@ -875,12 +964,12 @@ def agent_bot_completions(agent_id):
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
|
||||
async for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
|
||||
return get_result(data=answer)
|
||||
|
||||
|
||||
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
|
||||
def begin_inputs(agent_id):
|
||||
async def begin_inputs(agent_id):
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -901,7 +990,7 @@ def begin_inputs(agent_id):
|
||||
|
||||
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
|
||||
@validate_request("question", "kb_ids")
|
||||
def ask_about_embedded():
|
||||
async def ask_about_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -910,7 +999,7 @@ def ask_about_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
uid = objs[0].tenant_id
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
@ -919,10 +1008,10 @@ def ask_about_embedded():
|
||||
if search_app := SearchService.get_detail(search_id):
|
||||
search_config = search_app.get("search_config", {})
|
||||
|
||||
def stream():
|
||||
async def stream():
|
||||
nonlocal req, uid
|
||||
try:
|
||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps(
|
||||
@ -940,7 +1029,7 @@ def ask_about_embedded():
|
||||
|
||||
@manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval_test_embedded():
|
||||
async def retrieval_test_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -949,7 +1038,7 @@ def retrieval_test_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
@ -965,28 +1054,31 @@ def retrieval_test_embedded():
|
||||
use_kg = req.get("use_kg", False)
|
||||
top = int(req.get("top_k", 1024))
|
||||
langs = req.get("cross_languages", [])
|
||||
tenant_ids = []
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
if not tenant_id:
|
||||
return get_error_data_result(message="permission denined.")
|
||||
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
async def _retrieval():
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
tenant_ids = []
|
||||
_question = question
|
||||
|
||||
meta_data_filter = {}
|
||||
chat_mdl = None
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
else:
|
||||
meta_data_filter = req.get("meta_data_filter") or {}
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, _question, chat_mdl, local_doc_ids)
|
||||
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=tenant_id)
|
||||
for kb_id in kb_ids:
|
||||
for tenant in tenants:
|
||||
@ -994,7 +1086,7 @@ def retrieval_test_embedded():
|
||||
tenant_ids.append(tenant.tenant_id)
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.",
|
||||
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.",
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||
@ -1002,7 +1094,7 @@ def retrieval_test_embedded():
|
||||
return get_error_data_result(message="Knowledgebase not found!")
|
||||
|
||||
if langs:
|
||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
||||
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
@ -1012,15 +1104,15 @@ def retrieval_test_embedded():
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
_question += await keyword_extraction(chat_mdl, _question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
labels = label_question(_question, [kb])
|
||||
ranks = settings.retriever.retrieval(
|
||||
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||
)
|
||||
if use_kg:
|
||||
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl,
|
||||
ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
|
||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
@ -1030,6 +1122,9 @@ def retrieval_test_embedded():
|
||||
ranks["labels"] = labels
|
||||
|
||||
return get_json_result(data=ranks)
|
||||
|
||||
try:
|
||||
return await _retrieval()
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
||||
@ -1039,7 +1134,7 @@ def retrieval_test_embedded():
|
||||
|
||||
@manager.route("/searchbots/related_questions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("question")
|
||||
def related_questions_embedded():
|
||||
async def related_questions_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -1048,7 +1143,7 @@ def related_questions_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
tenant_id = objs[0].tenant_id
|
||||
if not tenant_id:
|
||||
return get_error_data_result(message="permission denined.")
|
||||
@ -1066,7 +1161,7 @@ def related_questions_embedded():
|
||||
|
||||
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
|
||||
prompt = load_prompt("related_question")
|
||||
ans = chat_mdl.chat(
|
||||
ans = await chat_mdl.async_chat(
|
||||
prompt,
|
||||
[
|
||||
{
|
||||
@ -1083,7 +1178,7 @@ Related search terms:
|
||||
|
||||
|
||||
@manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821
|
||||
def detail_share_embedded():
|
||||
async def detail_share_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -1115,7 +1210,7 @@ def detail_share_embedded():
|
||||
|
||||
@manager.route("/searchbots/mindmap", methods=["POST"]) # noqa: F821
|
||||
@validate_request("question", "kb_ids")
|
||||
def mindmap():
|
||||
async def mindmap():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -1125,12 +1220,12 @@ def mindmap():
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
|
||||
mind_map = gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
|
||||
mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
|
||||
if "error" in mind_map:
|
||||
return server_error_response(Exception(mind_map["error"]))
|
||||
return get_json_result(data=mind_map)
|
||||
|
||||
@ -14,8 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import request
|
||||
from api.apps import current_user, login_required
|
||||
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from api.db.db_models import DB
|
||||
@ -24,14 +24,14 @@ from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode, StatusEnum
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, get_request_json, server_error_response, validate_request
|
||||
|
||||
|
||||
@manager.route("/create", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.get_json()
|
||||
async def create():
|
||||
req = await get_request_json()
|
||||
search_name = req["name"]
|
||||
description = req.get("description", "")
|
||||
if not isinstance(search_name, str):
|
||||
@ -65,8 +65,8 @@ def create():
|
||||
@login_required
|
||||
@validate_request("search_id", "name", "search_config", "tenant_id")
|
||||
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||
def update():
|
||||
req = request.get_json()
|
||||
async def update():
|
||||
req = await get_request_json()
|
||||
if not isinstance(req["name"], str):
|
||||
return get_data_error_result(message="Search name must be string.")
|
||||
if req["name"].strip() == "":
|
||||
@ -140,7 +140,7 @@ def detail():
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_search_app():
|
||||
async def list_search_app():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
@ -150,7 +150,7 @@ def list_search_app():
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -173,8 +173,8 @@ def list_search_app():
|
||||
@manager.route("/rm", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("search_id")
|
||||
def rm():
|
||||
req = request.get_json()
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
search_id = req["search_id"]
|
||||
if not SearchService.accessible4deletion(search_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
@ -17,7 +17,7 @@ import logging
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import APITokenService
|
||||
@ -34,7 +34,7 @@ from common.time_utils import current_timestamp, datetime_format
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from flask import jsonify
|
||||
from quart import jsonify
|
||||
from api.utils.health_utils import run_health_checks
|
||||
from common import settings
|
||||
|
||||
|
||||
@ -13,11 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.apps import smtp_mail_server
|
||||
import logging
|
||||
import asyncio
|
||||
from api.db import UserTenantRole
|
||||
from api.db.db_models import UserTenant
|
||||
from api.db.services.user_service import UserTenantService, UserService
|
||||
@ -25,9 +22,10 @@ from api.db.services.user_service import UserTenantService, UserService
|
||||
from common.constants import RetCode, StatusEnum
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import delta_seconds
|
||||
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from api.utils.web_utils import send_invite_email
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route("/<tenant_id>/user/list", methods=["GET"]) # noqa: F821
|
||||
@ -51,14 +49,14 @@ def user_list(tenant_id):
|
||||
@manager.route('/<tenant_id>/user', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("email")
|
||||
def create(tenant_id):
|
||||
async def create(tenant_id):
|
||||
if current_user.id != tenant_id:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
invite_user_email = req["email"]
|
||||
invite_users = UserService.query(email=invite_user_email)
|
||||
if not invite_users:
|
||||
@ -83,20 +81,24 @@ def create(tenant_id):
|
||||
role=UserTenantRole.INVITE,
|
||||
status=StatusEnum.VALID.value)
|
||||
|
||||
if smtp_mail_server and settings.SMTP_CONF:
|
||||
from threading import Thread
|
||||
try:
|
||||
|
||||
user_name = ""
|
||||
_, user = UserService.get_by_id(current_user.id)
|
||||
if user:
|
||||
user_name = user.nickname
|
||||
|
||||
Thread(
|
||||
target=send_invite_email,
|
||||
args=(invite_user_email, settings.MAIL_FRONTEND_URL, tenant_id, user_name or current_user.email),
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
asyncio.create_task(
|
||||
send_invite_email(
|
||||
to_email=invite_user_email,
|
||||
invite_url=settings.MAIL_FRONTEND_URL,
|
||||
tenant_id=tenant_id,
|
||||
inviter=user_name or current_user.email
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to send invite email to {invite_user_email}: {e}")
|
||||
return get_json_result(data=False, message="Failed to send invite email.", code=RetCode.SERVER_ERROR)
|
||||
usr = invite_users[0].to_dict()
|
||||
usr = {k: v for k, v in usr.items() if k in ["id", "avatar", "email", "nickname"]}
|
||||
|
||||
|
||||
@ -21,9 +21,9 @@ import re
|
||||
import secrets
|
||||
import time
|
||||
from datetime import datetime
|
||||
import base64
|
||||
|
||||
from flask import redirect, request, session, make_response
|
||||
from flask_login import current_user, login_required, login_user, logout_user
|
||||
from quart import make_response, redirect, request, session
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
|
||||
from api.apps.auth import get_auth_client
|
||||
@ -40,12 +40,13 @@ from common.connection_utils import construct_response
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
get_request_json,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
)
|
||||
from api.utils.crypt import decrypt
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api.apps import smtp_mail_server
|
||||
from api.apps import login_required, current_user, login_user, logout_user
|
||||
from api.utils.web_utils import (
|
||||
send_email_html,
|
||||
OTP_LENGTH,
|
||||
@ -58,10 +59,11 @@ from api.utils.web_utils import (
|
||||
captcha_key,
|
||||
)
|
||||
from common import settings
|
||||
from common.http_client import async_request
|
||||
|
||||
|
||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||
def login():
|
||||
async def login():
|
||||
"""
|
||||
User login endpoint.
|
||||
---
|
||||
@ -91,10 +93,14 @@ def login():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
if not request.json:
|
||||
json_body = await get_request_json()
|
||||
if not json_body:
|
||||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
|
||||
|
||||
email = request.json.get("email", "")
|
||||
email = json_body.get("email", "")
|
||||
if email == "admin@ragflow.io":
|
||||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Default admin account cannot be used to login normal services!")
|
||||
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
return get_json_result(
|
||||
@ -103,7 +109,7 @@ def login():
|
||||
message=f"Email: {email} is not registered!",
|
||||
)
|
||||
|
||||
password = request.json.get("password")
|
||||
password = json_body.get("password")
|
||||
try:
|
||||
password = decrypt(password)
|
||||
except BaseException:
|
||||
@ -121,11 +127,12 @@ def login():
|
||||
response_data = user.to_json()
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.update_time = (current_timestamp(),)
|
||||
user.update_date = (datetime_format(datetime.now()),)
|
||||
user.update_time = current_timestamp()
|
||||
user.update_date = datetime_format(datetime.now())
|
||||
user.save()
|
||||
msg = "Welcome back!"
|
||||
return construct_response(data=response_data, auth=user.get_id(), message=msg)
|
||||
|
||||
return await construct_response(data=response_data, auth=user.get_id(), message=msg)
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -135,7 +142,7 @@ def login():
|
||||
|
||||
|
||||
@manager.route("/login/channels", methods=["GET"]) # noqa: F821
|
||||
def get_login_channels():
|
||||
async def get_login_channels():
|
||||
"""
|
||||
Get all supported authentication channels.
|
||||
"""
|
||||
@ -156,7 +163,7 @@ def get_login_channels():
|
||||
|
||||
|
||||
@manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
|
||||
def oauth_login(channel):
|
||||
async def oauth_login(channel):
|
||||
channel_config = settings.OAUTH_CONFIG.get(channel)
|
||||
if not channel_config:
|
||||
raise ValueError(f"Invalid channel name: {channel}")
|
||||
@ -169,7 +176,7 @@ def oauth_login(channel):
|
||||
|
||||
|
||||
@manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
|
||||
def oauth_callback(channel):
|
||||
async def oauth_callback(channel):
|
||||
"""
|
||||
Handle the OAuth/OIDC callback for various channels dynamically.
|
||||
"""
|
||||
@ -191,7 +198,10 @@ def oauth_callback(channel):
|
||||
return redirect("/?error=missing_code")
|
||||
|
||||
# Exchange authorization code for access token
|
||||
token_info = auth_cli.exchange_code_for_token(code)
|
||||
if hasattr(auth_cli, "async_exchange_code_for_token"):
|
||||
token_info = await auth_cli.async_exchange_code_for_token(code)
|
||||
else:
|
||||
token_info = auth_cli.exchange_code_for_token(code)
|
||||
access_token = token_info.get("access_token")
|
||||
if not access_token:
|
||||
return redirect("/?error=token_failed")
|
||||
@ -199,7 +209,10 @@ def oauth_callback(channel):
|
||||
id_token = token_info.get("id_token")
|
||||
|
||||
# Fetch user info
|
||||
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
|
||||
if hasattr(auth_cli, "async_fetch_user_info"):
|
||||
user_info = await auth_cli.async_fetch_user_info(access_token, id_token=id_token)
|
||||
else:
|
||||
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
|
||||
if not user_info.email:
|
||||
return redirect("/?error=email_missing")
|
||||
|
||||
@ -258,7 +271,7 @@ def oauth_callback(channel):
|
||||
|
||||
|
||||
@manager.route("/github_callback", methods=["GET"]) # noqa: F821
|
||||
def github_callback():
|
||||
async def github_callback():
|
||||
"""
|
||||
**Deprecated**, Use `/oauth/callback/<channel>` instead.
|
||||
|
||||
@ -278,9 +291,8 @@ def github_callback():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
import requests
|
||||
|
||||
res = requests.post(
|
||||
res = await async_request(
|
||||
"POST",
|
||||
settings.GITHUB_OAUTH.get("url"),
|
||||
data={
|
||||
"client_id": settings.GITHUB_OAUTH.get("client_id"),
|
||||
@ -298,7 +310,7 @@ def github_callback():
|
||||
|
||||
session["access_token"] = res["access_token"]
|
||||
session["access_token_from"] = "github"
|
||||
user_info = user_info_from_github(session["access_token"])
|
||||
user_info = await user_info_from_github(session["access_token"])
|
||||
email_address = user_info["email"]
|
||||
users = UserService.query(email=email_address)
|
||||
user_id = get_uuid()
|
||||
@ -347,7 +359,7 @@ def github_callback():
|
||||
|
||||
|
||||
@manager.route("/feishu_callback", methods=["GET"]) # noqa: F821
|
||||
def feishu_callback():
|
||||
async def feishu_callback():
|
||||
"""
|
||||
Feishu OAuth callback endpoint.
|
||||
---
|
||||
@ -365,9 +377,8 @@ def feishu_callback():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
import requests
|
||||
|
||||
app_access_token_res = requests.post(
|
||||
app_access_token_res = await async_request(
|
||||
"POST",
|
||||
settings.FEISHU_OAUTH.get("app_access_token_url"),
|
||||
data=json.dumps(
|
||||
{
|
||||
@ -381,7 +392,8 @@ def feishu_callback():
|
||||
if app_access_token_res["code"] != 0:
|
||||
return redirect("/?error=%s" % app_access_token_res)
|
||||
|
||||
res = requests.post(
|
||||
res = await async_request(
|
||||
"POST",
|
||||
settings.FEISHU_OAUTH.get("user_access_token_url"),
|
||||
data=json.dumps(
|
||||
{
|
||||
@ -402,7 +414,7 @@ def feishu_callback():
|
||||
return redirect("/?error=contact:user.email:readonly not in scope")
|
||||
session["access_token"] = res["data"]["access_token"]
|
||||
session["access_token_from"] = "feishu"
|
||||
user_info = user_info_from_feishu(session["access_token"])
|
||||
user_info = await user_info_from_feishu(session["access_token"])
|
||||
email_address = user_info["email"]
|
||||
users = UserService.query(email=email_address)
|
||||
user_id = get_uuid()
|
||||
@ -450,36 +462,34 @@ def feishu_callback():
|
||||
return redirect("/?auth=%s" % user.get_id())
|
||||
|
||||
|
||||
def user_info_from_feishu(access_token):
|
||||
import requests
|
||||
|
||||
async def user_info_from_feishu(access_token):
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
|
||||
res = await async_request("GET", "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
|
||||
user_info = res.json()["data"]
|
||||
user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
|
||||
return user_info
|
||||
|
||||
|
||||
def user_info_from_github(access_token):
|
||||
import requests
|
||||
|
||||
async def user_info_from_github(access_token):
|
||||
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
|
||||
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||
res = await async_request("GET", f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||
user_info = res.json()
|
||||
email_info = requests.get(
|
||||
email_info_response = await async_request(
|
||||
"GET",
|
||||
f"https://api.github.com/user/emails?access_token={access_token}",
|
||||
headers=headers,
|
||||
).json()
|
||||
)
|
||||
email_info = email_info_response.json()
|
||||
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||
return user_info
|
||||
|
||||
|
||||
@manager.route("/logout", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def log_out():
|
||||
async def log_out():
|
||||
"""
|
||||
User logout endpoint.
|
||||
---
|
||||
@ -501,7 +511,7 @@ def log_out():
|
||||
|
||||
@manager.route("/setting", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def setting_user():
|
||||
async def setting_user():
|
||||
"""
|
||||
Update user settings.
|
||||
---
|
||||
@ -530,7 +540,7 @@ def setting_user():
|
||||
type: object
|
||||
"""
|
||||
update_dict = {}
|
||||
request_data = request.json
|
||||
request_data = await get_request_json()
|
||||
if request_data.get("password"):
|
||||
new_password = request_data.get("new_password")
|
||||
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||||
@ -569,7 +579,7 @@ def setting_user():
|
||||
|
||||
@manager.route("/info", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def user_profile():
|
||||
async def user_profile():
|
||||
"""
|
||||
Get user profile information.
|
||||
---
|
||||
@ -660,7 +670,7 @@ def user_register(user_id, user):
|
||||
|
||||
@manager.route("/register", methods=["POST"]) # noqa: F821
|
||||
@validate_request("nickname", "email", "password")
|
||||
def user_add():
|
||||
async def user_add():
|
||||
"""
|
||||
Register a new user.
|
||||
---
|
||||
@ -697,7 +707,7 @@ def user_add():
|
||||
code=RetCode.OPERATING_ERROR,
|
||||
)
|
||||
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
email_address = req["email"]
|
||||
|
||||
# Validate the email address
|
||||
@ -737,7 +747,7 @@ def user_add():
|
||||
raise Exception(f"Same email: {email_address} exists!")
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return construct_response(
|
||||
return await construct_response(
|
||||
data=user.to_json(),
|
||||
auth=user.get_id(),
|
||||
message=f"{nickname}, welcome aboard!",
|
||||
@ -754,7 +764,7 @@ def user_add():
|
||||
|
||||
@manager.route("/tenant_info", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def tenant_info():
|
||||
async def tenant_info():
|
||||
"""
|
||||
Get tenant information.
|
||||
---
|
||||
@ -793,7 +803,7 @@ def tenant_info():
|
||||
@manager.route("/set_tenant_info", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
|
||||
def set_tenant_info():
|
||||
async def set_tenant_info():
|
||||
"""
|
||||
Update tenant information.
|
||||
---
|
||||
@ -830,17 +840,17 @@ def set_tenant_info():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
try:
|
||||
tid = req.pop("tenant_id")
|
||||
TenantService.update_by_id(tid, req)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
|
||||
@manager.route("/forget/captcha", methods=["GET"]) # noqa: F821
|
||||
def forget_get_captcha():
|
||||
async def forget_get_captcha():
|
||||
"""
|
||||
GET /forget/captcha?email=<email>
|
||||
- Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = OTP_TTL_SECONDS.
|
||||
@ -862,19 +872,19 @@ def forget_get_captcha():
|
||||
from captcha.image import ImageCaptcha
|
||||
image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70])
|
||||
img_bytes = image.generate(captcha_text).read()
|
||||
response = make_response(img_bytes)
|
||||
response = await make_response(img_bytes)
|
||||
response.headers.set("Content-Type", "image/JPEG")
|
||||
return response
|
||||
|
||||
|
||||
@manager.route("/forget/otp", methods=["POST"]) # noqa: F821
|
||||
def forget_send_otp():
|
||||
async def forget_send_otp():
|
||||
"""
|
||||
POST /forget/otp
|
||||
- Verify the image captcha stored at captcha:{email} (case-insensitive).
|
||||
- On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
|
||||
"""
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
email = req.get("email") or ""
|
||||
captcha = (req.get("captcha") or "").strip()
|
||||
|
||||
@ -917,47 +927,45 @@ def forget_send_otp():
|
||||
|
||||
ttl_min = OTP_TTL_SECONDS // 60
|
||||
|
||||
if not smtp_mail_server:
|
||||
logging.warning("SMTP mail server not initialized; skip sending email.")
|
||||
else:
|
||||
try:
|
||||
send_email_html(
|
||||
subject="Your Password Reset Code",
|
||||
to_email=email,
|
||||
template_key="reset_code",
|
||||
code=otp,
|
||||
ttl_min=ttl_min,
|
||||
)
|
||||
except Exception:
|
||||
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to send email")
|
||||
|
||||
try:
|
||||
await send_email_html(
|
||||
subject="Your Password Reset Code",
|
||||
to_email=email,
|
||||
template_key="reset_code",
|
||||
code=otp,
|
||||
ttl_min=ttl_min,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to send email")
|
||||
|
||||
return get_json_result(data=True, code=RetCode.SUCCESS, message="verification passed, email sent")
|
||||
|
||||
|
||||
@manager.route("/forget", methods=["POST"]) # noqa: F821
|
||||
def forget():
|
||||
def _verified_key(email: str) -> str:
|
||||
return f"otp:verified:{email}"
|
||||
|
||||
|
||||
@manager.route("/forget/verify-otp", methods=["POST"]) # noqa: F821
|
||||
async def forget_verify_otp():
|
||||
"""
|
||||
POST: Verify email + OTP and reset password, then log the user in.
|
||||
Request JSON: { email, otp, new_password, confirm_new_password }
|
||||
Verify email + OTP only. On success:
|
||||
- consume the OTP and attempt counters
|
||||
- set a short-lived verified flag in Redis for the email
|
||||
Request JSON: { email, otp }
|
||||
"""
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
email = req.get("email") or ""
|
||||
otp = (req.get("otp") or "").strip()
|
||||
new_pwd = req.get("new_password")
|
||||
new_pwd2 = req.get("confirm_new_password")
|
||||
|
||||
if not all([email, otp, new_pwd, new_pwd2]):
|
||||
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email, otp and passwords are required")
|
||||
|
||||
# For reset, passwords are provided as-is (no decrypt needed)
|
||||
if new_pwd != new_pwd2:
|
||||
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="passwords do not match")
|
||||
if not all([email, otp]):
|
||||
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email and otp are required")
|
||||
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
return get_json_result(data=False, code=RetCode.DATA_ERROR, message="invalid email")
|
||||
|
||||
user = users[0]
|
||||
# Verify OTP from Redis
|
||||
k_code, k_attempts, k_last, k_lock = otp_keys(email)
|
||||
if REDIS_CONN.get(k_lock):
|
||||
@ -973,7 +981,6 @@ def forget():
|
||||
except Exception:
|
||||
return get_json_result(data=False, code=RetCode.EXCEPTION_ERROR, message="otp storage corrupted")
|
||||
|
||||
# Case-insensitive verification: OTP generated uppercase
|
||||
calc = hash_code(otp.upper(), salt)
|
||||
if calc != stored_hash:
|
||||
# bump attempts
|
||||
@ -986,23 +993,70 @@ def forget():
|
||||
REDIS_CONN.set(k_lock, int(time.time()), ATTEMPT_LOCK_SECONDS)
|
||||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="expired otp")
|
||||
|
||||
# Success: consume OTP and reset password
|
||||
# Success: consume OTP and attempts; mark verified
|
||||
REDIS_CONN.delete(k_code)
|
||||
REDIS_CONN.delete(k_attempts)
|
||||
REDIS_CONN.delete(k_last)
|
||||
REDIS_CONN.delete(k_lock)
|
||||
|
||||
# set verified flag with limited TTL, reuse OTP_TTL_SECONDS or smaller window
|
||||
try:
|
||||
UserService.update_user_password(user.id, new_pwd)
|
||||
REDIS_CONN.set(_verified_key(email), "1", OTP_TTL_SECONDS)
|
||||
except Exception:
|
||||
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to set verification state")
|
||||
|
||||
return get_json_result(data=True, code=RetCode.SUCCESS, message="otp verified")
|
||||
|
||||
|
||||
@manager.route("/forget/reset-password", methods=["POST"]) # noqa: F821
|
||||
async def forget_reset_password():
|
||||
"""
|
||||
Reset password after successful OTP verification.
|
||||
Requires: { email, new_password, confirm_new_password }
|
||||
Steps:
|
||||
- check verified flag in Redis
|
||||
- update user password
|
||||
- auto login
|
||||
- clear verified flag
|
||||
"""
|
||||
|
||||
req = await get_request_json()
|
||||
email = req.get("email") or ""
|
||||
new_pwd = req.get("new_password")
|
||||
new_pwd2 = req.get("confirm_new_password")
|
||||
|
||||
new_pwd_base64 = decrypt(new_pwd)
|
||||
new_pwd_string = base64.b64decode(new_pwd_base64).decode('utf-8')
|
||||
new_pwd2_string = base64.b64decode(decrypt(new_pwd2)).decode('utf-8')
|
||||
|
||||
REDIS_CONN.get(_verified_key(email))
|
||||
if not REDIS_CONN.get(_verified_key(email)):
|
||||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="email not verified")
|
||||
|
||||
if not all([email, new_pwd, new_pwd2]):
|
||||
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email and passwords are required")
|
||||
|
||||
if new_pwd_string != new_pwd2_string:
|
||||
return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="passwords do not match")
|
||||
|
||||
users = UserService.query_user_by_email(email=email)
|
||||
if not users:
|
||||
return get_json_result(data=False, code=RetCode.DATA_ERROR, message="invalid email")
|
||||
|
||||
user = users[0]
|
||||
try:
|
||||
UserService.update_user_password(user.id, new_pwd_base64)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_json_result(data=False, code=RetCode.EXCEPTION_ERROR, message="failed to reset password")
|
||||
|
||||
# Auto login (reuse login flow)
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.update_time = (current_timestamp(),)
|
||||
user.update_date = (datetime_format(datetime.now()),)
|
||||
user.save()
|
||||
# clear verified flag
|
||||
try:
|
||||
REDIS_CONN.delete(_verified_key(email))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
msg = "Password reset successful. Logged in."
|
||||
return construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
|
||||
return await construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
|
||||
|
||||
|
||||
|
||||
@ -24,3 +24,5 @@ REQUEST_MAX_WAIT_SEC = 300
|
||||
|
||||
DATASET_NAME_LIMIT = 128
|
||||
FILE_NAME_LEN_LIMIT = 255
|
||||
MEMORY_NAME_LIMIT = 128
|
||||
MEMORY_SIZE_LIMIT = 10*1024*1024 # Byte
|
||||
|
||||
@ -25,7 +25,7 @@ from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import UserMixin
|
||||
from quart_auth import AuthUser
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||
@ -305,6 +305,7 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
|
||||
@ -594,7 +595,7 @@ def fill_db_model_object(model_object, human_model_dict):
|
||||
return model_object
|
||||
|
||||
|
||||
class User(DataBaseModel, UserMixin):
|
||||
class User(DataBaseModel, AuthUser):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
access_token = CharField(max_length=255, null=True, index=True)
|
||||
nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
|
||||
@ -748,7 +749,7 @@ class Knowledgebase(DataBaseModel):
|
||||
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0})
|
||||
pagerank = IntegerField(default=0, index=False)
|
||||
|
||||
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
|
||||
@ -772,8 +773,8 @@ class Document(DataBaseModel):
|
||||
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
||||
kb_id = CharField(max_length=256, null=False, index=True)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="pipleline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="pipeline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0})
|
||||
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
|
||||
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
|
||||
created_by = CharField(max_length=32, null=False, help_text="who created it", index=True)
|
||||
@ -876,7 +877,7 @@ class Dialog(DataBaseModel):
|
||||
class Conversation(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||
name = CharField(max_length=255, null=True, help_text="converastion name", index=True)
|
||||
name = CharField(max_length=255, null=True, help_text="conversation name", index=True)
|
||||
message = JSONField(null=True)
|
||||
reference = JSONField(null=True, default=[])
|
||||
user_id = CharField(max_length=255, null=True, help_text="user_id", index=True)
|
||||
@ -1112,6 +1113,91 @@ class SyncLogs(DataBaseModel):
|
||||
db_table = "sync_logs"
|
||||
|
||||
|
||||
class EvaluationDataset(DataBaseModel):
|
||||
"""Ground truth dataset for RAG evaluation"""
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
tenant_id = CharField(max_length=32, null=False, index=True, help_text="tenant ID")
|
||||
name = CharField(max_length=255, null=False, index=True, help_text="dataset name")
|
||||
description = TextField(null=True, help_text="dataset description")
|
||||
kb_ids = JSONField(null=False, help_text="knowledge base IDs to evaluate against")
|
||||
created_by = CharField(max_length=32, null=False, index=True, help_text="creator user ID")
|
||||
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
|
||||
update_time = BigIntegerField(null=False, help_text="last update timestamp")
|
||||
status = IntegerField(null=False, default=1, help_text="1=valid, 0=invalid")
|
||||
|
||||
class Meta:
|
||||
db_table = "evaluation_datasets"
|
||||
|
||||
|
||||
class EvaluationCase(DataBaseModel):
|
||||
"""Individual test case in an evaluation dataset"""
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
|
||||
question = TextField(null=False, help_text="test question")
|
||||
reference_answer = TextField(null=True, help_text="optional ground truth answer")
|
||||
relevant_doc_ids = JSONField(null=True, help_text="expected relevant document IDs")
|
||||
relevant_chunk_ids = JSONField(null=True, help_text="expected relevant chunk IDs")
|
||||
metadata = JSONField(null=True, help_text="additional context/tags")
|
||||
create_time = BigIntegerField(null=False, help_text="creation timestamp")
|
||||
|
||||
class Meta:
|
||||
db_table = "evaluation_cases"
|
||||
|
||||
|
||||
class EvaluationRun(DataBaseModel):
|
||||
"""A single evaluation run"""
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
|
||||
dialog_id = CharField(max_length=32, null=False, index=True, help_text="dialog configuration being evaluated")
|
||||
name = CharField(max_length=255, null=False, help_text="run name")
|
||||
config_snapshot = JSONField(null=False, help_text="dialog config at time of evaluation")
|
||||
metrics_summary = JSONField(null=True, help_text="aggregated metrics")
|
||||
status = CharField(max_length=32, null=False, default="PENDING", help_text="PENDING/RUNNING/COMPLETED/FAILED")
|
||||
created_by = CharField(max_length=32, null=False, index=True, help_text="user who started the run")
|
||||
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
|
||||
complete_time = BigIntegerField(null=True, help_text="completion timestamp")
|
||||
|
||||
class Meta:
|
||||
db_table = "evaluation_runs"
|
||||
|
||||
|
||||
class EvaluationResult(DataBaseModel):
|
||||
"""Result for a single test case in an evaluation run"""
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
run_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_runs")
|
||||
case_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_cases")
|
||||
generated_answer = TextField(null=False, help_text="generated answer")
|
||||
retrieved_chunks = JSONField(null=False, help_text="chunks that were retrieved")
|
||||
metrics = JSONField(null=False, help_text="all computed metrics")
|
||||
execution_time = FloatField(null=False, help_text="response time in seconds")
|
||||
token_usage = JSONField(null=True, help_text="prompt/completion tokens")
|
||||
create_time = BigIntegerField(null=False, help_text="creation timestamp")
|
||||
|
||||
class Meta:
|
||||
db_table = "evaluation_results"
|
||||
|
||||
|
||||
class Memory(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
name = CharField(max_length=128, null=False, index=False, help_text="Memory name")
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||
memory_type = IntegerField(null=False, default=1, index=True, help_text="Bit flags (LSB->MSB): 1=raw, 2=semantic, 4=episodic, 8=procedural. E.g., 5 enables raw + episodic.")
|
||||
storage_type = CharField(max_length=32, default='table', null=False, index=True, help_text="table|graph")
|
||||
embd_id = CharField(max_length=128, null=False, index=False, help_text="embedding model ID")
|
||||
llm_id = CharField(max_length=128, null=False, index=False, help_text="chat model ID")
|
||||
permissions = CharField(max_length=16, null=False, index=True, help_text="me|team", default="me")
|
||||
description = TextField(null=True, help_text="description")
|
||||
memory_size = IntegerField(default=5242880, null=False, index=False)
|
||||
forgetting_policy = CharField(max_length=32, null=False, default="fifo", index=False, help_text="lru|fifo")
|
||||
temperature = FloatField(default=0.5, index=False)
|
||||
system_prompt = TextField(null=True, help_text="system prompt", index=False)
|
||||
user_prompt = TextField(null=True, help_text="user prompt", index=False)
|
||||
|
||||
class Meta:
|
||||
db_table = "memory"
|
||||
|
||||
|
||||
def migrate_db():
|
||||
logging.disable(logging.ERROR)
|
||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||
@ -1292,4 +1378,43 @@ def migrate_db():
|
||||
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# RAG Evaluation tables
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "tenant_id", CharField(max_length=32, null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "name", CharField(max_length=255, null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "description", TextField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "kb_ids", JSONField(null=False)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "created_by", CharField(max_length=32, null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "create_time", BigIntegerField(null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "update_time", BigIntegerField(null=False)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1)))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
@ -34,14 +35,17 @@ from common.file_utils import get_project_base_directory
|
||||
from common import settings
|
||||
from api.common.base64 import encode_to_base64
|
||||
|
||||
DEFAULT_SUPERUSER_NICKNAME = os.getenv("DEFAULT_SUPERUSER_NICKNAME", "admin")
|
||||
DEFAULT_SUPERUSER_EMAIL = os.getenv("DEFAULT_SUPERUSER_EMAIL", "admin@ragflow.io")
|
||||
DEFAULT_SUPERUSER_PASSWORD = os.getenv("DEFAULT_SUPERUSER_PASSWORD", "admin")
|
||||
|
||||
def init_superuser():
|
||||
def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_EMAIL, password=DEFAULT_SUPERUSER_PASSWORD, role=UserTenantRole.OWNER):
|
||||
user_info = {
|
||||
"id": uuid.uuid1().hex,
|
||||
"password": encode_to_base64("admin"),
|
||||
"nickname": "admin",
|
||||
"password": encode_to_base64(password),
|
||||
"nickname": nickname,
|
||||
"is_superuser": True,
|
||||
"email": "admin@ragflow.io",
|
||||
"email": email,
|
||||
"creator": "system",
|
||||
"status": "1",
|
||||
}
|
||||
@ -58,7 +62,7 @@ def init_superuser():
|
||||
"tenant_id": user_info["id"],
|
||||
"user_id": user_info["id"],
|
||||
"invited_by": user_info["id"],
|
||||
"role": UserTenantRole.OWNER
|
||||
"role": role
|
||||
}
|
||||
|
||||
tenant_llm = get_init_tenant_llm(user_info["id"])
|
||||
@ -70,11 +74,10 @@ def init_superuser():
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
logging.info(
|
||||
"Super user initialized. email: admin@ragflow.io, password: admin. Changing the password after login is strongly recommended.")
|
||||
f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.")
|
||||
|
||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||
msg = chat_mdl.chat(system="", history=[
|
||||
{"role": "user", "content": "Hello!"}], gen_conf={})
|
||||
msg = asyncio.run(chat_mdl.async_chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}))
|
||||
if msg.find("ERROR: ") == 0:
|
||||
logging.error(
|
||||
"'{}' doesn't work. {}".format(
|
||||
|
||||
@ -153,7 +153,7 @@ def delete_user_data(user_id: str) -> dict:
|
||||
done_msg += "Start to delete owned tenant.\n"
|
||||
tenant_id = owned_tenant[0]["tenant_id"]
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(usr.id)
|
||||
# step1.1 delete knowledgebase related file and info
|
||||
# step1.1 delete dataset related file and info
|
||||
if kb_ids:
|
||||
# step1.1.1 delete files in storage, remove bucket
|
||||
for kb_id in kb_ids:
|
||||
@ -182,7 +182,7 @@ def delete_user_data(user_id: str) -> dict:
|
||||
search.index_name(tenant_id), kb_ids)
|
||||
done_msg += f"- Deleted {r} chunk records.\n"
|
||||
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
|
||||
done_msg += f"- Deleted {kb_delete_res} knowledgebase records.\n"
|
||||
done_msg += f"- Deleted {kb_delete_res} dataset records.\n"
|
||||
# step1.1.4 delete agents
|
||||
agent_delete_res = delete_user_agents(usr.id)
|
||||
done_msg += f"- Deleted {agent_delete_res['agents_deleted_count']} agent, {agent_delete_res['version_deleted_count']} versions records.\n"
|
||||
@ -258,7 +258,7 @@ def delete_user_data(user_id: str) -> dict:
|
||||
# step2.1.5 delete document record
|
||||
doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents])
|
||||
done_msg += f"- Deleted {doc_delete_res} documents.\n"
|
||||
# step2.1.6 update knowledge base doc&chunk&token cnt
|
||||
# step2.1.6 update dataset doc&chunk&token cnt
|
||||
for kb_id, doc_num in kb_doc_info.items():
|
||||
KnowledgebaseService.decrease_document_num_in_delete(kb_id, doc_num)
|
||||
|
||||
@ -273,7 +273,7 @@ def delete_user_data(user_id: str) -> dict:
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return {"success": False, "message": f"Error: {str(e)}. Already done:\n{done_msg}"}
|
||||
return {"success": False, "message": "An internal error occurred during user deletion. Some operations may have completed.","details": done_msg}
|
||||
|
||||
|
||||
def delete_user_agents(user_id: str) -> dict:
|
||||
|
||||
@ -177,7 +177,7 @@ class UserCanvasService(CommonService):
|
||||
return True
|
||||
|
||||
|
||||
def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
async def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
query = kwargs.get("query", "") or kwargs.get("question", "")
|
||||
files = kwargs.get("files", [])
|
||||
inputs = kwargs.get("inputs", {})
|
||||
@ -219,10 +219,14 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
"id": message_id
|
||||
})
|
||||
txt = ""
|
||||
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
ans["session_id"] = session_id
|
||||
if ans["event"] == "message":
|
||||
txt += ans["data"]["content"]
|
||||
if ans["data"].get("start_to_think", False):
|
||||
txt += "<think>"
|
||||
elif ans["data"].get("end_to_think", False):
|
||||
txt += "</think>"
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
conv.message.append({"role": "assistant", "content": txt, "created_at": time.time(), "id": message_id})
|
||||
@ -233,7 +237,7 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
API4ConversationService.append_message(conv["id"], conv)
|
||||
|
||||
|
||||
def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
|
||||
async def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
|
||||
tiktoken_encoder = tiktoken.get_encoding("cl100k_base")
|
||||
prompt_tokens = len(tiktoken_encoder.encode(str(question)))
|
||||
user_id = kwargs.get("user_id", "")
|
||||
@ -241,7 +245,7 @@ def completion_openai(tenant_id, agent_id, question, session_id=None, stream=Tru
|
||||
if stream:
|
||||
completion_tokens = 0
|
||||
try:
|
||||
for ans in completion(
|
||||
async for ans in completion(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
@ -300,7 +304,7 @@ def completion_openai(tenant_id, agent_id, question, session_id=None, stream=Tru
|
||||
try:
|
||||
all_content = ""
|
||||
reference = {}
|
||||
for ans in completion(
|
||||
async for ans in completion(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
|
||||
@ -169,10 +169,12 @@ class CommonService:
|
||||
"""
|
||||
if "id" not in kwargs:
|
||||
kwargs["id"] = get_uuid()
|
||||
kwargs["create_time"] = current_timestamp()
|
||||
kwargs["create_date"] = datetime_format(datetime.now())
|
||||
kwargs["update_time"] = current_timestamp()
|
||||
kwargs["update_date"] = datetime_format(datetime.now())
|
||||
timestamp = current_timestamp()
|
||||
cur_datetime = datetime_format(datetime.now())
|
||||
kwargs["create_time"] = timestamp
|
||||
kwargs["create_date"] = cur_datetime
|
||||
kwargs["update_time"] = timestamp
|
||||
kwargs["update_date"] = cur_datetime
|
||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return sample_obj
|
||||
|
||||
@ -207,10 +209,14 @@ class CommonService:
|
||||
data_list (list): List of dictionaries containing record data to update.
|
||||
Each dictionary must include an 'id' field.
|
||||
"""
|
||||
|
||||
timestamp = current_timestamp()
|
||||
cur_datetime = datetime_format(datetime.now())
|
||||
for data in data_list:
|
||||
data["update_time"] = timestamp
|
||||
data["update_date"] = cur_datetime
|
||||
with DB.atomic():
|
||||
for data in data_list:
|
||||
data["update_time"] = current_timestamp()
|
||||
data["update_date"] = datetime_format(datetime.now())
|
||||
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import os
|
||||
from typing import Tuple, List
|
||||
|
||||
from anthropic import BaseModel
|
||||
@ -24,7 +25,6 @@ from api.db import InputType
|
||||
from api.db.db_models import Connector, SyncLogs, Connector2Kb, Knowledgebase
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import TaskStatus
|
||||
from common.time_utils import current_timestamp, timestamp_to_date
|
||||
@ -68,9 +68,10 @@ class ConnectorService(CommonService):
|
||||
|
||||
@classmethod
|
||||
def rebuild(cls, kb_id:str, connector_id: str, tenant_id:str):
|
||||
from api.db.services.file_service import FileService
|
||||
e, conn = cls.get_by_id(connector_id)
|
||||
if not e:
|
||||
return
|
||||
return None
|
||||
SyncLogsService.filter_delete([SyncLogs.connector_id==connector_id, SyncLogs.kb_id==kb_id])
|
||||
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}", kb_id=kb_id)
|
||||
err = FileService.delete_docs([d.id for d in docs], tenant_id)
|
||||
@ -103,7 +104,8 @@ class SyncLogsService(CommonService):
|
||||
Knowledgebase.avatar.alias("kb_avatar"),
|
||||
Connector2Kb.auto_parse,
|
||||
cls.model.from_beginning.alias("reindex"),
|
||||
cls.model.status
|
||||
cls.model.status,
|
||||
cls.model.update_time
|
||||
]
|
||||
if not connector_id:
|
||||
fields.append(Connector.config)
|
||||
@ -116,7 +118,11 @@ class SyncLogsService(CommonService):
|
||||
if connector_id:
|
||||
query = query.where(cls.model.connector_id == connector_id)
|
||||
else:
|
||||
interval_expr = SQL("INTERVAL `t2`.`refresh_freq` MINUTE")
|
||||
database_type = os.getenv("DB_TYPE", "mysql")
|
||||
if "postgres" in database_type.lower():
|
||||
interval_expr = SQL("make_interval(mins => t2.refresh_freq)")
|
||||
else:
|
||||
interval_expr = SQL("INTERVAL `t2`.`refresh_freq` MINUTE")
|
||||
query = query.where(
|
||||
Connector.input_type == InputType.POLL,
|
||||
Connector.status == TaskStatus.SCHEDULE,
|
||||
@ -125,11 +131,11 @@ class SyncLogsService(CommonService):
|
||||
)
|
||||
|
||||
query = query.distinct().order_by(cls.model.update_time.desc())
|
||||
totbal = query.count()
|
||||
total = query.count()
|
||||
if page_number:
|
||||
query = query.paginate(page_number, items_per_page)
|
||||
|
||||
return list(query.dicts()), totbal
|
||||
return list(query.dicts()), total
|
||||
|
||||
@classmethod
|
||||
def start(cls, id, connector_id):
|
||||
@ -191,6 +197,7 @@ class SyncLogsService(CommonService):
|
||||
|
||||
@classmethod
|
||||
def duplicate_and_parse(cls, kb, docs, tenant_id, src, auto_parse=True):
|
||||
from api.db.services.file_service import FileService
|
||||
if not docs:
|
||||
return None
|
||||
|
||||
@ -207,9 +214,21 @@ class SyncLogsService(CommonService):
|
||||
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
|
||||
errs.extend(err)
|
||||
|
||||
# Create a mapping from filename to metadata for later use
|
||||
metadata_map = {}
|
||||
for d in docs:
|
||||
if d.get("metadata"):
|
||||
filename = d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else "")
|
||||
metadata_map[filename] = d["metadata"]
|
||||
|
||||
kb_table_num_map = {}
|
||||
for doc, _ in doc_blob_pairs:
|
||||
doc_ids.append(doc["id"])
|
||||
|
||||
# Set metadata if available for this document
|
||||
if doc["name"] in metadata_map:
|
||||
DocumentService.update_by_id(doc["id"], {"meta_fields": metadata_map[doc["name"]]})
|
||||
|
||||
if not auto_parse or auto_parse == "0":
|
||||
continue
|
||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||
@ -242,7 +261,7 @@ class Connector2KbService(CommonService):
|
||||
"id": get_uuid(),
|
||||
"connector_id": conn_id,
|
||||
"kb_id": kb_id,
|
||||
"auto_parse": conn.get("auto_parse", "1")
|
||||
"auto_parse": conn.get("auto_parse", "1")
|
||||
})
|
||||
SyncLogsService.schedule(conn_id, kb_id, reindex=True)
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ from common.constants import StatusEnum
|
||||
from api.db.db_models import Conversation, DB
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.dialog_service import DialogService, chat
|
||||
from api.db.services.dialog_service import DialogService, async_chat
|
||||
from common.misc_utils import get_uuid
|
||||
import json
|
||||
|
||||
@ -89,8 +89,7 @@ def structure_answer(conv, ans, message_id, session_id):
|
||||
conv.reference[-1] = reference
|
||||
return ans
|
||||
|
||||
|
||||
def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
||||
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
||||
assert name, "`name` can not be empty."
|
||||
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||
assert dia, "You do not own the chat."
|
||||
@ -112,7 +111,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
|
||||
"reference": {},
|
||||
"audio_binary": None,
|
||||
"id": None,
|
||||
"session_id": session_id
|
||||
"session_id": session_id
|
||||
}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
@ -148,7 +147,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
|
||||
|
||||
if stream:
|
||||
try:
|
||||
for ans in chat(dia, msg, True, **kwargs):
|
||||
async for ans in async_chat(dia, msg, True, **kwargs):
|
||||
ans = structure_answer(conv, ans, message_id, session_id)
|
||||
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
@ -160,14 +159,13 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
|
||||
|
||||
else:
|
||||
answer = None
|
||||
for ans in chat(dia, msg, False, **kwargs):
|
||||
async for ans in async_chat(dia, msg, False, **kwargs):
|
||||
answer = structure_answer(conv, ans, message_id, session_id)
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
break
|
||||
yield answer
|
||||
|
||||
|
||||
def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
|
||||
async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
|
||||
e, dia = DialogService.get_by_id(dialog_id)
|
||||
assert e, "Dialog not found"
|
||||
if not session_id:
|
||||
@ -222,7 +220,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
|
||||
|
||||
if stream:
|
||||
try:
|
||||
for ans in chat(dia, msg, True, **kwargs):
|
||||
async for ans in async_chat(dia, msg, True, **kwargs):
|
||||
ans = structure_answer(conv, ans, message_id, session_id)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
@ -235,7 +233,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
|
||||
|
||||
else:
|
||||
answer = None
|
||||
for ans in chat(dia, msg, False, **kwargs):
|
||||
async for ans in async_chat(dia, msg, False, **kwargs):
|
||||
answer = structure_answer(conv, ans, message_id, session_id)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
break
|
||||
|
||||
@ -21,10 +21,10 @@ from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from timeit import default_timer as timer
|
||||
import trio
|
||||
from langfuse import Langfuse
|
||||
from peewee import fn
|
||||
from agentic_reasoning import DeepResearcher
|
||||
from api.db.services.file_service import FileService
|
||||
from common.constants import LLMType, ParserType, StatusEnum
|
||||
from api.db.db_models import DB, Dialog
|
||||
from api.db.services.common_service import CommonService
|
||||
@ -32,6 +32,7 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
@ -39,7 +40,7 @@ from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp.search import index_name
|
||||
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||
gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||
PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
@ -177,7 +178,11 @@ class DialogService(CommonService):
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
def chat_solo(dialog, messages, stream=True):
|
||||
|
||||
async def async_chat_solo(dialog, messages, stream=True):
|
||||
attachments = ""
|
||||
if "files" in messages[-1]:
|
||||
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
@ -188,10 +193,13 @@ def chat_solo(dialog, messages, stream=True):
|
||||
if prompt_config.get("tts"):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
|
||||
if attachments and msg:
|
||||
msg[-1]["content"] += attachments
|
||||
if stream:
|
||||
last_ans = ""
|
||||
delta_ans = ""
|
||||
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||
answer = ""
|
||||
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
@ -202,7 +210,7 @@ def chat_solo(dialog, messages, stream=True):
|
||||
if delta_ans:
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||
else:
|
||||
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
||||
@ -270,77 +278,10 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
||||
return answer, idx
|
||||
|
||||
|
||||
def convert_conditions(metadata_condition):
|
||||
if metadata_condition is None:
|
||||
metadata_condition = {}
|
||||
op_mapping = {
|
||||
"is": "=",
|
||||
"not is": "≠"
|
||||
}
|
||||
return [
|
||||
{
|
||||
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
||||
"key": cond["name"],
|
||||
"value": cond["value"]
|
||||
}
|
||||
for cond in metadata_condition.get("conditions", [])
|
||||
]
|
||||
|
||||
|
||||
def meta_filter(metas: dict, filters: list[dict]):
|
||||
doc_ids = set([])
|
||||
|
||||
def filter_out(v2docs, operator, value):
|
||||
ids = []
|
||||
for input, docids in v2docs.items():
|
||||
if operator in ["=", "≠", ">", "<", "≥", "≤"]:
|
||||
try:
|
||||
input = float(input)
|
||||
value = float(value)
|
||||
except Exception:
|
||||
input = str(input)
|
||||
value = str(value)
|
||||
|
||||
for conds in [
|
||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||
(operator == "empty", not input),
|
||||
(operator == "not empty", input),
|
||||
(operator == "=", input == value),
|
||||
(operator == "≠", input != value),
|
||||
(operator == ">", input > value),
|
||||
(operator == "<", input < value),
|
||||
(operator == "≥", input >= value),
|
||||
(operator == "≤", input <= value),
|
||||
]:
|
||||
try:
|
||||
if all(conds):
|
||||
ids.extend(docids)
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
return ids
|
||||
|
||||
for k, v2docs in metas.items():
|
||||
for f in filters:
|
||||
if k != f["key"]:
|
||||
continue
|
||||
ids = filter_out(v2docs, f["op"], f["value"])
|
||||
if not doc_ids:
|
||||
doc_ids = set(ids)
|
||||
else:
|
||||
doc_ids = doc_ids & set(ids)
|
||||
if not doc_ids:
|
||||
return []
|
||||
return list(doc_ids)
|
||||
|
||||
|
||||
def chat(dialog, messages, stream=True, **kwargs):
|
||||
async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||
for ans in chat_solo(dialog, messages, stream):
|
||||
async for ans in async_chat_solo(dialog, messages, stream):
|
||||
yield ans
|
||||
return
|
||||
|
||||
@ -375,15 +316,18 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
retriever = settings.retriever
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
||||
attachments_= ""
|
||||
if "doc_ids" in messages[-1]:
|
||||
attachments = messages[-1]["doc_ids"]
|
||||
if "files" in messages[-1]:
|
||||
attachments_ = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
# try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||
if ans:
|
||||
yield ans
|
||||
return
|
||||
@ -397,27 +341,25 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
||||
|
||||
if len(questions) > 1 and prompt_config.get("refine_multiturn"):
|
||||
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
||||
questions = [await full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
||||
else:
|
||||
questions = questions[-1:]
|
||||
|
||||
if prompt_config.get("cross_languages"):
|
||||
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
||||
questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
||||
|
||||
if dialog.meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
|
||||
if dialog.meta_data_filter.get("method") == "auto":
|
||||
filters = gen_meta_filter(chat_mdl, metas, questions[-1])
|
||||
attachments.extend(meta_filter(metas, filters))
|
||||
if not attachments:
|
||||
attachments = None
|
||||
elif dialog.meta_data_filter.get("method") == "manual":
|
||||
attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"]))
|
||||
if not attachments:
|
||||
attachments = None
|
||||
attachments = await apply_meta_data_filter(
|
||||
dialog.meta_data_filter,
|
||||
metas,
|
||||
questions[-1],
|
||||
chat_mdl,
|
||||
attachments,
|
||||
)
|
||||
|
||||
if prompt_config.get("keyword", False):
|
||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||||
questions[-1] += await keyword_extraction(chat_mdl, questions[-1])
|
||||
|
||||
refine_question_ts = timer()
|
||||
|
||||
@ -445,7 +387,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
),
|
||||
)
|
||||
|
||||
for think in reasoner.thinking(kbinfos, " ".join(questions)):
|
||||
async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
|
||||
if isinstance(think, str):
|
||||
thought = think
|
||||
knowledges = [t for t in think.split("\n") if t]
|
||||
@ -472,6 +414,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids)
|
||||
if prompt_config.get("tavily_api_key"):
|
||||
tav = Tavily(prompt_config["tavily_api_key"])
|
||||
tav_res = tav.retrieve_chunks(" ".join(questions))
|
||||
@ -492,12 +435,13 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
empty_res = prompt_config["empty_response"]
|
||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||
"audio_binary": tts(tts_mdl, empty_res)}
|
||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
return
|
||||
|
||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||
gen_conf = dialog.llm_setting
|
||||
|
||||
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
||||
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}]
|
||||
prompt4citation = ""
|
||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||
prompt4citation = citation_prompt()
|
||||
@ -596,7 +540,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if stream:
|
||||
last_ans = ""
|
||||
answer = ""
|
||||
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
||||
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
||||
if thought:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
answer = ans
|
||||
@ -610,17 +554,19 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
yield decorate_answer(thought + answer)
|
||||
else:
|
||||
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||
res = decorate_answer(answer)
|
||||
res["audio_binary"] = tts(tts_mdl, answer)
|
||||
yield res
|
||||
|
||||
return
|
||||
|
||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||
|
||||
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||
sys_prompt = """
|
||||
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
||||
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
||||
Ensure that:
|
||||
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
|
||||
2. Write only the SQL, no explanations or additional text.
|
||||
@ -636,9 +582,9 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
|
||||
tried_times = 0
|
||||
|
||||
def get_table():
|
||||
async def get_table():
|
||||
nonlocal sys_prompt, user_prompt, question, tried_times
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
|
||||
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||
@ -664,7 +610,11 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
if kb_ids:
|
||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||
if "where" not in sql.lower():
|
||||
sql += f" WHERE {kb_filter}"
|
||||
o = sql.lower().split("order by")
|
||||
if len(o) > 1:
|
||||
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
|
||||
else:
|
||||
sql += f" WHERE {kb_filter}"
|
||||
else:
|
||||
sql += f" AND {kb_filter}"
|
||||
|
||||
@ -672,10 +622,9 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
tried_times += 1
|
||||
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
||||
|
||||
tbl, sql = get_table()
|
||||
if tbl is None:
|
||||
return None
|
||||
if tbl.get("error") and tried_times <= 2:
|
||||
try:
|
||||
tbl, sql = await get_table()
|
||||
except Exception as e:
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
Table of database fields are as follows:
|
||||
@ -689,16 +638,14 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
The SQL error you provided last time is as follows:
|
||||
{}
|
||||
|
||||
Error issued by database as follows:
|
||||
{}
|
||||
|
||||
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
|
||||
tbl, sql = get_table()
|
||||
logging.debug("TRY it again: {}".format(sql))
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
|
||||
try:
|
||||
tbl, sql = await get_table()
|
||||
except Exception:
|
||||
return
|
||||
|
||||
logging.debug("GET table: {}".format(tbl))
|
||||
if tbl.get("error") or len(tbl["rows"]) == 0:
|
||||
if len(tbl["rows"]) == 0:
|
||||
return None
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||
@ -742,17 +689,51 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
"prompt": sys_prompt,
|
||||
}
|
||||
|
||||
def clean_tts_text(text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||||
|
||||
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||
|
||||
emoji_pattern = re.compile(
|
||||
"[\U0001F600-\U0001F64F"
|
||||
"\U0001F300-\U0001F5FF"
|
||||
"\U0001F680-\U0001F6FF"
|
||||
"\U0001F1E0-\U0001F1FF"
|
||||
"\U00002700-\U000027BF"
|
||||
"\U0001F900-\U0001F9FF"
|
||||
"\U0001FA70-\U0001FAFF"
|
||||
"\U0001FAD0-\U0001FAFF]+",
|
||||
flags=re.UNICODE
|
||||
)
|
||||
text = emoji_pattern.sub("", text)
|
||||
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
MAX_LEN = 500
|
||||
if len(text) > MAX_LEN:
|
||||
text = text[:MAX_LEN]
|
||||
|
||||
return text
|
||||
|
||||
def tts(tts_mdl, text):
|
||||
if not tts_mdl or not text:
|
||||
return
|
||||
return None
|
||||
text = clean_tts_text(text)
|
||||
if not text:
|
||||
return None
|
||||
bin = b""
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
try:
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
except Exception as e:
|
||||
logging.error(f"TTS failed: {e}, text={text!r}")
|
||||
return None
|
||||
return binascii.hexlify(bin).decode("utf-8")
|
||||
|
||||
|
||||
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
doc_ids = search_config.get("doc_ids", [])
|
||||
rerank_mdl = None
|
||||
kb_ids = search_config.get("kb_ids", kb_ids)
|
||||
@ -775,15 +756,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||||
|
||||
kbinfos = retriever.retrieval(
|
||||
question=question,
|
||||
@ -826,13 +799,13 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
return {"answer": answer, "reference": refs}
|
||||
|
||||
answer = ""
|
||||
for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
||||
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
||||
answer = ans
|
||||
yield {"answer": answer, "reference": {}}
|
||||
yield decorate_answer(answer)
|
||||
|
||||
|
||||
def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
doc_ids = search_config.get("doc_ids", [])
|
||||
rerank_id = search_config.get("rerank_id", "")
|
||||
@ -850,15 +823,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||||
|
||||
ranks = settings.retriever.retrieval(
|
||||
question=question,
|
||||
@ -876,5 +841,5 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
rank_feature=label_question(question, kbs),
|
||||
)
|
||||
mindmap = MindMapExtractor(chat_mdl)
|
||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
|
||||
mind_map = await mindmap([c["content_with_weight"] for c in ranks["chunks"]])
|
||||
return mind_map.output
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
@ -22,7 +23,6 @@ from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
|
||||
import trio
|
||||
import xxhash
|
||||
from peewee import fn, Case, JOIN
|
||||
|
||||
@ -41,6 +41,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common import settings
|
||||
|
||||
|
||||
class DocumentService(CommonService):
|
||||
model = Document
|
||||
|
||||
@ -78,7 +79,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords, id, name, suffix=None, run = None):
|
||||
orderby, desc, keywords, id, name, suffix=None, run = None, doc_ids=None):
|
||||
fields = cls.get_cls_model_fields()
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\
|
||||
.join(File, on = (File.id == File2Document.file_id))\
|
||||
@ -95,6 +96,8 @@ class DocumentService(CommonService):
|
||||
docs = docs.where(
|
||||
fn.LOWER(cls.model.name).contains(keywords.lower())
|
||||
)
|
||||
if doc_ids:
|
||||
docs = docs.where(cls.model.id.in_(doc_ids))
|
||||
if suffix:
|
||||
docs = docs.where(cls.model.suffix.in_(suffix))
|
||||
if run:
|
||||
@ -113,7 +116,7 @@ class DocumentService(CommonService):
|
||||
def check_doc_health(cls, tenant_id: str, filename):
|
||||
import os
|
||||
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
|
||||
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(tenant_id) >= MAX_FILE_NUM_PER_USER:
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(tenant_id):
|
||||
raise RuntimeError("Exceed the maximum file number of a free user!")
|
||||
if len(filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
raise RuntimeError("Exceed the maximum length of file name!")
|
||||
@ -122,7 +125,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords, run_status, types, suffix):
|
||||
orderby, desc, keywords, run_status, types, suffix, doc_ids=None):
|
||||
fields = cls.get_cls_model_fields()
|
||||
if keywords:
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
|
||||
@ -142,6 +145,8 @@ class DocumentService(CommonService):
|
||||
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(cls.model.kb_id == kb_id)
|
||||
|
||||
if doc_ids:
|
||||
docs = docs.where(cls.model.id.in_(doc_ids))
|
||||
if run_status:
|
||||
docs = docs.where(cls.model.run.in_(run_status))
|
||||
if types:
|
||||
@ -175,6 +180,16 @@ class DocumentService(CommonService):
|
||||
"1": 2,
|
||||
"2": 2
|
||||
}
|
||||
"metadata": {
|
||||
"key1": {
|
||||
"key1_value1": 1,
|
||||
"key1_value2": 2,
|
||||
},
|
||||
"key2": {
|
||||
"key2_value1": 2,
|
||||
"key2_value2": 1,
|
||||
},
|
||||
}
|
||||
}, total
|
||||
where "1" => RUNNING, "2" => CANCEL
|
||||
"""
|
||||
@ -195,19 +210,40 @@ class DocumentService(CommonService):
|
||||
if suffix:
|
||||
query = query.where(cls.model.suffix.in_(suffix))
|
||||
|
||||
rows = query.select(cls.model.run, cls.model.suffix)
|
||||
rows = query.select(cls.model.run, cls.model.suffix, cls.model.meta_fields)
|
||||
total = rows.count()
|
||||
|
||||
suffix_counter = {}
|
||||
run_status_counter = {}
|
||||
metadata_counter = {}
|
||||
|
||||
for row in rows:
|
||||
suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1
|
||||
run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1
|
||||
meta_fields = row.meta_fields or {}
|
||||
if isinstance(meta_fields, str):
|
||||
try:
|
||||
meta_fields = json.loads(meta_fields)
|
||||
except Exception:
|
||||
meta_fields = {}
|
||||
if not isinstance(meta_fields, dict):
|
||||
continue
|
||||
for key, value in meta_fields.items():
|
||||
values = value if isinstance(value, list) else [value]
|
||||
for vv in values:
|
||||
if vv is None:
|
||||
continue
|
||||
if isinstance(vv, str) and not vv.strip():
|
||||
continue
|
||||
sv = str(vv)
|
||||
if key not in metadata_counter:
|
||||
metadata_counter[key] = {}
|
||||
metadata_counter[key][sv] = metadata_counter[key].get(sv, 0) + 1
|
||||
|
||||
return {
|
||||
"suffix": suffix_counter,
|
||||
"run_status": run_status_counter
|
||||
"run_status": run_status_counter,
|
||||
"metadata": metadata_counter,
|
||||
}, total
|
||||
|
||||
@classmethod
|
||||
@ -309,7 +345,7 @@ class DocumentService(CommonService):
|
||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
page * page_size, page_size, search.index_name(tenant_id),
|
||||
[doc.kb_id])
|
||||
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
|
||||
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
|
||||
if not chunk_ids:
|
||||
break
|
||||
all_chunk_ids.extend(chunk_ids)
|
||||
@ -322,7 +358,7 @@ class DocumentService(CommonService):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
graph_source = settings.docStoreConn.getFields(
|
||||
graph_source = settings.docStoreConn.get_fields(
|
||||
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||
)
|
||||
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
|
||||
@ -464,7 +500,7 @@ class DocumentService(CommonService):
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["tenant_id"]
|
||||
|
||||
@classmethod
|
||||
@ -473,7 +509,7 @@ class DocumentService(CommonService):
|
||||
docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["kb_id"]
|
||||
|
||||
@classmethod
|
||||
@ -486,7 +522,7 @@ class DocumentService(CommonService):
|
||||
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["tenant_id"]
|
||||
|
||||
@classmethod
|
||||
@ -533,7 +569,7 @@ class DocumentService(CommonService):
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["embd_id"]
|
||||
|
||||
@classmethod
|
||||
@ -569,7 +605,7 @@ class DocumentService(CommonService):
|
||||
.where(cls.model.name == doc_name)
|
||||
doc_id = doc_id.dicts()
|
||||
if not doc_id:
|
||||
return
|
||||
return None
|
||||
return doc_id[0]["id"]
|
||||
|
||||
@classmethod
|
||||
@ -643,6 +679,13 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_meta_by_kbs(cls, kb_ids):
|
||||
"""
|
||||
Legacy metadata aggregator (backward-compatible).
|
||||
- Does NOT expand list values and a list is kept as one string key.
|
||||
Example: {"tags": ["foo","bar"]} -> meta["tags"]["['foo', 'bar']"] = [doc_id]
|
||||
- Expects meta_fields is a dict.
|
||||
Use when existing callers rely on the old list-as-string semantics.
|
||||
"""
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.meta_fields,
|
||||
@ -659,6 +702,171 @@ class DocumentService(CommonService):
|
||||
meta[k][v].append(doc_id)
|
||||
return meta
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_flatted_meta_by_kbs(cls, kb_ids):
|
||||
"""
|
||||
- Parses stringified JSON meta_fields when possible and skips non-dict or unparsable values.
|
||||
- Expands list values into individual entries.
|
||||
Example: {"tags": ["foo","bar"], "author": "alice"} ->
|
||||
meta["tags"]["foo"] = [doc_id], meta["tags"]["bar"] = [doc_id], meta["author"]["alice"] = [doc_id]
|
||||
Prefer for metadata_condition filtering and scenarios that must respect list semantics.
|
||||
"""
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.meta_fields,
|
||||
]
|
||||
meta = {}
|
||||
for r in cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids)):
|
||||
doc_id = r.id
|
||||
meta_fields = r.meta_fields or {}
|
||||
if isinstance(meta_fields, str):
|
||||
try:
|
||||
meta_fields = json.loads(meta_fields)
|
||||
except Exception:
|
||||
continue
|
||||
if not isinstance(meta_fields, dict):
|
||||
continue
|
||||
for k, v in meta_fields.items():
|
||||
if k not in meta:
|
||||
meta[k] = {}
|
||||
values = v if isinstance(v, list) else [v]
|
||||
for vv in values:
|
||||
if vv is None:
|
||||
continue
|
||||
sv = str(vv)
|
||||
if sv not in meta[k]:
|
||||
meta[k][sv] = []
|
||||
meta[k][sv].append(doc_id)
|
||||
return meta
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_metadata_summary(cls, kb_id):
|
||||
fields = [cls.model.id, cls.model.meta_fields]
|
||||
summary = {}
|
||||
for r in cls.model.select(*fields).where(cls.model.kb_id == kb_id):
|
||||
meta_fields = r.meta_fields or {}
|
||||
if isinstance(meta_fields, str):
|
||||
try:
|
||||
meta_fields = json.loads(meta_fields)
|
||||
except Exception:
|
||||
continue
|
||||
if not isinstance(meta_fields, dict):
|
||||
continue
|
||||
for k, v in meta_fields.items():
|
||||
values = v if isinstance(v, list) else [v]
|
||||
for vv in values:
|
||||
if not vv:
|
||||
continue
|
||||
sv = str(vv)
|
||||
if k not in summary:
|
||||
summary[k] = {}
|
||||
summary[k][sv] = summary[k].get(sv, 0) + 1
|
||||
return {k: sorted([(val, cnt) for val, cnt in v.items()], key=lambda x: x[1], reverse=True) for k, v in summary.items()}
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def batch_update_metadata(cls, kb_id, doc_ids, updates=None, deletes=None):
|
||||
updates = updates or []
|
||||
deletes = deletes or []
|
||||
if not doc_ids:
|
||||
return 0
|
||||
|
||||
def _normalize_meta(meta):
|
||||
if isinstance(meta, str):
|
||||
try:
|
||||
meta = json.loads(meta)
|
||||
except Exception:
|
||||
return {}
|
||||
if not isinstance(meta, dict):
|
||||
return {}
|
||||
return deepcopy(meta)
|
||||
|
||||
def _str_equal(a, b):
|
||||
return str(a) == str(b)
|
||||
|
||||
def _apply_updates(meta):
|
||||
changed = False
|
||||
for upd in updates:
|
||||
key = upd.get("key")
|
||||
if not key or key not in meta:
|
||||
continue
|
||||
|
||||
new_value = upd.get("value")
|
||||
match_provided = "match" in upd
|
||||
if isinstance(meta[key], list):
|
||||
if not match_provided:
|
||||
meta[key] = new_value
|
||||
changed = True
|
||||
else:
|
||||
match_value = upd.get("match")
|
||||
replaced = False
|
||||
new_list = []
|
||||
for item in meta[key]:
|
||||
if _str_equal(item, match_value):
|
||||
new_list.append(new_value)
|
||||
replaced = True
|
||||
else:
|
||||
new_list.append(item)
|
||||
if replaced:
|
||||
meta[key] = new_list
|
||||
changed = True
|
||||
else:
|
||||
if not match_provided:
|
||||
meta[key] = new_value
|
||||
changed = True
|
||||
else:
|
||||
match_value = upd.get("match")
|
||||
if _str_equal(meta[key], match_value):
|
||||
meta[key] = new_value
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
def _apply_deletes(meta):
|
||||
changed = False
|
||||
for d in deletes:
|
||||
key = d.get("key")
|
||||
if not key or key not in meta:
|
||||
continue
|
||||
value = d.get("value", None)
|
||||
if isinstance(meta[key], list):
|
||||
if value is None:
|
||||
del meta[key]
|
||||
changed = True
|
||||
continue
|
||||
new_list = [item for item in meta[key] if not _str_equal(item, value)]
|
||||
if len(new_list) != len(meta[key]):
|
||||
if new_list:
|
||||
meta[key] = new_list
|
||||
else:
|
||||
del meta[key]
|
||||
changed = True
|
||||
else:
|
||||
if value is None or _str_equal(meta[key], value):
|
||||
del meta[key]
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
updated_docs = 0
|
||||
with DB.atomic():
|
||||
rows = cls.model.select(cls.model.id, cls.model.meta_fields).where(
|
||||
(cls.model.id.in_(doc_ids)) & (cls.model.kb_id == kb_id)
|
||||
)
|
||||
for r in rows:
|
||||
meta = _normalize_meta(r.meta_fields or {})
|
||||
original_meta = deepcopy(meta)
|
||||
changed = _apply_updates(meta)
|
||||
changed = _apply_deletes(meta) or changed
|
||||
if changed and meta != original_meta:
|
||||
cls.model.update(
|
||||
meta_fields=meta,
|
||||
update_time=current_timestamp(),
|
||||
update_date=get_format_time()
|
||||
).where(cls.model.id == r.id).execute()
|
||||
updated_docs += 1
|
||||
return updated_docs
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_progress(cls):
|
||||
@ -715,13 +923,17 @@ class DocumentService(CommonService):
|
||||
prg = 1
|
||||
status = TaskStatus.DONE.value
|
||||
|
||||
# only for special task and parsed docs and unfinised
|
||||
# only for special task and parsed docs and unfinished
|
||||
freeze_progress = special_task_running and doc_progress >= 1 and not finished
|
||||
msg = "\n".join(sorted(msg))
|
||||
begin_at = d.get("process_begin_at")
|
||||
if not begin_at:
|
||||
begin_at = datetime.now()
|
||||
# fallback
|
||||
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
|
||||
|
||||
info = {
|
||||
"process_duration": datetime.timestamp(
|
||||
datetime.now()) -
|
||||
d["process_begin_at"].timestamp(),
|
||||
"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0),
|
||||
"run": status}
|
||||
if prg != 0 and not freeze_progress:
|
||||
info["progress"] = prg
|
||||
@ -901,12 +1113,12 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not dia.kb_ids:
|
||||
raise LookupError("No knowledge base associated with this conversation. "
|
||||
"Please add a knowledge base before uploading documents")
|
||||
raise LookupError("No dataset associated with this conversation. "
|
||||
"Please add a dataset before uploading documents")
|
||||
kb_id = dia.kb_ids[0]
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
raise LookupError("Can't find this knowledgebase!")
|
||||
raise LookupError("Can't find this dataset!")
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
||||
|
||||
@ -922,7 +1134,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
ParserType.AUDIO.value: audio,
|
||||
ParserType.EMAIL.value: email
|
||||
}
|
||||
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
|
||||
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0}
|
||||
exe = ThreadPoolExecutor(max_workers=12)
|
||||
threads = []
|
||||
doc_nm = {}
|
||||
@ -974,13 +1186,13 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
|
||||
def embedding(doc_id, cnts, batch_size=16):
|
||||
nonlocal embd_mdl, chunk_counts, token_counts
|
||||
vects = []
|
||||
vectors = []
|
||||
for i in range(0, len(cnts), batch_size):
|
||||
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
|
||||
vects.extend(vts.tolist())
|
||||
vectors.extend(vts.tolist())
|
||||
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
|
||||
token_counts[doc_id] += c
|
||||
return vects
|
||||
return vectors
|
||||
|
||||
idxnm = search.index_name(kb.tenant_id)
|
||||
try_create_idx = True
|
||||
@ -994,7 +1206,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
mindmap = MindMapExtractor(llm_bdl)
|
||||
try:
|
||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
|
||||
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
|
||||
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
|
||||
if len(mind_map) < 32:
|
||||
raise Exception("Few content: " + mind_map)
|
||||
@ -1011,15 +1223,15 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
except Exception:
|
||||
logging.exception("Mind map generation error")
|
||||
|
||||
vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
|
||||
assert len(cks) == len(vects)
|
||||
vectors = embedding(doc_id, [c["content_with_weight"] for c in cks])
|
||||
assert len(cks) == len(vectors)
|
||||
for i, d in enumerate(cks):
|
||||
v = vects[i]
|
||||
v = vectors[i]
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0]))
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
|
||||
637
api/db/services/evaluation_service.py
Normal file
637
api/db/services/evaluation_service.py
Normal file
@ -0,0 +1,637 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
RAG Evaluation Service
|
||||
|
||||
Provides functionality for evaluating RAG system performance including:
|
||||
- Dataset management
|
||||
- Test case management
|
||||
- Evaluation execution
|
||||
- Metrics computation
|
||||
- Configuration recommendations
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp
|
||||
from common.constants import StatusEnum
|
||||
|
||||
|
||||
class EvaluationService(CommonService):
|
||||
"""Service for managing RAG evaluations"""
|
||||
|
||||
model = EvaluationDataset
|
||||
|
||||
# ==================== Dataset Management ====================
|
||||
|
||||
@classmethod
|
||||
def create_dataset(cls, name: str, description: str, kb_ids: List[str],
|
||||
tenant_id: str, user_id: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Create a new evaluation dataset.
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
description: Dataset description
|
||||
kb_ids: List of knowledge base IDs to evaluate against
|
||||
tenant_id: Tenant ID
|
||||
user_id: User ID who creates the dataset
|
||||
|
||||
Returns:
|
||||
(success, dataset_id or error_message)
|
||||
"""
|
||||
try:
|
||||
dataset_id = get_uuid()
|
||||
dataset = {
|
||||
"id": dataset_id,
|
||||
"tenant_id": tenant_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"kb_ids": kb_ids,
|
||||
"created_by": user_id,
|
||||
"create_time": current_timestamp(),
|
||||
"update_time": current_timestamp(),
|
||||
"status": StatusEnum.VALID.value
|
||||
}
|
||||
|
||||
if not EvaluationDataset.create(**dataset):
|
||||
return False, "Failed to create dataset"
|
||||
|
||||
return True, dataset_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error creating evaluation dataset: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@classmethod
|
||||
def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get dataset by ID"""
|
||||
try:
|
||||
dataset = EvaluationDataset.get_by_id(dataset_id)
|
||||
if dataset:
|
||||
return dataset.to_dict()
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting dataset {dataset_id}: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def list_datasets(cls, tenant_id: str, user_id: str,
|
||||
page: int = 1, page_size: int = 20) -> Dict[str, Any]:
|
||||
"""List datasets for a tenant"""
|
||||
try:
|
||||
query = EvaluationDataset.select().where(
|
||||
(EvaluationDataset.tenant_id == tenant_id) &
|
||||
(EvaluationDataset.status == StatusEnum.VALID.value)
|
||||
).order_by(EvaluationDataset.create_time.desc())
|
||||
|
||||
total = query.count()
|
||||
datasets = query.paginate(page, page_size)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"datasets": [d.to_dict() for d in datasets]
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error listing datasets: {e}")
|
||||
return {"total": 0, "datasets": []}
|
||||
|
||||
@classmethod
|
||||
def update_dataset(cls, dataset_id: str, **kwargs) -> bool:
|
||||
"""Update dataset"""
|
||||
try:
|
||||
kwargs["update_time"] = current_timestamp()
|
||||
return EvaluationDataset.update(**kwargs).where(
|
||||
EvaluationDataset.id == dataset_id
|
||||
).execute() > 0
|
||||
except Exception as e:
|
||||
logging.error(f"Error updating dataset {dataset_id}: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def delete_dataset(cls, dataset_id: str) -> bool:
|
||||
"""Soft delete dataset"""
|
||||
try:
|
||||
return EvaluationDataset.update(
|
||||
status=StatusEnum.INVALID.value,
|
||||
update_time=current_timestamp()
|
||||
).where(EvaluationDataset.id == dataset_id).execute() > 0
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting dataset {dataset_id}: {e}")
|
||||
return False
|
||||
|
||||
# ==================== Test Case Management ====================
|
||||
|
||||
@classmethod
|
||||
def add_test_case(cls, dataset_id: str, question: str,
|
||||
reference_answer: Optional[str] = None,
|
||||
relevant_doc_ids: Optional[List[str]] = None,
|
||||
relevant_chunk_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
Add a test case to a dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
question: Test question
|
||||
reference_answer: Optional ground truth answer
|
||||
relevant_doc_ids: Optional list of relevant document IDs
|
||||
relevant_chunk_ids: Optional list of relevant chunk IDs
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
(success, case_id or error_message)
|
||||
"""
|
||||
try:
|
||||
case_id = get_uuid()
|
||||
case = {
|
||||
"id": case_id,
|
||||
"dataset_id": dataset_id,
|
||||
"question": question,
|
||||
"reference_answer": reference_answer,
|
||||
"relevant_doc_ids": relevant_doc_ids,
|
||||
"relevant_chunk_ids": relevant_chunk_ids,
|
||||
"metadata": metadata,
|
||||
"create_time": current_timestamp()
|
||||
}
|
||||
|
||||
if not EvaluationCase.create(**case):
|
||||
return False, "Failed to create test case"
|
||||
|
||||
return True, case_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding test case: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@classmethod
|
||||
def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all test cases for a dataset"""
|
||||
try:
|
||||
cases = EvaluationCase.select().where(
|
||||
EvaluationCase.dataset_id == dataset_id
|
||||
).order_by(EvaluationCase.create_time)
|
||||
|
||||
return [c.to_dict() for c in cases]
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting test cases for dataset {dataset_id}: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def delete_test_case(cls, case_id: str) -> bool:
|
||||
"""Delete a test case"""
|
||||
try:
|
||||
return EvaluationCase.delete().where(
|
||||
EvaluationCase.id == case_id
|
||||
).execute() > 0
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting test case {case_id}: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]:
|
||||
"""
|
||||
Bulk import test cases from a list.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
cases: List of test case dictionaries
|
||||
|
||||
Returns:
|
||||
(success_count, failure_count)
|
||||
"""
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
|
||||
for case_data in cases:
|
||||
success, _ = cls.add_test_case(
|
||||
dataset_id=dataset_id,
|
||||
question=case_data.get("question", ""),
|
||||
reference_answer=case_data.get("reference_answer"),
|
||||
relevant_doc_ids=case_data.get("relevant_doc_ids"),
|
||||
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
|
||||
metadata=case_data.get("metadata")
|
||||
)
|
||||
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
failure_count += 1
|
||||
|
||||
return success_count, failure_count
|
||||
|
||||
# ==================== Evaluation Execution ====================
|
||||
|
||||
@classmethod
|
||||
def start_evaluation(cls, dataset_id: str, dialog_id: str,
|
||||
user_id: str, name: Optional[str] = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
Start an evaluation run.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
dialog_id: Dialog configuration to evaluate
|
||||
user_id: User ID who starts the run
|
||||
name: Optional run name
|
||||
|
||||
Returns:
|
||||
(success, run_id or error_message)
|
||||
"""
|
||||
try:
|
||||
# Get dialog configuration
|
||||
success, dialog = DialogService.get_by_id(dialog_id)
|
||||
if not success:
|
||||
return False, "Dialog not found"
|
||||
|
||||
# Create evaluation run
|
||||
run_id = get_uuid()
|
||||
if not name:
|
||||
name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
run = {
|
||||
"id": run_id,
|
||||
"dataset_id": dataset_id,
|
||||
"dialog_id": dialog_id,
|
||||
"name": name,
|
||||
"config_snapshot": dialog.to_dict(),
|
||||
"metrics_summary": None,
|
||||
"status": "RUNNING",
|
||||
"created_by": user_id,
|
||||
"create_time": current_timestamp(),
|
||||
"complete_time": None
|
||||
}
|
||||
|
||||
if not EvaluationRun.create(**run):
|
||||
return False, "Failed to create evaluation run"
|
||||
|
||||
# Execute evaluation asynchronously (in production, use task queue)
|
||||
# For now, we'll execute synchronously
|
||||
cls._execute_evaluation(run_id, dataset_id, dialog)
|
||||
|
||||
return True, run_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error starting evaluation: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@classmethod
|
||||
def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any):
|
||||
"""
|
||||
Execute evaluation for all test cases.
|
||||
|
||||
This method runs the RAG pipeline for each test case and computes metrics.
|
||||
"""
|
||||
try:
|
||||
# Get all test cases
|
||||
test_cases = cls.get_test_cases(dataset_id)
|
||||
|
||||
if not test_cases:
|
||||
EvaluationRun.update(
|
||||
status="FAILED",
|
||||
complete_time=current_timestamp()
|
||||
).where(EvaluationRun.id == run_id).execute()
|
||||
return
|
||||
|
||||
# Execute each test case
|
||||
results = []
|
||||
for case in test_cases:
|
||||
result = cls._evaluate_single_case(run_id, case, dialog)
|
||||
if result:
|
||||
results.append(result)
|
||||
|
||||
# Compute summary metrics
|
||||
metrics_summary = cls._compute_summary_metrics(results)
|
||||
|
||||
# Update run status
|
||||
EvaluationRun.update(
|
||||
status="COMPLETED",
|
||||
metrics_summary=metrics_summary,
|
||||
complete_time=current_timestamp()
|
||||
).where(EvaluationRun.id == run_id).execute()
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing evaluation {run_id}: {e}")
|
||||
EvaluationRun.update(
|
||||
status="FAILED",
|
||||
complete_time=current_timestamp()
|
||||
).where(EvaluationRun.id == run_id).execute()
|
||||
|
||||
@classmethod
|
||||
def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
|
||||
dialog: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Evaluate a single test case.
|
||||
|
||||
Args:
|
||||
run_id: Evaluation run ID
|
||||
case: Test case dictionary
|
||||
dialog: Dialog configuration
|
||||
|
||||
Returns:
|
||||
Result dictionary or None if failed
|
||||
"""
|
||||
try:
|
||||
# Prepare messages
|
||||
messages = [{"role": "user", "content": case["question"]}]
|
||||
|
||||
# Execute RAG pipeline
|
||||
start_time = timer()
|
||||
answer = ""
|
||||
retrieved_chunks = []
|
||||
|
||||
|
||||
def _sync_from_async_gen(async_gen):
|
||||
result_queue: queue.Queue = queue.Queue()
|
||||
|
||||
def runner():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
async def consume():
|
||||
try:
|
||||
async for item in async_gen:
|
||||
result_queue.put(item)
|
||||
except Exception as e:
|
||||
result_queue.put(e)
|
||||
finally:
|
||||
result_queue.put(StopIteration)
|
||||
|
||||
loop.run_until_complete(consume())
|
||||
loop.close()
|
||||
|
||||
threading.Thread(target=runner, daemon=True).start()
|
||||
|
||||
while True:
|
||||
item = result_queue.get()
|
||||
if item is StopIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
|
||||
def chat(dialog, messages, stream=True, **kwargs):
|
||||
from api.db.services.dialog_service import async_chat
|
||||
|
||||
return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs))
|
||||
|
||||
for ans in chat(dialog, messages, stream=False):
|
||||
if isinstance(ans, dict):
|
||||
answer = ans.get("answer", "")
|
||||
retrieved_chunks = ans.get("reference", {}).get("chunks", [])
|
||||
break
|
||||
|
||||
execution_time = timer() - start_time
|
||||
|
||||
# Compute metrics
|
||||
metrics = cls._compute_metrics(
|
||||
question=case["question"],
|
||||
generated_answer=answer,
|
||||
reference_answer=case.get("reference_answer"),
|
||||
retrieved_chunks=retrieved_chunks,
|
||||
relevant_chunk_ids=case.get("relevant_chunk_ids"),
|
||||
dialog=dialog
|
||||
)
|
||||
|
||||
# Save result
|
||||
result_id = get_uuid()
|
||||
result = {
|
||||
"id": result_id,
|
||||
"run_id": run_id,
|
||||
"case_id": case["id"],
|
||||
"generated_answer": answer,
|
||||
"retrieved_chunks": retrieved_chunks,
|
||||
"metrics": metrics,
|
||||
"execution_time": execution_time,
|
||||
"token_usage": None, # TODO: Track token usage
|
||||
"create_time": current_timestamp()
|
||||
}
|
||||
|
||||
EvaluationResult.create(**result)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error evaluating case {case.get('id')}: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _compute_metrics(cls, question: str, generated_answer: str,
|
||||
reference_answer: Optional[str],
|
||||
retrieved_chunks: List[Dict[str, Any]],
|
||||
relevant_chunk_ids: Optional[List[str]],
|
||||
dialog: Any) -> Dict[str, float]:
|
||||
"""
|
||||
Compute evaluation metrics for a single test case.
|
||||
|
||||
Returns:
|
||||
Dictionary of metric names to values
|
||||
"""
|
||||
metrics = {}
|
||||
|
||||
# Retrieval metrics (if ground truth chunks provided)
|
||||
if relevant_chunk_ids:
|
||||
retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks]
|
||||
metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids))
|
||||
|
||||
# Generation metrics
|
||||
if generated_answer:
|
||||
# Basic metrics
|
||||
metrics["answer_length"] = len(generated_answer)
|
||||
metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0
|
||||
|
||||
# TODO: Implement advanced metrics using LLM-as-judge
|
||||
# - Faithfulness (hallucination detection)
|
||||
# - Answer relevance
|
||||
# - Context relevance
|
||||
# - Semantic similarity (if reference answer provided)
|
||||
|
||||
return metrics
|
||||
|
||||
@classmethod
|
||||
def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
|
||||
relevant_ids: List[str]) -> Dict[str, float]:
|
||||
"""
|
||||
Compute retrieval metrics.
|
||||
|
||||
Args:
|
||||
retrieved_ids: List of retrieved chunk IDs
|
||||
relevant_ids: List of relevant chunk IDs (ground truth)
|
||||
|
||||
Returns:
|
||||
Dictionary of retrieval metrics
|
||||
"""
|
||||
if not relevant_ids:
|
||||
return {}
|
||||
|
||||
retrieved_set = set(retrieved_ids)
|
||||
relevant_set = set(relevant_ids)
|
||||
|
||||
# Precision: proportion of retrieved that are relevant
|
||||
precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0
|
||||
|
||||
# Recall: proportion of relevant that were retrieved
|
||||
recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0
|
||||
|
||||
# F1 score
|
||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||
|
||||
# Hit rate: whether any relevant chunk was retrieved
|
||||
hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0
|
||||
|
||||
# MRR (Mean Reciprocal Rank): position of first relevant chunk
|
||||
mrr = 0.0
|
||||
for i, chunk_id in enumerate(retrieved_ids, 1):
|
||||
if chunk_id in relevant_set:
|
||||
mrr = 1.0 / i
|
||||
break
|
||||
|
||||
return {
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1_score": f1,
|
||||
"hit_rate": hit_rate,
|
||||
"mrr": mrr
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Compute summary metrics across all test cases.
|
||||
|
||||
Args:
|
||||
results: List of result dictionaries
|
||||
|
||||
Returns:
|
||||
Summary metrics dictionary
|
||||
"""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
# Aggregate metrics
|
||||
metric_sums = {}
|
||||
metric_counts = {}
|
||||
|
||||
for result in results:
|
||||
metrics = result.get("metrics", {})
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
metric_sums[key] = metric_sums.get(key, 0) + value
|
||||
metric_counts[key] = metric_counts.get(key, 0) + 1
|
||||
|
||||
# Compute averages
|
||||
summary = {
|
||||
"total_cases": len(results),
|
||||
"avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results)
|
||||
}
|
||||
|
||||
for key in metric_sums:
|
||||
summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key]
|
||||
|
||||
return summary
|
||||
|
||||
# ==================== Results & Analysis ====================
|
||||
|
||||
@classmethod
|
||||
def get_run_results(cls, run_id: str) -> Dict[str, Any]:
|
||||
"""Get results for an evaluation run"""
|
||||
try:
|
||||
run = EvaluationRun.get_by_id(run_id)
|
||||
if not run:
|
||||
return {}
|
||||
|
||||
results = EvaluationResult.select().where(
|
||||
EvaluationResult.run_id == run_id
|
||||
).order_by(EvaluationResult.create_time)
|
||||
|
||||
return {
|
||||
"run": run.to_dict(),
|
||||
"results": [r.to_dict() for r in results]
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting run results {run_id}: {e}")
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Analyze evaluation results and provide configuration recommendations.
|
||||
|
||||
Args:
|
||||
run_id: Evaluation run ID
|
||||
|
||||
Returns:
|
||||
List of recommendation dictionaries
|
||||
"""
|
||||
try:
|
||||
run = EvaluationRun.get_by_id(run_id)
|
||||
if not run or not run.metrics_summary:
|
||||
return []
|
||||
|
||||
metrics = run.metrics_summary
|
||||
recommendations = []
|
||||
|
||||
# Low precision: retrieving irrelevant chunks
|
||||
if metrics.get("avg_precision", 1.0) < 0.7:
|
||||
recommendations.append({
|
||||
"issue": "Low Precision",
|
||||
"severity": "high",
|
||||
"description": "System is retrieving many irrelevant chunks",
|
||||
"suggestions": [
|
||||
"Increase similarity_threshold to filter out less relevant chunks",
|
||||
"Enable reranking to improve chunk ordering",
|
||||
"Reduce top_k to return fewer chunks"
|
||||
]
|
||||
})
|
||||
|
||||
# Low recall: missing relevant chunks
|
||||
if metrics.get("avg_recall", 1.0) < 0.7:
|
||||
recommendations.append({
|
||||
"issue": "Low Recall",
|
||||
"severity": "high",
|
||||
"description": "System is missing relevant chunks",
|
||||
"suggestions": [
|
||||
"Increase top_k to retrieve more chunks",
|
||||
"Lower similarity_threshold to be more inclusive",
|
||||
"Enable hybrid search (keyword + semantic)",
|
||||
"Check chunk size - may be too large or too small"
|
||||
]
|
||||
})
|
||||
|
||||
# Slow response time
|
||||
if metrics.get("avg_execution_time", 0) > 5.0:
|
||||
recommendations.append({
|
||||
"issue": "Slow Response Time",
|
||||
"severity": "medium",
|
||||
"description": f"Average response time is {metrics['avg_execution_time']:.2f}s",
|
||||
"suggestions": [
|
||||
"Reduce top_k to retrieve fewer chunks",
|
||||
"Optimize embedding model selection",
|
||||
"Consider caching frequently asked questions"
|
||||
]
|
||||
})
|
||||
|
||||
return recommendations
|
||||
except Exception as e:
|
||||
logging.error(f"Error generating recommendations for run {run_id}: {e}")
|
||||
return []
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user