mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Compare commits
225 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8de6b97806 | |||
| e4e0a88053 | |||
| 7719fd6350 | |||
| 15ef6dd72f | |||
| 5b5f19cbc1 | |||
| ea38e12d42 | |||
| 885eb2eab9 | |||
| 6587acef88 | |||
| ad03ede7cd | |||
| 468e4042c2 | |||
| af1344033d | |||
| 4012d65b3c | |||
| e2bc1a3478 | |||
| 6c2c447a72 | |||
| e7022db9a4 | |||
| ca4a0ee1b2 | |||
| 27b0550876 | |||
| 797e03f843 | |||
| b4e06237ef | |||
| 751a13fb64 | |||
| fa7b857aa9 | |||
| 257af75ece | |||
| cbdacf21f6 | |||
| b1f3130519 | |||
| 3c224c817b | |||
| a3c9402218 | |||
| a7d40e9132 | |||
| 648342b62f | |||
| 4870d42949 | |||
| caaf7043cc | |||
| 237a66913b | |||
| 3c50c7d3ac | |||
| b44e65a12e | |||
| e3f40db963 | |||
| b5ad7b7062 | |||
| 6fc7def562 | |||
| c8f608b2dd | |||
| 5c81e01de5 | |||
| 83fac6d0a0 | |||
| a6681d6366 | |||
| 1388c4420d | |||
| 962bd5f5df | |||
| 627c11c429 | |||
| 4ba17361e9 | |||
| c946858328 | |||
| ba6e2af5fd | |||
| 2ffe6f7439 | |||
| e3987e21b9 | |||
| a713f54732 | |||
| 519f03097e | |||
| 299c655e39 | |||
| b8c0fb4572 | |||
| d1e172171f | |||
| 81ae6cf78d | |||
| 1120575021 | |||
| 221947acc4 | |||
| 21d8ffca56 | |||
| 41cff3e09e | |||
| b6c4722687 | |||
| 6ea4248bdc | |||
| 88a28212b3 | |||
| 9d0309aedc | |||
| 9a8ce9d3e2 | |||
| 7499608a8b | |||
| 0ebbb60102 | |||
| 80f6d22d2a | |||
| 088b049b4c | |||
| fa9b7b259c | |||
| 14616cf845 | |||
| d2915f6984 | |||
| ccce8beeeb | |||
| 3d2e0f1a1b | |||
| 918d5a9ff8 | |||
| 7d05d4ced7 | |||
| dbdda0fbab | |||
| cf7fdd274b | |||
| 982ed233a2 | |||
| 1f96c95b42 | |||
| 8604c4f57c | |||
| a674338c21 | |||
| 89d82ff031 | |||
| c71d25f744 | |||
| f57f32cf3a | |||
| b6314164c5 | |||
| 856201c0f2 | |||
| 9d8b96c1d0 | |||
| 7c3c185038 | |||
| a9259917c6 | |||
| 8c28587821 | |||
| 12979a3f21 | |||
| 376eb15c63 | |||
| 89ba7abe30 | |||
| 2fd5ac1031 | |||
| 40e84ca41a | |||
| a28c672695 | |||
| 74e0b58d89 | |||
| 7c20c964b4 | |||
| 5d0981d046 | |||
| a793dd2ea8 | |||
| 915e385244 | |||
| 7a344a32f9 | |||
| 8c1ee3845a | |||
| 8c751d5afc | |||
| f5faf0c94f | |||
| af72e8dc33 | |||
| bcd70affb5 | |||
| 6987e9f23b | |||
| 41665b0865 | |||
| d1744aaaf3 | |||
| d5f8548200 | |||
| 4d8698624c | |||
| 1009819801 | |||
| 8fe782f4ea | |||
| 7140950e93 | |||
| 0181747881 | |||
| 3c41159d26 | |||
| e0e1d04da5 | |||
| f0a14f5fce | |||
| 174a2578e8 | |||
| a0959b9d38 | |||
| 13299197b8 | |||
| 249296e417 | |||
| db0f6840d9 | |||
| 1033a3ae26 | |||
| 1845daf41f | |||
| 4c8f9f0d77 | |||
| cc00c3ec93 | |||
| 653b785958 | |||
| 971c1bcba7 | |||
| 065917bf1c | |||
| 820934fc77 | |||
| d3d2ccc76c | |||
| c8ab9079b3 | |||
| 0d5589bfda | |||
| b846a0f547 | |||
| 69578ebfce | |||
| 06cef71ba6 | |||
| d2b1da0e26 | |||
| 7c6d30f4c8 | |||
| ea0352ee4a | |||
| fa5cf10f56 | |||
| 3fe71ab7dd | |||
| 9f715d6bc2 | |||
| 48de3b26ba | |||
| 273c4bc4d3 | |||
| 420c97199a | |||
| ecf0322165 | |||
| 38234aca53 | |||
| 1c06ec39ca | |||
| cfdccebb17 | |||
| 980a883033 | |||
| 02d429f0ca | |||
| 9c24d5d44a | |||
| 0cc5d7a8a6 | |||
| c43bf1dcf5 | |||
| f76b8279dd | |||
| db5ec89dc5 | |||
| 1c201c4d54 | |||
| ba78d0f0c2 | |||
| add8c63458 | |||
| 83661efdaf | |||
| 971197d595 | |||
| 0884e9a4d9 | |||
| 2de42f00b8 | |||
| e8fe580d7a | |||
| 62505164d5 | |||
| d1dcf3b43c | |||
| f84662d2ee | |||
| 1cb6b7f5dd | |||
| 023f509501 | |||
| 50bc53a1f5 | |||
| 8cd4882596 | |||
| 35e5fade93 | |||
| 4942a23290 | |||
| d1716d865a | |||
| c2b7c305fa | |||
| 341e5904c8 | |||
| ded9bf80c5 | |||
| fea157ba08 | |||
| 0db00f70b2 | |||
| 701761d119 | |||
| 2993fc666b | |||
| 8a6d205df0 | |||
| 912b6b023e | |||
| 89e8818dda | |||
| 1dba6b5bf9 | |||
| 3fcf2ee54c | |||
| d8f413a885 | |||
| 7264fb6978 | |||
| bd4bc57009 | |||
| 0569b50fed | |||
| 6b64641042 | |||
| 9cef3a2625 | |||
| e7e89d3ecb | |||
| 13e212c856 | |||
| 61cf430dbb | |||
| e841b09d63 | |||
| b1a1eedf53 | |||
| 68e3b33ae4 | |||
| cd55f6c1b8 | |||
| 996b5fe14e | |||
| db4fd19c82 | |||
| 12db62b9c7 | |||
| b5f2cf16bc | |||
| e27ff8d3d4 | |||
| 5f59418aba | |||
| 87e69868c0 | |||
| 72c20022f6 | |||
| 3f2472f1b9 | |||
| 1d4d67daf8 | |||
| 7538e218a5 | |||
| 6b52f7df5a | |||
| 63131ec9b2 | |||
| e8f1a245a6 | |||
| 908450509f | |||
| 70a0f081f6 | |||
| 93422fa8cc | |||
| bfc84ba95b | |||
| 871055b0fc | |||
| ba71160b14 | |||
| bd5dda6b10 | |||
| 774563970b | |||
| 83d84e90ed | |||
| 8ef2f79d0a | |||
| 296476ab89 |
50
.github/workflows/tests.yml
vendored
50
.github/workflows/tests.yml
vendored
@ -12,7 +12,7 @@ on:
|
||||
# The only difference between pull_request and pull_request_target is the context in which the workflow runs:
|
||||
# — pull_request_target workflows use the workflow files from the default branch, and secrets are available.
|
||||
# — pull_request workflows use the workflow files from the pull request branch, and secrets are unavailable.
|
||||
pull_request_target:
|
||||
pull_request:
|
||||
types: [ synchronize, ready_for_review ]
|
||||
paths-ignore:
|
||||
- 'docs/**'
|
||||
@ -31,7 +31,7 @@ jobs:
|
||||
name: ragflow_tests
|
||||
# https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution
|
||||
# https://github.com/orgs/community/discussions/26261
|
||||
if: ${{ github.event_name != 'pull_request_target' || contains(github.event.pull_request.labels.*.name, 'ci') }}
|
||||
if: ${{ github.event_name != 'pull_request' || (github.event.pull_request.draft == false && contains(github.event.pull_request.labels.*.name, 'ci')) }}
|
||||
runs-on: [ "self-hosted", "ragflow-test" ]
|
||||
steps:
|
||||
# https://github.com/hmarr/debug-action
|
||||
@ -53,7 +53,7 @@ jobs:
|
||||
- name: Check workflow duplication
|
||||
if: ${{ !cancelled() && !failure() }}
|
||||
run: |
|
||||
if [[ ${GITHUB_EVENT_NAME} != "pull_request_target" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then
|
||||
if [[ ${GITHUB_EVENT_NAME} != "pull_request" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then
|
||||
HEAD=$(git rev-parse HEAD)
|
||||
# Find a PR that introduced a given commit
|
||||
gh auth login --with-token <<< "${{ secrets.GITHUB_TOKEN }}"
|
||||
@ -78,7 +78,7 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
elif [[ ${GITHUB_EVENT_NAME} == "pull_request_target" ]]; then
|
||||
elif [[ ${GITHUB_EVENT_NAME} == "pull_request" ]]; then
|
||||
PR_NUMBER=${{ github.event.pull_request.number }}
|
||||
PR_SHA_FP=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/PR_${PR_NUMBER}
|
||||
# Calculate the hash of the current workspace content
|
||||
@ -95,6 +95,46 @@ jobs:
|
||||
version: ">=0.11.x"
|
||||
args: "check"
|
||||
|
||||
- name: Check comments of changed Python files
|
||||
if: ${{ false }}
|
||||
run: |
|
||||
if [[ ${{ github.event_name }} == 'pull_request' || ${{ github.event_name }} == 'pull_request_target' ]]; then
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \
|
||||
| grep -E '\.(py)$' || true)
|
||||
|
||||
if [ -n "$CHANGED_FILES" ]; then
|
||||
echo "Check comments of changed Python files with check_comment_ascii.py"
|
||||
|
||||
readarray -t files <<< "$CHANGED_FILES"
|
||||
HAS_ERROR=0
|
||||
|
||||
for file in "${files[@]}"; do
|
||||
if [ -f "$file" ]; then
|
||||
if python3 check_comment_ascii.py "$file"; then
|
||||
echo "✅ $file"
|
||||
else
|
||||
echo "❌ $file"
|
||||
HAS_ERROR=1
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $HAS_ERROR -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "No Python files changed"
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Run unit test
|
||||
run: |
|
||||
uv sync --python 3.10 --group test --frozen
|
||||
source .venv/bin/activate
|
||||
which pytest || echo "pytest not in PATH"
|
||||
echo "Start to run unit test"
|
||||
python3 run_tests.py
|
||||
|
||||
- name: Build ragflow:nightly
|
||||
run: |
|
||||
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}
|
||||
@ -161,7 +201,7 @@ jobs:
|
||||
echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> ${GITHUB_ENV}
|
||||
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
|
||||
uv sync --python 3.10 --only-group test --no-default-groups --frozen && uv pip install sdk/python
|
||||
uv sync --python 3.10 --only-group test --no-default-groups --frozen && uv pip install sdk/python --group test
|
||||
|
||||
- name: Run sdk tests against Elasticsearch
|
||||
run: |
|
||||
|
||||
@ -10,11 +10,10 @@ WORKDIR /ragflow
|
||||
# Copy models downloaded via download_deps.py
|
||||
RUN mkdir -p /ragflow/rag/res/deepdoc /root/.ragflow
|
||||
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \
|
||||
cp /huggingface.co/InfiniFlow/huqie/huqie.txt.trie /ragflow/rag/res/ && \
|
||||
tar --exclude='.*' -cf - \
|
||||
/huggingface.co/InfiniFlow/text_concat_xgb_v1.0 \
|
||||
/huggingface.co/InfiniFlow/deepdoc \
|
||||
| tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc
|
||||
| tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc
|
||||
|
||||
# https://github.com/chrismattmann/tika-python
|
||||
# This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
|
||||
@ -51,7 +50,9 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||
apt install -y libjemalloc-dev && \
|
||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
||||
apt install -y ghostscript
|
||||
apt install -y ghostscript && \
|
||||
apt install -y pandoc && \
|
||||
apt install -y texlive
|
||||
|
||||
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||
|
||||
15
README.md
15
README.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -85,7 +85,8 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
## 🔥 Latest Updates
|
||||
|
||||
- 2025-11-12 Supports data synchronization from Confluence, AWS S3, Discord, Google Drive.
|
||||
- 2025-11-19 Supports Gemini 3 Pro.
|
||||
- 2025-11-12 Supports data synchronization from Confluence, S3, Notion, Discord, Google Drive.
|
||||
- 2025-10-23 Supports MinerU & Docling as document parsing methods.
|
||||
- 2025-10-15 Supports orchestrable ingestion pipeline.
|
||||
- 2025-08-08 Supports OpenAI's latest GPT-5 series models.
|
||||
@ -93,8 +94,6 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
- 2025-05-23 Adds a Python/JavaScript code executor component to Agent.
|
||||
- 2025-05-05 Supports cross-language query.
|
||||
- 2025-03-19 Supports using a multi-modal model to make sense of images within PDF or DOCX files.
|
||||
- 2024-12-18 Upgrades Document Layout Analysis model in DeepDoc.
|
||||
- 2024-08-22 Support text to SQL statements through RAG.
|
||||
|
||||
## 🎉 Stay Tuned
|
||||
|
||||
@ -188,13 +187,15 @@ releases! 🌟
|
||||
> All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64.
|
||||
> If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system.
|
||||
|
||||
> The command below downloads the `v0.22.0` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.22.0`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
|
||||
> The command below downloads the `v0.22.1` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.22.1`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# git checkout v0.22.1
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
|
||||
# This step ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
|
||||
11
README_id.md
11
README_id.md
@ -22,7 +22,7 @@
|
||||
<img alt="Lencana Daring" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Rilis%20Terbaru" alt="Rilis Terbaru">
|
||||
@ -85,7 +85,8 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
## 🔥 Pembaruan Terbaru
|
||||
|
||||
- 2025-11-12 Mendukung sinkronisasi data dari Confluence, AWS S3, Discord, Google Drive.
|
||||
- 2025-11-19 Mendukung Gemini 3 Pro.
|
||||
- 2025-11-12 Mendukung sinkronisasi data dari Confluence, S3, Notion, Discord, Google Drive.
|
||||
- 2025-10-23 Mendukung MinerU & Docling sebagai metode penguraian dokumen.
|
||||
- 2025-10-15 Dukungan untuk jalur data yang terorkestrasi.
|
||||
- 2025-08-08 Mendukung model seri GPT-5 terbaru dari OpenAI.
|
||||
@ -186,12 +187,14 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
> Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64.
|
||||
> Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image).
|
||||
|
||||
> Perintah di bawah ini mengunduh edisi v0.22.0 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.22.0, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
|
||||
> Perintah di bawah ini mengunduh edisi v0.22.1 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.22.1, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases), contoh: git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases)
|
||||
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
11
README_ja.md
11
README_ja.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -66,7 +66,8 @@
|
||||
|
||||
## 🔥 最新情報
|
||||
|
||||
- 2025-11-12 Confluence、AWS S3、Discord、Google Drive からのデータ同期をサポートします。
|
||||
- 2025-11-19 Gemini 3 Proをサポートしています
|
||||
- 2025-11-12 Confluence、S3、Notion、Discord、Google Drive からのデータ同期をサポートします。
|
||||
- 2025-10-23 ドキュメント解析方法として MinerU と Docling をサポートします。
|
||||
- 2025-10-15 オーケストレーションされたデータパイプラインのサポート。
|
||||
- 2025-08-08 OpenAI の最新 GPT-5 シリーズモデルをサポートします。
|
||||
@ -166,12 +167,14 @@
|
||||
> 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。
|
||||
> ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。
|
||||
|
||||
> 以下のコマンドは、RAGFlow Docker イメージの v0.22.0 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.22.0 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
|
||||
> 以下のコマンドは、RAGFlow Docker イメージの v0.22.1 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.22.1 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) 例: git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases)
|
||||
# この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
11
README_ko.md
11
README_ko.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -67,7 +67,8 @@
|
||||
|
||||
## 🔥 업데이트
|
||||
|
||||
- 2025-11-12 Confluence, AWS S3, Discord, Google Drive에서 데이터 동기화를 지원합니다.
|
||||
- 2025-11-19 Gemini 3 Pro를 지원합니다.
|
||||
- 2025-11-12 Confluence, S3, Notion, Discord, Google Drive에서 데이터 동기화를 지원합니다.
|
||||
- 2025-10-23 문서 파싱 방법으로 MinerU 및 Docling을 지원합니다.
|
||||
- 2025-10-15 조정된 데이터 파이프라인 지원.
|
||||
- 2025-08-08 OpenAI의 최신 GPT-5 시리즈 모델을 지원합니다.
|
||||
@ -168,12 +169,14 @@
|
||||
> 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다.
|
||||
> ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image).
|
||||
|
||||
> 아래 명령어는 RAGFlow Docker 이미지의 v0.22.0 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.22.0과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
|
||||
> 아래 명령어는 RAGFlow Docker 이미지의 v0.22.1 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.22.1과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
|
||||
# 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
<img alt="Badge Estático" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Última%20Relese" alt="Última Versão">
|
||||
@ -86,7 +86,8 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
## 🔥 Últimas Atualizações
|
||||
|
||||
- 12-11-2025 Suporta a sincronização de dados do Confluence, AWS S3, Discord e Google Drive.
|
||||
- 19-11-2025 Suporta Gemini 3 Pro.
|
||||
- 12-11-2025 Suporta a sincronização de dados do Confluence, S3, Notion, Discord e Google Drive.
|
||||
- 23-10-2025 Suporta MinerU e Docling como métodos de análise de documentos.
|
||||
- 15-10-2025 Suporte para pipelines de dados orquestrados.
|
||||
- 08-08-2025 Suporta a mais recente série GPT-5 da OpenAI.
|
||||
@ -186,12 +187,14 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
> Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64.
|
||||
> Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema.
|
||||
|
||||
> O comando abaixo baixa a edição`v0.22.0` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.22.0`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
|
||||
> O comando abaixo baixa a edição`v0.22.1` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.22.1`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases), ex.: git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases)
|
||||
# Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -85,7 +85,8 @@
|
||||
|
||||
## 🔥 近期更新
|
||||
|
||||
- 2025-11-12 支援從 Confluence、AWS S3、Discord、Google Drive 進行資料同步。
|
||||
- 2025-11-19 支援 Gemini 3 Pro.
|
||||
- 2025-11-12 支援從 Confluence、S3、Notion、Discord、Google Drive 進行資料同步。
|
||||
- 2025-10-23 支援 MinerU 和 Docling 作為文件解析方法。
|
||||
- 2025-10-15 支援可編排的資料管道。
|
||||
- 2025-08-08 支援 OpenAI 最新的 GPT-5 系列模型。
|
||||
@ -185,12 +186,14 @@
|
||||
> 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。
|
||||
> 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。
|
||||
|
||||
> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.22.0`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.22.0` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
|
||||
> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.22.1`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.22.1` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases),例:git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases)
|
||||
# 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
11
README_zh.md
11
README_zh.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -85,7 +85,8 @@
|
||||
|
||||
## 🔥 近期更新
|
||||
|
||||
- 2025-11-12 支持从 Confluence、AWS S3、Discord、Google Drive 进行数据同步。
|
||||
- 2025-11-19 支持 Gemini 3 Pro.
|
||||
- 2025-11-12 支持从 Confluence、S3、Notion、Discord、Google Drive 进行数据同步。
|
||||
- 2025-10-23 支持 MinerU 和 Docling 作为文档解析方法。
|
||||
- 2025-10-15 支持可编排的数据管道。
|
||||
- 2025-08-08 支持 OpenAI 最新的 GPT-5 系列模型。
|
||||
@ -186,12 +187,14 @@
|
||||
> 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。
|
||||
> 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。
|
||||
|
||||
> 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.22.0`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.22.0` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
|
||||
> 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.22.1`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.22.1` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases),例如:git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases)
|
||||
# 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
Admin Service is a dedicated management component designed to monitor, maintain, and administrate the RAGFlow system. It provides comprehensive tools for ensuring system stability, performing operational tasks, and managing users and permissions efficiently.
|
||||
|
||||
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
|
||||
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Infinity, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
|
||||
|
||||
For user and system management, it supports listing, creating, modifying, and deleting users and their associated resources like knowledge bases and Agents.
|
||||
|
||||
@ -48,7 +48,7 @@ It consists of a server-side Service and a command-line client (CLI), both imple
|
||||
1. Ensure the Admin Service is running.
|
||||
2. Install ragflow-cli.
|
||||
```bash
|
||||
pip install ragflow-cli==0.22.0
|
||||
pip install ragflow-cli==0.22.1
|
||||
```
|
||||
3. Launch the CLI client:
|
||||
```bash
|
||||
|
||||
@ -378,7 +378,7 @@ class AdminCLI(Cmd):
|
||||
self.session.headers.update({
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': response.headers['Authorization'],
|
||||
'User-Agent': 'RAGFlow-CLI/0.22.0'
|
||||
'User-Agent': 'RAGFlow-CLI/0.22.1'
|
||||
})
|
||||
print("Authentication successful.")
|
||||
return True
|
||||
@ -393,7 +393,9 @@ class AdminCLI(Cmd):
|
||||
print(f"Can't access {self.host}, port: {self.port}")
|
||||
|
||||
def _format_service_detail_table(self, data):
|
||||
if not any([isinstance(v, list) for v in data.values()]):
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
if not all([isinstance(v, list) for v in data.values()]):
|
||||
# normal table
|
||||
return data
|
||||
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
|
||||
@ -404,7 +406,7 @@ class AdminCLI(Cmd):
|
||||
task_executor_list.append({
|
||||
"task_executor_name": k,
|
||||
**heartbeats[0],
|
||||
})
|
||||
} if heartbeats else {"task_executor_name": k})
|
||||
return task_executor_list
|
||||
|
||||
def _print_table_simple(self, data):
|
||||
@ -415,7 +417,8 @@ class AdminCLI(Cmd):
|
||||
# handle single row data
|
||||
data = [data]
|
||||
|
||||
columns = list(data[0].keys())
|
||||
columns = list(set().union(*(d.keys() for d in data)))
|
||||
columns.sort()
|
||||
col_widths = {}
|
||||
|
||||
def get_string_width(text):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ragflow-cli"
|
||||
version = "0.22.0"
|
||||
version = "0.22.1"
|
||||
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
|
||||
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
|
||||
license = { text = "Apache License, Version 2.0" }
|
||||
@ -8,7 +8,7 @@ readme = "README.md"
|
||||
requires-python = ">=3.10,<3.13"
|
||||
dependencies = [
|
||||
"requests>=2.30.0,<3.0.0",
|
||||
"beartype>=0.18.5,<0.19.0",
|
||||
"beartype>=0.20.0,<1.0.0",
|
||||
"pycryptodomex>=3.10.0",
|
||||
"lark>=1.1.0",
|
||||
]
|
||||
|
||||
298
admin/client/uv.lock
generated
Normal file
298
admin/client/uv.lock
generated
Normal file
@ -0,0 +1,298 @@
|
||||
version = 1
|
||||
revision = 3
|
||||
requires-python = ">=3.10, <3.13"
|
||||
|
||||
[[package]]
|
||||
name = "beartype"
|
||||
version = "0.22.6"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/88/e2/105ceb1704cb80fe4ab3872529ab7b6f365cf7c74f725e6132d0efcf1560/beartype-0.22.6.tar.gz", hash = "sha256:97fbda69c20b48c5780ac2ca60ce3c1bb9af29b3a1a0216898ffabdd523e48f4", size = 1588975, upload-time = "2025-11-20T04:47:14.736Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/98/c9/ceecc71fe2c9495a1d8e08d44f5f31f5bca1350d5b2e27a4b6265424f59e/beartype-0.22.6-py3-none-any.whl", hash = "sha256:0584bc46a2ea2a871509679278cda992eadde676c01356ab0ac77421f3c9a093", size = 1324807, upload-time = "2025-11-20T04:47:11.837Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2025.11.12"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a2/8c/58f469717fa48465e4a50c014a0400602d3c437d7c0c468e17ada824da3a/certifi-2025.11.12.tar.gz", hash = "sha256:d8ab5478f2ecd78af242878415affce761ca6bc54a22a27e026d7c25357c3316", size = 160538, upload-time = "2025-11-12T02:54:51.517Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/70/7d/9bc192684cea499815ff478dfcdc13835ddf401365057044fb721ec6bddb/certifi-2025.11.12-py3-none-any.whl", hash = "sha256:97de8790030bbd5c2d96b7ec782fc2f7820ef8dba6db909ccf95449f2d062d4b", size = 159438, upload-time = "2025-11-12T02:54:49.735Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "charset-normalizer"
|
||||
version = "3.4.4"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418, upload-time = "2025-10-14T04:42:32.879Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/b8/6d51fc1d52cbd52cd4ccedd5b5b2f0f6a11bbf6765c782298b0f3e808541/charset_normalizer-3.4.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e824f1492727fa856dd6eda4f7cee25f8518a12f3c4a56a74e8095695089cf6d", size = 209709, upload-time = "2025-10-14T04:40:11.385Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/5c/af/1f9d7f7faafe2ddfb6f72a2e07a548a629c61ad510fe60f9630309908fef/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4bd5d4137d500351a30687c2d3971758aac9a19208fc110ccb9d7188fbe709e8", size = 148814, upload-time = "2025-10-14T04:40:13.135Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/79/3d/f2e3ac2bbc056ca0c204298ea4e3d9db9b4afe437812638759db2c976b5f/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:027f6de494925c0ab2a55eab46ae5129951638a49a34d87f4c3eda90f696b4ad", size = 144467, upload-time = "2025-10-14T04:40:14.728Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ec/85/1bf997003815e60d57de7bd972c57dc6950446a3e4ccac43bc3070721856/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f820802628d2694cb7e56db99213f930856014862f3fd943d290ea8438d07ca8", size = 162280, upload-time = "2025-10-14T04:40:16.14Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/3e/8e/6aa1952f56b192f54921c436b87f2aaf7c7a7c3d0d1a765547d64fd83c13/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:798d75d81754988d2565bff1b97ba5a44411867c0cf32b77a7e8f8d84796b10d", size = 159454, upload-time = "2025-10-14T04:40:17.567Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/36/3b/60cbd1f8e93aa25d1c669c649b7a655b0b5fb4c571858910ea9332678558/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d1bb833febdff5c8927f922386db610b49db6e0d4f4ee29601d71e7c2694313", size = 153609, upload-time = "2025-10-14T04:40:19.08Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/64/91/6a13396948b8fd3c4b4fd5bc74d045f5637d78c9675585e8e9fbe5636554/charset_normalizer-3.4.4-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:9cd98cdc06614a2f768d2b7286d66805f94c48cde050acdbbb7db2600ab3197e", size = 151849, upload-time = "2025-10-14T04:40:20.607Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b7/7a/59482e28b9981d105691e968c544cc0df3b7d6133152fb3dcdc8f135da7a/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:077fbb858e903c73f6c9db43374fd213b0b6a778106bc7032446a8e8b5b38b93", size = 151586, upload-time = "2025-10-14T04:40:21.719Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/92/59/f64ef6a1c4bdd2baf892b04cd78792ed8684fbc48d4c2afe467d96b4df57/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:244bfb999c71b35de57821b8ea746b24e863398194a4014e4c76adc2bbdfeff0", size = 145290, upload-time = "2025-10-14T04:40:23.069Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6b/63/3bf9f279ddfa641ffa1962b0db6a57a9c294361cc2f5fcac997049a00e9c/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:64b55f9dce520635f018f907ff1b0df1fdc31f2795a922fb49dd14fbcdf48c84", size = 163663, upload-time = "2025-10-14T04:40:24.17Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ed/09/c9e38fc8fa9e0849b172b581fd9803bdf6e694041127933934184e19f8c3/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:faa3a41b2b66b6e50f84ae4a68c64fcd0c44355741c6374813a800cd6695db9e", size = 151964, upload-time = "2025-10-14T04:40:25.368Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d2/d1/d28b747e512d0da79d8b6a1ac18b7ab2ecfd81b2944c4c710e166d8dd09c/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:6515f3182dbe4ea06ced2d9e8666d97b46ef4c75e326b79bb624110f122551db", size = 161064, upload-time = "2025-10-14T04:40:26.806Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/bb/9a/31d62b611d901c3b9e5500c36aab0ff5eb442043fb3a1c254200d3d397d9/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cc00f04ed596e9dc0da42ed17ac5e596c6ccba999ba6bd92b0e0aef2f170f2d6", size = 155015, upload-time = "2025-10-14T04:40:28.284Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/f3/107e008fa2bff0c8b9319584174418e5e5285fef32f79d8ee6a430d0039c/charset_normalizer-3.4.4-cp310-cp310-win32.whl", hash = "sha256:f34be2938726fc13801220747472850852fe6b1ea75869a048d6f896838c896f", size = 99792, upload-time = "2025-10-14T04:40:29.613Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/eb/66/e396e8a408843337d7315bab30dbf106c38966f1819f123257f5520f8a96/charset_normalizer-3.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:a61900df84c667873b292c3de315a786dd8dac506704dea57bc957bd31e22c7d", size = 107198, upload-time = "2025-10-14T04:40:30.644Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b5/58/01b4f815bf0312704c267f2ccb6e5d42bcc7752340cd487bc9f8c3710597/charset_normalizer-3.4.4-cp310-cp310-win_arm64.whl", hash = "sha256:cead0978fc57397645f12578bfd2d5ea9138ea0fac82b2f63f7f7c6877986a69", size = 100262, upload-time = "2025-10-14T04:40:32.108Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ed/27/c6491ff4954e58a10f69ad90aca8a1b6fe9c5d3c6f380907af3c37435b59/charset_normalizer-3.4.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6e1fcf0720908f200cd21aa4e6750a48ff6ce4afe7ff5a79a90d5ed8a08296f8", size = 206988, upload-time = "2025-10-14T04:40:33.79Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/94/59/2e87300fe67ab820b5428580a53cad894272dbb97f38a7a814a2a1ac1011/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5f819d5fe9234f9f82d75bdfa9aef3a3d72c4d24a6e57aeaebba32a704553aa0", size = 147324, upload-time = "2025-10-14T04:40:34.961Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/07/fb/0cf61dc84b2b088391830f6274cb57c82e4da8bbc2efeac8c025edb88772/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a59cb51917aa591b1c4e6a43c132f0cdc3c76dbad6155df4e28ee626cc77a0a3", size = 142742, upload-time = "2025-10-14T04:40:36.105Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/62/8b/171935adf2312cd745d290ed93cf16cf0dfe320863ab7cbeeae1dcd6535f/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8ef3c867360f88ac904fd3f5e1f902f13307af9052646963ee08ff4f131adafc", size = 160863, upload-time = "2025-10-14T04:40:37.188Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/09/73/ad875b192bda14f2173bfc1bc9a55e009808484a4b256748d931b6948442/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d9e45d7faa48ee908174d8fe84854479ef838fc6a705c9315372eacbc2f02897", size = 157837, upload-time = "2025-10-14T04:40:38.435Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6d/fc/de9cce525b2c5b94b47c70a4b4fb19f871b24995c728e957ee68ab1671ea/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:840c25fb618a231545cbab0564a799f101b63b9901f2569faecd6b222ac72381", size = 151550, upload-time = "2025-10-14T04:40:40.053Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/55/c2/43edd615fdfba8c6f2dfbd459b25a6b3b551f24ea21981e23fb768503ce1/charset_normalizer-3.4.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca5862d5b3928c4940729dacc329aa9102900382fea192fc5e52eb69d6093815", size = 149162, upload-time = "2025-10-14T04:40:41.163Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/03/86/bde4ad8b4d0e9429a4e82c1e8f5c659993a9a863ad62c7df05cf7b678d75/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d9c7f57c3d666a53421049053eaacdd14bbd0a528e2186fcb2e672effd053bb0", size = 150019, upload-time = "2025-10-14T04:40:42.276Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/86/a151eb2af293a7e7bac3a739b81072585ce36ccfb4493039f49f1d3cae8c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:277e970e750505ed74c832b4bf75dac7476262ee2a013f5574dd49075879e161", size = 143310, upload-time = "2025-10-14T04:40:43.439Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b5/fe/43dae6144a7e07b87478fdfc4dbe9efd5defb0e7ec29f5f58a55aeef7bf7/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:31fd66405eaf47bb62e8cd575dc621c56c668f27d46a61d975a249930dd5e2a4", size = 162022, upload-time = "2025-10-14T04:40:44.547Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/80/e6/7aab83774f5d2bca81f42ac58d04caf44f0cc2b65fc6db2b3b2e8a05f3b3/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:0d3d8f15c07f86e9ff82319b3d9ef6f4bf907608f53fe9d92b28ea9ae3d1fd89", size = 149383, upload-time = "2025-10-14T04:40:46.018Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/4f/e8/b289173b4edae05c0dde07f69f8db476a0b511eac556dfe0d6bda3c43384/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:9f7fcd74d410a36883701fafa2482a6af2ff5ba96b9a620e9e0721e28ead5569", size = 159098, upload-time = "2025-10-14T04:40:47.081Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d8/df/fe699727754cae3f8478493c7f45f777b17c3ef0600e28abfec8619eb49c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ebf3e58c7ec8a8bed6d66a75d7fb37b55e5015b03ceae72a8e7c74495551e224", size = 152991, upload-time = "2025-10-14T04:40:48.246Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1a/86/584869fe4ddb6ffa3bd9f491b87a01568797fb9bd8933f557dba9771beaf/charset_normalizer-3.4.4-cp311-cp311-win32.whl", hash = "sha256:eecbc200c7fd5ddb9a7f16c7decb07b566c29fa2161a16cf67b8d068bd21690a", size = 99456, upload-time = "2025-10-14T04:40:49.376Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/65/f6/62fdd5feb60530f50f7e38b4f6a1d5203f4d16ff4f9f0952962c044e919a/charset_normalizer-3.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:5ae497466c7901d54b639cf42d5b8c1b6a4fead55215500d2f486d34db48d016", size = 106978, upload-time = "2025-10-14T04:40:50.844Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/7a/9d/0710916e6c82948b3be62d9d398cb4fcf4e97b56d6a6aeccd66c4b2f2bd5/charset_normalizer-3.4.4-cp311-cp311-win_arm64.whl", hash = "sha256:65e2befcd84bc6f37095f5961e68a6f077bf44946771354a28ad434c2cce0ae1", size = 99969, upload-time = "2025-10-14T04:40:52.272Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f3/85/1637cd4af66fa687396e757dec650f28025f2a2f5a5531a3208dc0ec43f2/charset_normalizer-3.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394", size = 208425, upload-time = "2025-10-14T04:40:53.353Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/9d/6a/04130023fef2a0d9c62d0bae2649b69f7b7d8d24ea5536feef50551029df/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25", size = 148162, upload-time = "2025-10-14T04:40:54.558Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/78/29/62328d79aa60da22c9e0b9a66539feae06ca0f5a4171ac4f7dc285b83688/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef", size = 144558, upload-time = "2025-10-14T04:40:55.677Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/86/bb/b32194a4bf15b88403537c2e120b817c61cd4ecffa9b6876e941c3ee38fe/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d", size = 161497, upload-time = "2025-10-14T04:40:57.217Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/19/89/a54c82b253d5b9b111dc74aca196ba5ccfcca8242d0fb64146d4d3183ff1/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8", size = 159240, upload-time = "2025-10-14T04:40:58.358Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86", size = 153471, upload-time = "2025-10-14T04:40:59.468Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/61/fa/fbf177b55bdd727010f9c0a3c49eefa1d10f960e5f09d1d887bf93c2e698/charset_normalizer-3.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a", size = 150864, upload-time = "2025-10-14T04:41:00.623Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/05/12/9fbc6a4d39c0198adeebbde20b619790e9236557ca59fc40e0e3cebe6f40/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f", size = 150647, upload-time = "2025-10-14T04:41:01.754Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ad/1f/6a9a593d52e3e8c5d2b167daf8c6b968808efb57ef4c210acb907c365bc4/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc", size = 145110, upload-time = "2025-10-14T04:41:03.231Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/30/42/9a52c609e72471b0fc54386dc63c3781a387bb4fe61c20231a4ebcd58bdd/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf", size = 162839, upload-time = "2025-10-14T04:41:04.715Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c4/5b/c0682bbf9f11597073052628ddd38344a3d673fda35a36773f7d19344b23/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15", size = 150667, upload-time = "2025-10-14T04:41:05.827Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/e4/24/a41afeab6f990cf2daf6cb8c67419b63b48cf518e4f56022230840c9bfb2/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9", size = 160535, upload-time = "2025-10-14T04:41:06.938Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/2a/e5/6a4ce77ed243c4a50a1fecca6aaaab419628c818a49434be428fe24c9957/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0", size = 154816, upload-time = "2025-10-14T04:41:08.101Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a8/ef/89297262b8092b312d29cdb2517cb1237e51db8ecef2e9af5edbe7b683b1/charset_normalizer-3.4.4-cp312-cp312-win32.whl", hash = "sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26", size = 99694, upload-time = "2025-10-14T04:41:09.23Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/3d/2d/1e5ed9dd3b3803994c155cd9aacb60c82c331bad84daf75bcb9c91b3295e/charset_normalizer-3.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525", size = 107131, upload-time = "2025-10-14T04:41:10.467Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d0/d9/0ed4c7098a861482a7b6a95603edce4c0d9db2311af23da1fb2b75ec26fc/charset_normalizer-3.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3", size = 100390, upload-time = "2025-10-14T04:41:11.915Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.3.1"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.11"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582, upload-time = "2025-10-12T14:55:20.501Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lark"
|
||||
version = "1.3.1"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/da/34/28fff3ab31ccff1fd4f6c7c7b0ceb2b6968d8ea4950663eadcb5720591a0/lark-1.3.1.tar.gz", hash = "sha256:b426a7a6d6d53189d318f2b6236ab5d6429eaf09259f1ca33eb716eed10d2905", size = 382732, upload-time = "2025-10-27T18:25:56.653Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/82/3d/14ce75ef66813643812f3093ab17e46d3a206942ce7376d31ec2d36229e7/lark-1.3.1-py3-none-any.whl", hash = "sha256:c629b661023a014c37da873b4ff58a817398d12635d3bbb2c5a03be7fe5d1e12", size = 113151, upload-time = "2025-10-27T18:25:54.882Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "25.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.6.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pycryptodomex"
|
||||
version = "3.23.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c9/85/e24bf90972a30b0fcd16c73009add1d7d7cd9140c2498a68252028899e41/pycryptodomex-3.23.0.tar.gz", hash = "sha256:71909758f010c82bc99b0abf4ea12012c98962fbf0583c2164f8b84533c2e4da", size = 4922157, upload-time = "2025-05-17T17:23:41.434Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/dd/9c/1a8f35daa39784ed8adf93a694e7e5dc15c23c741bbda06e1d45f8979e9e/pycryptodomex-3.23.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:06698f957fe1ab229a99ba2defeeae1c09af185baa909a31a5d1f9d42b1aaed6", size = 2499240, upload-time = "2025-05-17T17:22:46.953Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/7a/62/f5221a191a97157d240cf6643747558759126c76ee92f29a3f4aee3197a5/pycryptodomex-3.23.0-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:b2c2537863eccef2d41061e82a881dcabb04944c5c06c5aa7110b577cc487545", size = 1644042, upload-time = "2025-05-17T17:22:49.098Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/8c/fd/5a054543c8988d4ed7b612721d7e78a4b9bf36bc3c5ad45ef45c22d0060e/pycryptodomex-3.23.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:43c446e2ba8df8889e0e16f02211c25b4934898384c1ec1ec04d7889c0333587", size = 2186227, upload-time = "2025-05-17T17:22:51.139Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c8/a9/8862616a85cf450d2822dbd4fff1fcaba90877907a6ff5bc2672cafe42f8/pycryptodomex-3.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f489c4765093fb60e2edafdf223397bc716491b2b69fe74367b70d6999257a5c", size = 2272578, upload-time = "2025-05-17T17:22:53.676Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/46/9f/bda9c49a7c1842820de674ab36c79f4fbeeee03f8ff0e4f3546c3889076b/pycryptodomex-3.23.0-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bdc69d0d3d989a1029df0eed67cc5e8e5d968f3724f4519bd03e0ec68df7543c", size = 2312166, upload-time = "2025-05-17T17:22:56.585Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/03/cc/870b9bf8ca92866ca0186534801cf8d20554ad2a76ca959538041b7a7cf4/pycryptodomex-3.23.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:6bbcb1dd0f646484939e142462d9e532482bc74475cecf9c4903d4e1cd21f003", size = 2185467, upload-time = "2025-05-17T17:22:59.237Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/96/e3/ce9348236d8e669fea5dd82a90e86be48b9c341210f44e25443162aba187/pycryptodomex-3.23.0-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:8a4fcd42ccb04c31268d1efeecfccfd1249612b4de6374205376b8f280321744", size = 2346104, upload-time = "2025-05-17T17:23:02.112Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a5/e9/e869bcee87beb89040263c416a8a50204f7f7a83ac11897646c9e71e0daf/pycryptodomex-3.23.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:55ccbe27f049743a4caf4f4221b166560d3438d0b1e5ab929e07ae1702a4d6fd", size = 2271038, upload-time = "2025-05-17T17:23:04.872Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/8d/67/09ee8500dd22614af5fbaa51a4aee6e342b5fa8aecf0a6cb9cbf52fa6d45/pycryptodomex-3.23.0-cp37-abi3-win32.whl", hash = "sha256:189afbc87f0b9f158386bf051f720e20fa6145975f1e76369303d0f31d1a8d7c", size = 1771969, upload-time = "2025-05-17T17:23:07.115Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/69/96/11f36f71a865dd6df03716d33bd07a67e9d20f6b8d39820470b766af323c/pycryptodomex-3.23.0-cp37-abi3-win_amd64.whl", hash = "sha256:52e5ca58c3a0b0bd5e100a9fbc8015059b05cffc6c66ce9d98b4b45e023443b9", size = 1803124, upload-time = "2025-05-17T17:23:09.267Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/93/45c1cdcbeb182ccd2e144c693eaa097763b08b38cded279f0053ed53c553/pycryptodomex-3.23.0-cp37-abi3-win_arm64.whl", hash = "sha256:02d87b80778c171445d67e23d1caef279bf4b25c3597050ccd2e13970b57fd51", size = 1707161, upload-time = "2025-05-17T17:23:11.414Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f3/b8/3e76d948c3c4ac71335bbe75dac53e154b40b0f8f1f022dfa295257a0c96/pycryptodomex-3.23.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ebfff755c360d674306e5891c564a274a47953562b42fb74a5c25b8fc1fb1cb5", size = 1627695, upload-time = "2025-05-17T17:23:17.38Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6a/cf/80f4297a4820dfdfd1c88cf6c4666a200f204b3488103d027b5edd9176ec/pycryptodomex-3.23.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eca54f4bb349d45afc17e3011ed4264ef1cc9e266699874cdd1349c504e64798", size = 1675772, upload-time = "2025-05-17T17:23:19.202Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d1/42/1e969ee0ad19fe3134b0e1b856c39bd0b70d47a4d0e81c2a8b05727394c9/pycryptodomex-3.23.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f2596e643d4365e14d0879dc5aafe6355616c61c2176009270f3048f6d9a61f", size = 1668083, upload-time = "2025-05-17T17:23:21.867Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6e/c3/1de4f7631fea8a992a44ba632aa40e0008764c0fb9bf2854b0acf78c2cf2/pycryptodomex-3.23.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fdfac7cda115bca3a5abb2f9e43bc2fb66c2b65ab074913643803ca7083a79ea", size = 1706056, upload-time = "2025-05-17T17:23:24.031Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f2/5f/af7da8e6f1e42b52f44a24d08b8e4c726207434e2593732d39e7af5e7256/pycryptodomex-3.23.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:14c37aaece158d0ace436f76a7bb19093db3b4deade9797abfc39ec6cd6cc2fe", size = 1806478, upload-time = "2025-05-17T17:23:26.066Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.19.2"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "9.0.1"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||
{ name = "iniconfig" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pluggy" },
|
||||
{ name = "pygments" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/07/56/f013048ac4bc4c1d9be45afd4ab209ea62822fb1598f40687e6bf45dcea4/pytest-9.0.1.tar.gz", hash = "sha256:3e9c069ea73583e255c3b21cf46b8d3c56f6e3a1a8f6da94ccb0fcf57b9d73c8", size = 1564125, upload-time = "2025-11-12T13:05:09.333Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0b/8b/6300fb80f858cda1c51ffa17075df5d846757081d11ab4aa35cef9e6258b/pytest-9.0.1-py3-none-any.whl", hash = "sha256:67be0030d194df2dfa7b556f2e56fb3c3315bd5c8822c6951162b92b32ce7dad", size = 373668, upload-time = "2025-11-12T13:05:07.379Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ragflow-cli"
|
||||
version = "0.22.1"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "beartype" },
|
||||
{ name = "lark" },
|
||||
{ name = "pycryptodomex" },
|
||||
{ name = "requests" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
test = [
|
||||
{ name = "pytest" },
|
||||
{ name = "requests" },
|
||||
{ name = "requests-toolbelt" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "beartype", specifier = ">=0.20.0,<1.0.0" },
|
||||
{ name = "lark", specifier = ">=1.1.0" },
|
||||
{ name = "pycryptodomex", specifier = ">=3.10.0" },
|
||||
{ name = "requests", specifier = ">=2.30.0,<3.0.0" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
test = [
|
||||
{ name = "pytest", specifier = ">=8.3.5" },
|
||||
{ name = "requests", specifier = ">=2.32.3" },
|
||||
{ name = "requests-toolbelt", specifier = ">=1.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.5"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
{ name = "charset-normalizer" },
|
||||
{ name = "idna" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "requests-toolbelt"
|
||||
version = "1.0.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
dependencies = [
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f3/61/d7545dafb7ac2230c70d38d31cbfe4cc64f7144dc41f6e4e4b78ecd9f5bb/requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6", size = 206888, upload-time = "2023-05-01T04:11:33.229Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481, upload-time = "2023-05-01T04:11:28.427Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b6/eb/a7679c8ac85208706d27436e8d421dfa39d4c914dcf5fa8083a9305f58d9/tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456", size = 96445, upload-time = "2025-10-08T22:01:06.896Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a0/b7/a7a7042715d55c9ba6e8b196d65d2cb662578b4d8cd17d882d45322b0d78/tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876", size = 97124, upload-time = "2025-10-08T22:01:15.629Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.15.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.5.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" },
|
||||
]
|
||||
@ -20,8 +20,11 @@ import logging
|
||||
import time
|
||||
import threading
|
||||
import traceback
|
||||
from werkzeug.serving import run_simple
|
||||
import faulthandler
|
||||
|
||||
from flask import Flask
|
||||
from flask_login import LoginManager
|
||||
from werkzeug.serving import run_simple
|
||||
from routes import admin_bp
|
||||
from common.log_utils import init_root_logger
|
||||
from common.constants import SERVICE_CONF
|
||||
@ -30,12 +33,12 @@ from common import settings
|
||||
from config import load_configurations, SERVICE_CONFIGS
|
||||
from auth import init_default_admin, setup_auth
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from common.versions import get_ragflow_version
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
if __name__ == '__main__':
|
||||
faulthandler.enable()
|
||||
init_root_logger("admin_service")
|
||||
logging.info(r"""
|
||||
____ ___ ______________ ___ __ _
|
||||
|
||||
@ -19,7 +19,8 @@ import logging
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from datetime import datetime
|
||||
from flask import request, jsonify
|
||||
|
||||
from flask import jsonify, request
|
||||
from flask_login import current_user, login_user
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
@ -30,7 +31,7 @@ from common.constants import ActiveEnum, StatusEnum
|
||||
from api.utils.crypt import decrypt
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp, datetime_format, get_format_time
|
||||
from common.connection_utils import construct_response
|
||||
from common.connection_utils import sync_construct_response
|
||||
from common import settings
|
||||
|
||||
|
||||
@ -129,7 +130,7 @@ def login_admin(email: str, password: str):
|
||||
user.last_login_time = get_format_time()
|
||||
user.save()
|
||||
msg = "Welcome back!"
|
||||
return construct_response(data=resp, auth=user.get_id(), message=msg)
|
||||
return sync_construct_response(data=resp, auth=user.get_id(), message=msg)
|
||||
|
||||
|
||||
def check_admin(username: str, password: str):
|
||||
@ -169,7 +170,7 @@ def login_verify(f):
|
||||
username = auth.parameters['username']
|
||||
password = auth.parameters['password']
|
||||
try:
|
||||
if check_admin(username, password) is False:
|
||||
if not check_admin(username, password):
|
||||
return jsonify({
|
||||
"code": 500,
|
||||
"message": "Access denied",
|
||||
|
||||
@ -25,8 +25,21 @@ from common.config_utils import read_config
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
host: str
|
||||
port: int
|
||||
service_type: str
|
||||
detail_func_name: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
|
||||
'service_type': self.service_type}
|
||||
|
||||
|
||||
class ServiceConfigs:
|
||||
configs = dict
|
||||
configs = list[BaseConfig]
|
||||
|
||||
def __init__(self):
|
||||
self.configs = []
|
||||
@ -45,19 +58,6 @@ class ServiceType(Enum):
|
||||
FILE_STORE = "file_store"
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
host: str
|
||||
port: int
|
||||
service_type: str
|
||||
detail_func_name: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
|
||||
'service_type': self.service_type}
|
||||
|
||||
|
||||
class MetaConfig(BaseConfig):
|
||||
meta_type: str
|
||||
|
||||
@ -227,7 +227,7 @@ def load_configurations(config_path: str) -> list[BaseConfig]:
|
||||
ragflow_count = 0
|
||||
id_count = 0
|
||||
for k, v in raw_configs.items():
|
||||
match (k):
|
||||
match k:
|
||||
case "ragflow":
|
||||
name: str = f'ragflow_{ragflow_count}'
|
||||
host: str = v['host']
|
||||
|
||||
@ -13,8 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
from flask import jsonify
|
||||
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
import secrets
|
||||
|
||||
from flask import Blueprint, request
|
||||
from flask_login import current_user, logout_user, login_required
|
||||
from flask_login import current_user, login_required, logout_user
|
||||
|
||||
from auth import login_verify, login_admin, check_admin_auth
|
||||
from responses import success_response, error_response
|
||||
|
||||
@ -13,8 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import logging
|
||||
import re
|
||||
from werkzeug.security import check_password_hash
|
||||
from common.constants import ActiveEnum
|
||||
@ -190,7 +189,8 @@ class ServiceMgr:
|
||||
config_dict['status'] = service_detail['status']
|
||||
else:
|
||||
config_dict['status'] = 'timeout'
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.warning(f"Can't get service details, error: {e}")
|
||||
config_dict['status'] = 'timeout'
|
||||
if not config_dict['host']:
|
||||
config_dict['host'] = '-'
|
||||
@ -205,17 +205,13 @@ class ServiceMgr:
|
||||
|
||||
@staticmethod
|
||||
def get_service_details(service_id: int):
|
||||
service_id = int(service_id)
|
||||
service_idx = int(service_id)
|
||||
configs = SERVICE_CONFIGS.configs
|
||||
service_config_mapping = {
|
||||
c.id: {
|
||||
'name': c.name,
|
||||
'detail_func_name': c.detail_func_name
|
||||
} for c in configs
|
||||
}
|
||||
service_info = service_config_mapping.get(service_id, {})
|
||||
if not service_info:
|
||||
raise AdminException(f"invalid service_id: {service_id}")
|
||||
if service_idx < 0 or service_idx >= len(configs):
|
||||
raise AdminException(f"invalid service_index: {service_idx}")
|
||||
|
||||
service_config = configs[service_idx]
|
||||
service_info = {'name': service_config.name, 'detail_func_name': service_config.detail_func_name}
|
||||
|
||||
detail_func = getattr(health_utils, service_info.get('detail_func_name'))
|
||||
res = detail_func()
|
||||
|
||||
@ -14,5 +14,5 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from beartype.claw import beartype_this_package
|
||||
beartype_this_package()
|
||||
# from beartype.claw import beartype_this_package
|
||||
# beartype_this_package()
|
||||
|
||||
280
agent/canvas.py
280
agent/canvas.py
@ -13,7 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import base64
|
||||
import inspect
|
||||
import binascii
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@ -26,7 +29,9 @@ from typing import Any, Union, Tuple
|
||||
from agent.component import component_class
|
||||
from agent.component.base import ComponentBase
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.constants import LLMType
|
||||
from common.misc_utils import get_uuid, hash_str2int
|
||||
from common.exceptions import TaskCanceledException
|
||||
from rag.prompts.generator import chunks_format
|
||||
@ -80,14 +85,12 @@ class Graph:
|
||||
self.dsl = json.loads(dsl)
|
||||
self._tenant_id = tenant_id
|
||||
self.task_id = task_id if task_id else get_uuid()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=5)
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
self.components = self.dsl["components"]
|
||||
cpn_nms = set([])
|
||||
for k, cpn in self.components.items():
|
||||
cpn_nms.add(cpn["obj"]["component_name"])
|
||||
|
||||
for k, cpn in self.components.items():
|
||||
cpn_nms.add(cpn["obj"]["component_name"])
|
||||
param = component_class(cpn["obj"]["component_name"] + "Param")()
|
||||
@ -207,17 +210,60 @@ class Graph:
|
||||
for key in path.split('.'):
|
||||
if cur is None:
|
||||
return None
|
||||
|
||||
if isinstance(cur, str):
|
||||
try:
|
||||
cur = json.loads(cur)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if isinstance(cur, dict):
|
||||
cur = cur.get(key)
|
||||
else:
|
||||
cur = getattr(cur, key, None)
|
||||
continue
|
||||
|
||||
if isinstance(cur, (list, tuple)):
|
||||
try:
|
||||
idx = int(key)
|
||||
cur = cur[idx]
|
||||
except Exception:
|
||||
return None
|
||||
continue
|
||||
|
||||
cur = getattr(cur, key, None)
|
||||
return cur
|
||||
|
||||
def set_variable_value(self, exp: str,value):
|
||||
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
|
||||
if exp.find("@") < 0:
|
||||
self.globals[exp] = value
|
||||
return
|
||||
cpn_id, var_nm = exp.split("@")
|
||||
cpn = self.get_component(cpn_id)
|
||||
if not cpn:
|
||||
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
||||
parts = var_nm.split(".", 1)
|
||||
root_key = parts[0]
|
||||
rest = parts[1] if len(parts) > 1 else ""
|
||||
if not rest:
|
||||
cpn["obj"].set_output(root_key, value)
|
||||
return
|
||||
root_val = cpn["obj"].output(root_key)
|
||||
if not root_val:
|
||||
root_val = {}
|
||||
cpn["obj"].set_output(root_key, self.set_variable_param_value(root_val,rest,value))
|
||||
|
||||
def set_variable_param_value(self, obj: Any, path: str, value) -> Any:
|
||||
cur = obj
|
||||
keys = path.split('.')
|
||||
if not path:
|
||||
return value
|
||||
for key in keys:
|
||||
if key not in cur or not isinstance(cur[key], dict):
|
||||
cur[key] = {}
|
||||
cur = cur[key]
|
||||
cur[keys[-1]] = value
|
||||
return obj
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
return has_canceled(self.task_id)
|
||||
|
||||
@ -239,6 +285,7 @@ class Canvas(Graph):
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
self.variables = {}
|
||||
super().__init__(dsl, tenant_id, task_id)
|
||||
|
||||
def load(self):
|
||||
@ -253,6 +300,10 @@ class Canvas(Graph):
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
if "variables" in self.dsl:
|
||||
self.variables = self.dsl["variables"]
|
||||
else:
|
||||
self.variables = {}
|
||||
|
||||
self.retrieval = self.dsl["retrieval"]
|
||||
self.memory = self.dsl.get("memory", [])
|
||||
@ -269,6 +320,7 @@ class Canvas(Graph):
|
||||
self.history = []
|
||||
self.retrieval = []
|
||||
self.memory = []
|
||||
print(self.variables)
|
||||
for k in self.globals.keys():
|
||||
if k.startswith("sys."):
|
||||
if isinstance(self.globals[k], str):
|
||||
@ -283,9 +335,31 @@ class Canvas(Graph):
|
||||
self.globals[k] = {}
|
||||
else:
|
||||
self.globals[k] = None
|
||||
if k.startswith("env."):
|
||||
key = k[4:]
|
||||
if key in self.variables:
|
||||
variable = self.variables[key]
|
||||
if variable["value"]:
|
||||
self.globals[k] = variable["value"]
|
||||
else:
|
||||
if variable["type"] == "string":
|
||||
self.globals[k] = ""
|
||||
elif variable["type"] == "number":
|
||||
self.globals[k] = 0
|
||||
elif variable["type"] == "boolean":
|
||||
self.globals[k] = False
|
||||
elif variable["type"] == "object":
|
||||
self.globals[k] = {}
|
||||
elif variable["type"].startswith("array"):
|
||||
self.globals[k] = []
|
||||
else:
|
||||
self.globals[k] = ""
|
||||
else:
|
||||
self.globals[k] = ""
|
||||
|
||||
def run(self, **kwargs):
|
||||
async def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self.message_id = get_uuid()
|
||||
created_at = int(time.time())
|
||||
self.add_user_input(kwargs.get("query"))
|
||||
@ -298,12 +372,10 @@ class Canvas(Graph):
|
||||
for kk, vv in kwargs["webhook_payload"].items():
|
||||
self.components[k]["obj"].set_output(kk, vv)
|
||||
|
||||
self.components[k]["obj"].reset(True)
|
||||
|
||||
for k in kwargs.keys():
|
||||
if k in ["query", "user_id", "files"] and kwargs[k]:
|
||||
if k == "files":
|
||||
self.globals[f"sys.{k}"] = self.get_files(kwargs[k])
|
||||
self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k])
|
||||
else:
|
||||
self.globals[f"sys.{k}"] = kwargs[k]
|
||||
if not self.globals["sys.conversation_turns"] :
|
||||
@ -333,31 +405,50 @@ class Canvas(Graph):
|
||||
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
|
||||
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
|
||||
|
||||
def _run_batch(f, t):
|
||||
async def _run_batch(f, t):
|
||||
if self.is_canceled():
|
||||
msg = f"Task {self.task_id} has been canceled during batch execution."
|
||||
logging.info(msg)
|
||||
raise TaskCanceledException(msg)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
thr = []
|
||||
i = f
|
||||
while i < t:
|
||||
cpn = self.get_component_obj(self.path[i])
|
||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
|
||||
i += 1
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = []
|
||||
|
||||
def _run_async_in_thread(coro_func, **call_kwargs):
|
||||
return asyncio.run(coro_func(**call_kwargs))
|
||||
|
||||
i = f
|
||||
while i < t:
|
||||
cpn = self.get_component_obj(self.path[i])
|
||||
task_fn = None
|
||||
call_kwargs = None
|
||||
|
||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||
call_kwargs = {"inputs": kwargs.get("inputs", {})}
|
||||
task_fn = cpn.invoke
|
||||
i += 1
|
||||
else:
|
||||
for _, ele in cpn.get_input_elements().items():
|
||||
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
|
||||
self.path.pop(i)
|
||||
t -= 1
|
||||
break
|
||||
else:
|
||||
for _, ele in cpn.get_input_elements().items():
|
||||
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
|
||||
self.path.pop(i)
|
||||
t -= 1
|
||||
break
|
||||
else:
|
||||
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
||||
i += 1
|
||||
for t in thr:
|
||||
t.result()
|
||||
call_kwargs = cpn.get_input()
|
||||
task_fn = cpn.invoke
|
||||
i += 1
|
||||
|
||||
if task_fn is None:
|
||||
continue
|
||||
|
||||
invoke_async = getattr(cpn, "invoke_async", None)
|
||||
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
|
||||
else:
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def _node_finished(cpn_obj):
|
||||
return decorate("node_finished",{
|
||||
@ -374,6 +465,7 @@ class Canvas(Graph):
|
||||
self.error = ""
|
||||
idx = len(self.path) - 1
|
||||
partials = []
|
||||
tts_mdl = None
|
||||
while idx < len(self.path):
|
||||
to = len(self.path)
|
||||
for i in range(idx, to):
|
||||
@ -384,31 +476,70 @@ class Canvas(Graph):
|
||||
"component_type": self.get_component_type(self.path[i]),
|
||||
"thoughts": self.get_component_thoughts(self.path[i])
|
||||
})
|
||||
_run_batch(idx, to)
|
||||
await _run_batch(idx, to)
|
||||
to = len(self.path)
|
||||
# post processing of components invocation
|
||||
for i in range(idx, to):
|
||||
cpn = self.get_component(self.path[i])
|
||||
cpn_obj = self.get_component_obj(self.path[i])
|
||||
if cpn_obj.component_name.lower() == "message":
|
||||
if cpn_obj.get_param("auto_play"):
|
||||
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
|
||||
if isinstance(cpn_obj.output("content"), partial):
|
||||
_m = ""
|
||||
for m in cpn_obj.output("content")():
|
||||
buff_m = ""
|
||||
stream = cpn_obj.output("content")()
|
||||
async def _process_stream(m):
|
||||
nonlocal buff_m, _m, tts_mdl
|
||||
if not m:
|
||||
continue
|
||||
return
|
||||
if m == "<think>":
|
||||
yield decorate("message", {"content": "", "start_to_think": True})
|
||||
return decorate("message", {"content": "", "start_to_think": True})
|
||||
|
||||
elif m == "</think>":
|
||||
yield decorate("message", {"content": "", "end_to_think": True})
|
||||
else:
|
||||
yield decorate("message", {"content": m})
|
||||
_m += m
|
||||
return decorate("message", {"content": "", "end_to_think": True})
|
||||
|
||||
buff_m += m
|
||||
_m += m
|
||||
|
||||
if len(buff_m) > 16:
|
||||
ev = decorate(
|
||||
"message",
|
||||
{
|
||||
"content": m,
|
||||
"audio_binary": self.tts(tts_mdl, buff_m)
|
||||
}
|
||||
)
|
||||
buff_m = ""
|
||||
return ev
|
||||
|
||||
return decorate("message", {"content": m})
|
||||
|
||||
if inspect.isasyncgen(stream):
|
||||
async for m in stream:
|
||||
ev= await _process_stream(m)
|
||||
if ev:
|
||||
yield ev
|
||||
else:
|
||||
for m in stream:
|
||||
ev= await _process_stream(m)
|
||||
if ev:
|
||||
yield ev
|
||||
if buff_m:
|
||||
yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
|
||||
buff_m = ""
|
||||
cpn_obj.set_output("content", _m)
|
||||
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
||||
else:
|
||||
yield decorate("message", {"content": cpn_obj.output("content")})
|
||||
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
||||
yield decorate("message_end", {"reference": self.get_reference() if cite else None})
|
||||
|
||||
message_end = {}
|
||||
if isinstance(cpn_obj.output("attachment"), dict):
|
||||
message_end["attachment"] = cpn_obj.output("attachment")
|
||||
if cite:
|
||||
message_end["reference"] = self.get_reference()
|
||||
yield decorate("message_end", message_end)
|
||||
|
||||
while partials:
|
||||
_cpn_obj = self.get_component_obj(partials[0])
|
||||
@ -429,7 +560,7 @@ class Canvas(Graph):
|
||||
else:
|
||||
self.error = cpn_obj.error()
|
||||
|
||||
if cpn_obj.component_name.lower() != "iteration":
|
||||
if cpn_obj.component_name.lower() not in ("iteration","loop"):
|
||||
if isinstance(cpn_obj.output("content"), partial):
|
||||
if self.error:
|
||||
cpn_obj.set_output("content", None)
|
||||
@ -454,14 +585,16 @@ class Canvas(Graph):
|
||||
for cpn_id in cpn_ids:
|
||||
_append_path(cpn_id)
|
||||
|
||||
if cpn_obj.component_name.lower() == "iterationitem" and cpn_obj.end():
|
||||
if cpn_obj.component_name.lower() in ("iterationitem","loopitem") and cpn_obj.end():
|
||||
iter = cpn_obj.get_parent()
|
||||
yield _node_finished(iter)
|
||||
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
|
||||
elif cpn_obj.component_name.lower() in ["categorize", "switch"]:
|
||||
_extend_path(cpn_obj.output("_next"))
|
||||
elif cpn_obj.component_name.lower() == "iteration":
|
||||
elif cpn_obj.component_name.lower() in ("iteration", "loop"):
|
||||
_append_path(cpn_obj.get_start())
|
||||
elif cpn_obj.component_name.lower() == "exitloop" and cpn_obj.get_parent().component_name.lower() == "loop":
|
||||
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
|
||||
elif not cpn["downstream"] and cpn_obj.get_parent():
|
||||
_append_path(cpn_obj.get_parent().get_start())
|
||||
else:
|
||||
@ -517,6 +650,50 @@ class Canvas(Graph):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def tts(self,tts_mdl, text):
|
||||
def clean_tts_text(text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||||
|
||||
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||
|
||||
emoji_pattern = re.compile(
|
||||
"[\U0001F600-\U0001F64F"
|
||||
"\U0001F300-\U0001F5FF"
|
||||
"\U0001F680-\U0001F6FF"
|
||||
"\U0001F1E0-\U0001F1FF"
|
||||
"\U00002700-\U000027BF"
|
||||
"\U0001F900-\U0001F9FF"
|
||||
"\U0001FA70-\U0001FAFF"
|
||||
"\U0001FAD0-\U0001FAFF]+",
|
||||
flags=re.UNICODE
|
||||
)
|
||||
text = emoji_pattern.sub("", text)
|
||||
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
MAX_LEN = 500
|
||||
if len(text) > MAX_LEN:
|
||||
text = text[:MAX_LEN]
|
||||
|
||||
return text
|
||||
if not tts_mdl or not text:
|
||||
return None
|
||||
text = clean_tts_text(text)
|
||||
if not text:
|
||||
return None
|
||||
bin = b""
|
||||
try:
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
except Exception as e:
|
||||
logging.error(f"TTS failed: {e}, text={text!r}")
|
||||
return None
|
||||
return binascii.hexlify(bin).decode("utf-8")
|
||||
|
||||
def get_history(self, window_size):
|
||||
convs = []
|
||||
if window_size <= 0:
|
||||
@ -546,20 +723,30 @@ class Canvas(Graph):
|
||||
def get_component_input_elements(self, cpnnm):
|
||||
return self.components[cpnnm]["obj"].get_input_elements()
|
||||
|
||||
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
||||
async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]:
|
||||
if not files:
|
||||
return []
|
||||
def image_to_base64(file):
|
||||
return "data:{};base64,{}".format(file["mime_type"],
|
||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||
exe = ThreadPoolExecutor(max_workers=5)
|
||||
threads = []
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = []
|
||||
for file in files:
|
||||
if file["mime_type"].find("image") >=0:
|
||||
threads.append(exe.submit(image_to_base64, file))
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
|
||||
continue
|
||||
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||
return [th.result() for th in threads]
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
||||
"""
|
||||
Synchronous wrapper for get_files_async, used by sync component invoke paths.
|
||||
"""
|
||||
loop = getattr(self, "_loop", None)
|
||||
if loop and loop.is_running():
|
||||
return asyncio.run_coroutine_threadsafe(self.get_files_async(files), loop).result()
|
||||
|
||||
return asyncio.run(self.get_files_async(files))
|
||||
|
||||
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
|
||||
agent_ids = agent_id.split("-->")
|
||||
@ -613,4 +800,3 @@ class Canvas(Graph):
|
||||
|
||||
def get_component_thoughts(self, cpn_id) -> str:
|
||||
return self.components.get(cpn_id)["obj"].thoughts()
|
||||
|
||||
|
||||
@ -13,10 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
@ -28,9 +29,9 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from common.connection_utils import timeout
|
||||
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
|
||||
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||
from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
|
||||
citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
|
||||
|
||||
@ -137,8 +138,34 @@ class Agent(LLM, ToolBase):
|
||||
res.update(cpn.get_input_form())
|
||||
return res
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||
def _get_output_schema(self):
|
||||
try:
|
||||
cand = self._param.outputs.get("structured")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if isinstance(cand, dict):
|
||||
if isinstance(cand.get("properties"), dict) and len(cand["properties"]) > 0:
|
||||
return cand
|
||||
for k in ("schema", "structured"):
|
||||
if isinstance(cand.get(k), dict) and isinstance(cand[k].get("properties"), dict) and len(cand[k]["properties"]) > 0:
|
||||
return cand[k]
|
||||
|
||||
return None
|
||||
|
||||
async def _force_format_to_schema_async(self, text: str, schema_prompt: str) -> str:
|
||||
fmt_msgs = [
|
||||
{"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."},
|
||||
{"role": "user", "content": text},
|
||||
]
|
||||
_, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97))
|
||||
return await self._generate_async(fmt_msgs)
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
|
||||
@ -157,25 +184,25 @@ class Agent(LLM, ToolBase):
|
||||
if not self.tools:
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
return LLM._invoke(self, **kwargs)
|
||||
return await LLM._invoke_async(self, **kwargs)
|
||||
|
||||
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
||||
output_schema = self._get_output_schema()
|
||||
schema_prompt = ""
|
||||
if output_schema:
|
||||
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
|
||||
schema_prompt = structured_output_prompt(schema)
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
output_structure=None
|
||||
try:
|
||||
output_structure=self._param.outputs['structured']
|
||||
except Exception:
|
||||
pass
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
|
||||
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt))
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
use_tools = []
|
||||
ans = ""
|
||||
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
||||
async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
ans += delta_ans
|
||||
@ -188,16 +215,38 @@ class Agent(LLM, ToolBase):
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
|
||||
if output_schema:
|
||||
error = ""
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
try:
|
||||
def clean_formated_answer(ans: str) -> str:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
|
||||
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||
obj = json_repair.loads(clean_formated_answer(ans))
|
||||
self.set_output("structured", obj)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
return obj
|
||||
except Exception:
|
||||
error = "The answer cannot be parsed as JSON"
|
||||
ans = await self._force_format_to_schema_async(ans, schema_prompt)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
continue
|
||||
|
||||
self.set_output("_ERROR", error)
|
||||
return
|
||||
|
||||
self.set_output("content", ans)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
return ans
|
||||
|
||||
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
|
||||
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer_without_toolcall = ""
|
||||
use_tools = []
|
||||
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
||||
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
|
||||
@ -215,39 +264,23 @@ class Agent(LLM, ToolBase):
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
|
||||
def _gen_citations(self, text):
|
||||
retrievals = self._canvas.get_reference()
|
||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
||||
for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
||||
{"role": "user", "content": text}
|
||||
]):
|
||||
yield delta_ans
|
||||
|
||||
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}):
|
||||
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||
token_count = 0
|
||||
tool_metas = self.tool_meta
|
||||
hist = deepcopy(history)
|
||||
last_calling = ""
|
||||
if len(hist) > 3:
|
||||
st = timer()
|
||||
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
|
||||
user_request = await asyncio.to_thread(full_question, messages=history, chat_mdl=self.chat_mdl)
|
||||
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
||||
else:
|
||||
user_request = history[-1]["content"]
|
||||
|
||||
def use_tool(name, args):
|
||||
nonlocal hist, use_tools, token_count,last_calling,user_request
|
||||
async def use_tool_async(name, args):
|
||||
nonlocal hist, use_tools, last_calling
|
||||
logging.info(f"{last_calling=} == {name=}")
|
||||
# Summarize of function calling
|
||||
#if all([
|
||||
# isinstance(self.toolcall_session.get_tool_obj(name), Agent),
|
||||
# last_calling,
|
||||
# last_calling != name
|
||||
#]):
|
||||
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt))
|
||||
last_calling = name
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
tool_response = await self.toolcall_session.tool_call_async(name, args)
|
||||
use_tools.append({
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
@ -258,12 +291,16 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
return name, tool_response
|
||||
|
||||
def complete():
|
||||
async def complete():
|
||||
nonlocal hist
|
||||
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
|
||||
if schema_prompt:
|
||||
need2cite = False
|
||||
cited = False
|
||||
if hist[0]["role"] == "system" and need2cite:
|
||||
if len(hist) < 7:
|
||||
if hist and hist[0]["role"] == "system":
|
||||
if schema_prompt:
|
||||
hist[0]["content"] += "\n" + schema_prompt
|
||||
if need2cite and len(hist) < 7:
|
||||
hist[0]["content"] += citation_prompt()
|
||||
cited = True
|
||||
yield "", token_count
|
||||
@ -272,7 +309,7 @@ class Agent(LLM, ToolBase):
|
||||
if len(hist) > 12:
|
||||
_hist = [hist[0], hist[1], *hist[-10:]]
|
||||
entire_txt = ""
|
||||
for delta_ans in self._generate_streamly(_hist):
|
||||
async for delta_ans in self._generate_streamly_async(_hist):
|
||||
if not need2cite or cited:
|
||||
yield delta_ans, 0
|
||||
entire_txt += delta_ans
|
||||
@ -281,7 +318,7 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
st = timer()
|
||||
txt = ""
|
||||
for delta_ans in self._gen_citations(entire_txt):
|
||||
async for delta_ans in self._gen_citations_async(entire_txt):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
yield delta_ans, 0
|
||||
@ -296,14 +333,14 @@ class Agent(LLM, ToolBase):
|
||||
hist.append({"role": "user", "content": content})
|
||||
|
||||
st = timer()
|
||||
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||
task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||
for _ in range(self._param.max_rounds + 1):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||
response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||
# self.callback("next_step", {}, str(response)[:256]+"...")
|
||||
token_count += tk
|
||||
token_count += tk or 0
|
||||
hist.append({"role": "assistant", "content": response})
|
||||
try:
|
||||
functions = json_repair.loads(re.sub(r"```.*", "", response))
|
||||
@ -312,23 +349,24 @@ class Agent(LLM, ToolBase):
|
||||
for f in functions:
|
||||
if not isinstance(f, dict):
|
||||
raise TypeError(f"An object type should be returned, but `{f}`")
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
thr = []
|
||||
for func in functions:
|
||||
name = func["name"]
|
||||
args = func["arguments"]
|
||||
if name == COMPLETE_TASK:
|
||||
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
||||
for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
return
|
||||
|
||||
thr.append(executor.submit(use_tool, name, args))
|
||||
tool_tasks = []
|
||||
for func in functions:
|
||||
name = func["name"]
|
||||
args = func["arguments"]
|
||||
if name == COMPLETE_TASK:
|
||||
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
||||
async for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
return
|
||||
|
||||
st = timer()
|
||||
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt)
|
||||
append_user_content(hist, reflection)
|
||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||
tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
|
||||
|
||||
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
|
||||
st = timer()
|
||||
reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
|
||||
append_user_content(hist, reflection)
|
||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
|
||||
@ -352,27 +390,30 @@ Respond immediately with your final comprehensive answer.
|
||||
return
|
||||
append_user_content(hist, final_instruction)
|
||||
|
||||
for txt, tkcnt in complete():
|
||||
async for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
|
||||
def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str:
|
||||
# self.callback("get_useful_memory", {"topn": 3}, "...")
|
||||
mems = self._canvas.get_memory()
|
||||
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt)
|
||||
try:
|
||||
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
|
||||
mems = [mems[r] for r in rank]
|
||||
return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
async def _gen_citations_async(self, text):
|
||||
retrievals = self._canvas.get_reference()
|
||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
||||
async for delta_ans in self._generate_streamly_async([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
||||
{"role": "user", "content": text}
|
||||
]):
|
||||
yield delta_ans
|
||||
|
||||
return "Error occurred."
|
||||
|
||||
def reset(self, temp=False):
|
||||
def reset(self, only_output=False):
|
||||
"""
|
||||
Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession.
|
||||
"""
|
||||
for k in self._param.outputs.keys():
|
||||
self._param.outputs[k]["value"] = None
|
||||
|
||||
for k, cpn in self.tools.items():
|
||||
if hasattr(cpn, "reset") and callable(cpn.reset):
|
||||
cpn.reset()
|
||||
|
||||
if only_output:
|
||||
return
|
||||
for k in self._param.inputs.keys():
|
||||
self._param.inputs[k]["value"] = None
|
||||
self._param.debug_inputs = {}
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
from abc import ABC
|
||||
@ -445,6 +446,34 @@ class ComponentBase(ABC):
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
async def invoke_async(self, **kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Async wrapper for component invocation.
|
||||
Prefers coroutine `_invoke_async` if present; otherwise falls back to `_invoke`.
|
||||
Handles timing and error recording consistently with `invoke`.
|
||||
"""
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
try:
|
||||
if self.check_if_canceled("Component processing"):
|
||||
return
|
||||
|
||||
fn_async = getattr(self, "_invoke_async", None)
|
||||
if fn_async and asyncio.iscoroutinefunction(fn_async):
|
||||
await fn_async(**kwargs)
|
||||
elif asyncio.iscoroutinefunction(self._invoke):
|
||||
await self._invoke(**kwargs)
|
||||
else:
|
||||
await asyncio.to_thread(self._invoke, **kwargs)
|
||||
except Exception as e:
|
||||
if self.get_exception_default_value():
|
||||
self.set_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", str(e))
|
||||
logging.exception(e)
|
||||
self._param.debug_inputs = {}
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
@ -463,12 +492,15 @@ class ComponentBase(ABC):
|
||||
return self._param.outputs.get("_ERROR", {}).get("value")
|
||||
|
||||
def reset(self, only_output=False):
|
||||
for k in self._param.outputs.keys():
|
||||
self._param.outputs[k]["value"] = None
|
||||
outputs: dict = self._param.outputs # for better performance
|
||||
for k in outputs.keys():
|
||||
outputs[k]["value"] = None
|
||||
if only_output:
|
||||
return
|
||||
for k in self._param.inputs.keys():
|
||||
self._param.inputs[k]["value"] = None
|
||||
|
||||
inputs: dict = self._param.inputs # for better performance
|
||||
for k in inputs.keys():
|
||||
inputs[k]["value"] = None
|
||||
self._param.debug_inputs = {}
|
||||
|
||||
def get_input(self, key: str=None) -> Union[Any, dict[str, Any]]:
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from agent.component.fillup import UserFillUpParam, UserFillUp
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
|
||||
class BeginParam(UserFillUpParam):
|
||||
@ -48,7 +49,7 @@ class Begin(UserFillUp):
|
||||
if v.get("optional") and v.get("value", None) is None:
|
||||
v = None
|
||||
else:
|
||||
v = self._canvas.get_files([v["value"]])
|
||||
v = FileService.get_files([v["value"]])
|
||||
else:
|
||||
v = v.get("value")
|
||||
self.set_output(k, v)
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import ast
|
||||
import os
|
||||
|
||||
32
agent/component/exit_loop.py
Normal file
32
agent/component/exit_loop.py
Normal file
@ -0,0 +1,32 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class ExitLoopParam(ComponentParamBase, ABC):
|
||||
def check(self):
|
||||
return True
|
||||
|
||||
|
||||
class ExitLoop(ComponentBase, ABC):
|
||||
component_name = "ExitLoop"
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
pass
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return ""
|
||||
@ -18,6 +18,7 @@ import re
|
||||
from functools import partial
|
||||
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
|
||||
class UserFillUpParam(ComponentParamBase):
|
||||
@ -63,6 +64,13 @@ class UserFillUp(ComponentBase):
|
||||
for k, v in kwargs.get("inputs", {}).items():
|
||||
if self.check_if_canceled("UserFillUp processing"):
|
||||
return
|
||||
if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
|
||||
if v.get("optional") and v.get("value", None) is None:
|
||||
v = None
|
||||
else:
|
||||
v = FileService.get_files([v["value"]])
|
||||
else:
|
||||
v = v.get("value")
|
||||
self.set_output(k, v)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
|
||||
@ -32,6 +32,7 @@ class IterationParam(ComponentParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.items_ref = ""
|
||||
self.variable={}
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
|
||||
168
agent/component/list_operations.py
Normal file
168
agent/component/list_operations.py
Normal file
@ -0,0 +1,168 @@
|
||||
from abc import ABC
|
||||
import os
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
class ListOperationsParam(ComponentParamBase):
|
||||
"""
|
||||
Define the List Operations component parameters.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.query = ""
|
||||
self.operations = "topN"
|
||||
self.n=0
|
||||
self.sort_method = "asc"
|
||||
self.filter = {
|
||||
"operator": "=",
|
||||
"value": ""
|
||||
}
|
||||
self.outputs = {
|
||||
"result": {
|
||||
"value": [],
|
||||
"type": "Array of ?"
|
||||
},
|
||||
"first": {
|
||||
"value": "",
|
||||
"type": "?"
|
||||
},
|
||||
"last": {
|
||||
"value": "",
|
||||
"type": "?"
|
||||
}
|
||||
}
|
||||
|
||||
def check(self):
|
||||
self.check_empty(self.query, "query")
|
||||
self.check_valid_value(self.operations, "Support operations", ["topN","head","tail","filter","sort","drop_duplicates"])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {}
|
||||
|
||||
|
||||
class ListOperations(ComponentBase,ABC):
|
||||
component_name = "ListOperations"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
self.input_objects=[]
|
||||
inputs = getattr(self._param, "query", None)
|
||||
self.inputs = self._canvas.get_variable_value(inputs)
|
||||
if not isinstance(self.inputs, list):
|
||||
raise TypeError("The input of List Operations should be an array.")
|
||||
self.set_input_value(inputs, self.inputs)
|
||||
if self._param.operations == "topN":
|
||||
self._topN()
|
||||
elif self._param.operations == "head":
|
||||
self._head()
|
||||
elif self._param.operations == "tail":
|
||||
self._tail()
|
||||
elif self._param.operations == "filter":
|
||||
self._filter()
|
||||
elif self._param.operations == "sort":
|
||||
self._sort()
|
||||
elif self._param.operations == "drop_duplicates":
|
||||
self._drop_duplicates()
|
||||
|
||||
|
||||
def _coerce_n(self):
|
||||
try:
|
||||
return int(getattr(self._param, "n", 0))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _set_outputs(self, outputs):
|
||||
self._param.outputs["result"]["value"] = outputs
|
||||
self._param.outputs["first"]["value"] = outputs[0] if outputs else None
|
||||
self._param.outputs["last"]["value"] = outputs[-1] if outputs else None
|
||||
|
||||
def _topN(self):
|
||||
n = self._coerce_n()
|
||||
if n < 1:
|
||||
outputs = []
|
||||
else:
|
||||
n = min(n, len(self.inputs))
|
||||
outputs = self.inputs[:n]
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _head(self):
|
||||
n = self._coerce_n()
|
||||
if 1 <= n <= len(self.inputs):
|
||||
outputs = [self.inputs[n - 1]]
|
||||
else:
|
||||
outputs = []
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _tail(self):
|
||||
n = self._coerce_n()
|
||||
if 1 <= n <= len(self.inputs):
|
||||
outputs = [self.inputs[-n]]
|
||||
else:
|
||||
outputs = []
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _filter(self):
|
||||
self._set_outputs([i for i in self.inputs if self._eval(self._norm(i),self._param.filter["operator"],self._param.filter["value"])])
|
||||
|
||||
def _norm(self,v):
|
||||
s = "" if v is None else str(v)
|
||||
return s
|
||||
|
||||
def _eval(self, v, operator, value):
|
||||
if operator == "=":
|
||||
return v == value
|
||||
elif operator == "≠":
|
||||
return v != value
|
||||
elif operator == "contains":
|
||||
return value in v
|
||||
elif operator == "start with":
|
||||
return v.startswith(value)
|
||||
elif operator == "end with":
|
||||
return v.endswith(value)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _sort(self):
|
||||
items = self.inputs or []
|
||||
method = getattr(self._param, "sort_method", "asc") or "asc"
|
||||
reverse = method == "desc"
|
||||
|
||||
if not items:
|
||||
self._set_outputs([])
|
||||
return
|
||||
|
||||
first = items[0]
|
||||
|
||||
if isinstance(first, dict):
|
||||
outputs = sorted(
|
||||
items,
|
||||
key=lambda x: self._hashable(x),
|
||||
reverse=reverse,
|
||||
)
|
||||
else:
|
||||
outputs = sorted(items, reverse=reverse)
|
||||
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _drop_duplicates(self):
|
||||
seen = set()
|
||||
outs = []
|
||||
for item in self.inputs:
|
||||
k = self._hashable(item)
|
||||
if k in seen:
|
||||
continue
|
||||
seen.add(k)
|
||||
outs.append(item)
|
||||
self._set_outputs(outs)
|
||||
|
||||
def _hashable(self,x):
|
||||
if isinstance(x, dict):
|
||||
return tuple(sorted((k, self._hashable(v)) for k, v in x.items()))
|
||||
if isinstance(x, (list, tuple)):
|
||||
return tuple(self._hashable(v) for v in x)
|
||||
if isinstance(x, set):
|
||||
return tuple(sorted(self._hashable(v) for v in x))
|
||||
return x
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "ListOperation in progress"
|
||||
@ -13,12 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from copy import deepcopy
|
||||
from typing import Any, Generator
|
||||
from typing import Any, Generator, AsyncGenerator
|
||||
import json_repair
|
||||
from functools import partial
|
||||
from common.constants import LLMType
|
||||
@ -171,6 +173,13 @@ class LLM(ComponentBase):
|
||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||
|
||||
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
|
||||
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
|
||||
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
|
||||
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||
return await asyncio.to_thread(self._generate, msg, **kwargs)
|
||||
|
||||
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
|
||||
ans = ""
|
||||
last_idx = 0
|
||||
@ -205,8 +214,120 @@ class LLM(ComponentBase):
|
||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
||||
yield delta(txt)
|
||||
|
||||
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
|
||||
async def delta_wrapper(txt_iter):
|
||||
ans = ""
|
||||
last_idx = 0
|
||||
endswith_think = False
|
||||
|
||||
def delta(txt):
|
||||
nonlocal ans, last_idx, endswith_think
|
||||
delta_ans = txt[last_idx:]
|
||||
ans = txt
|
||||
|
||||
if delta_ans.find("<think>") == 0:
|
||||
last_idx += len("<think>")
|
||||
return "<think>"
|
||||
elif delta_ans.find("<think>") > 0:
|
||||
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
|
||||
last_idx += delta_ans.find("<think>")
|
||||
return delta_ans
|
||||
elif delta_ans.endswith("</think>"):
|
||||
endswith_think = True
|
||||
elif endswith_think:
|
||||
endswith_think = False
|
||||
return "</think>"
|
||||
|
||||
last_idx = len(ans)
|
||||
if ans.endswith("</think>"):
|
||||
last_idx -= len("</think>")
|
||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
async for t in txt_iter:
|
||||
yield delta(t)
|
||||
|
||||
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
|
||||
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
|
||||
yield t
|
||||
return
|
||||
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
|
||||
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
|
||||
yield t
|
||||
return
|
||||
|
||||
# fallback
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def worker():
|
||||
try:
|
||||
for item in self._generate_streamly(msg, **kwargs):
|
||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||
except Exception as e:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||
finally:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||
|
||||
threading.Thread(target=worker, daemon=True).start()
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is StopAsyncIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
async def _stream_output_async(self, prompt, msg):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer = ""
|
||||
last_idx = 0
|
||||
endswith_think = False
|
||||
|
||||
def delta(txt):
|
||||
nonlocal answer, last_idx, endswith_think
|
||||
delta_ans = txt[last_idx:]
|
||||
answer = txt
|
||||
|
||||
if delta_ans.find("<think>") == 0:
|
||||
last_idx += len("<think>")
|
||||
return "<think>"
|
||||
elif delta_ans.find("<think>") > 0:
|
||||
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
|
||||
last_idx += delta_ans.find("<think>")
|
||||
return delta_ans
|
||||
elif delta_ans.endswith("</think>"):
|
||||
endswith_think = True
|
||||
elif endswith_think:
|
||||
endswith_think = False
|
||||
return "</think>"
|
||||
|
||||
last_idx = len(answer)
|
||||
if answer.endswith("</think>"):
|
||||
last_idx -= len("</think>")
|
||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
stream_kwargs = {"images": self.imgs} if self.imgs else {}
|
||||
async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs):
|
||||
if self.check_if_canceled("LLM streaming"):
|
||||
return
|
||||
|
||||
if isinstance(ans, int):
|
||||
continue
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
|
||||
yield delta(ans)
|
||||
|
||||
self.set_output("content", answer)
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
@ -217,22 +338,25 @@ class LLM(ComponentBase):
|
||||
|
||||
prompt, msg, _ = self._prepare_prompt_variables()
|
||||
error: str = ""
|
||||
output_structure=None
|
||||
output_structure = None
|
||||
try:
|
||||
output_structure = self._param.outputs['structured']
|
||||
output_structure = self._param.outputs["structured"]
|
||||
except Exception:
|
||||
pass
|
||||
if output_structure:
|
||||
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||
prompt += structured_output_prompt(schema)
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0:
|
||||
schema = json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||
prompt_with_schema = prompt + structured_output_prompt(schema)
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
_, msg_fit = message_fit_in(
|
||||
[{"role": "system", "content": prompt_with_schema}, *deepcopy(msg)],
|
||||
int(self.chat_mdl.max_length * 0.97),
|
||||
)
|
||||
error = ""
|
||||
ans = self._generate(msg)
|
||||
msg.pop(0)
|
||||
ans = await self._generate_async(msg_fit)
|
||||
msg_fit.pop(0)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"LLM response error: {ans}")
|
||||
error = ans
|
||||
@ -241,7 +365,7 @@ class LLM(ComponentBase):
|
||||
self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
|
||||
return
|
||||
except Exception:
|
||||
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
|
||||
msg_fit.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
|
||||
error = "The answer can't not be parsed as JSON"
|
||||
if error:
|
||||
self.set_output("_ERROR", error)
|
||||
@ -249,18 +373,23 @@ class LLM(ComponentBase):
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self._stream_output, prompt, msg))
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams]) and not (
|
||||
ex and ex["goto"]
|
||||
):
|
||||
self.set_output("content", partial(self._stream_output_async, prompt, deepcopy(msg)))
|
||||
return
|
||||
|
||||
for _ in range(self._param.max_retries+1):
|
||||
error = ""
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
_, msg_fit = message_fit_in(
|
||||
[{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97)
|
||||
)
|
||||
error = ""
|
||||
ans = self._generate(msg)
|
||||
msg.pop(0)
|
||||
ans = await self._generate_async(msg_fit)
|
||||
msg_fit.pop(0)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"LLM response error: {ans}")
|
||||
error = ans
|
||||
@ -274,23 +403,9 @@ class LLM(ComponentBase):
|
||||
else:
|
||||
self.set_output("_ERROR", error)
|
||||
|
||||
def _stream_output(self, prompt, msg):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer = ""
|
||||
for ans in self._generate_streamly(msg):
|
||||
if self.check_if_canceled("LLM streaming"):
|
||||
return
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
yield ans
|
||||
answer += ans
|
||||
self.set_output("content", answer)
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
|
||||
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
|
||||
|
||||
80
agent/component/loop.py
Normal file
80
agent/component/loop.py
Normal file
@ -0,0 +1,80 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class LoopParam(ComponentParamBase):
|
||||
"""
|
||||
Define the Loop component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loop_variables = []
|
||||
self.loop_termination_condition=[]
|
||||
self.maximum_loop_count = 0
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"items": {
|
||||
"type": "json",
|
||||
"name": "Items"
|
||||
}
|
||||
}
|
||||
|
||||
def check(self):
|
||||
return True
|
||||
|
||||
|
||||
class Loop(ComponentBase, ABC):
|
||||
component_name = "Loop"
|
||||
|
||||
def get_start(self):
|
||||
for cid in self._canvas.components.keys():
|
||||
if self._canvas.get_component(cid)["obj"].component_name.lower() != "loopitem":
|
||||
continue
|
||||
if self._canvas.get_component(cid)["parent_id"] == self._id:
|
||||
return cid
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Loop processing"):
|
||||
return
|
||||
|
||||
for item in self._param.loop_variables:
|
||||
if any([not item.get("variable"), not item.get("input_mode"), not item.get("value"),not item.get("type")]):
|
||||
assert "Loop Variable is not complete."
|
||||
if item["input_mode"]=="variable":
|
||||
self.set_output(item["variable"],self._canvas.get_variable_value(item["value"]))
|
||||
elif item["input_mode"]=="constant":
|
||||
self.set_output(item["variable"],item["value"])
|
||||
else:
|
||||
if item["type"] == "number":
|
||||
self.set_output(item["variable"], 0)
|
||||
elif item["type"] == "string":
|
||||
self.set_output(item["variable"], "")
|
||||
elif item["type"] == "boolean":
|
||||
self.set_output(item["variable"], False)
|
||||
elif item["type"].startswith("object"):
|
||||
self.set_output(item["variable"], {})
|
||||
elif item["type"].startswith("array"):
|
||||
self.set_output(item["variable"], [])
|
||||
else:
|
||||
self.set_output(item["variable"], "")
|
||||
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Loop from canvas."
|
||||
163
agent/component/loopitem.py
Normal file
163
agent/component/loopitem.py
Normal file
@ -0,0 +1,163 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class LoopItemParam(ComponentParamBase):
|
||||
"""
|
||||
Define the LoopItem component parameters.
|
||||
"""
|
||||
def check(self):
|
||||
return True
|
||||
|
||||
class LoopItem(ComponentBase, ABC):
|
||||
component_name = "LoopItem"
|
||||
|
||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||
super().__init__(canvas, id, param)
|
||||
self._idx = 0
|
||||
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("LoopItem processing"):
|
||||
return
|
||||
parent = self.get_parent()
|
||||
maximum_loop_count = parent._param.maximum_loop_count
|
||||
if self._idx >= maximum_loop_count:
|
||||
self._idx = -1
|
||||
return
|
||||
if self._idx > 0:
|
||||
if self.check_if_canceled("LoopItem processing"):
|
||||
return
|
||||
self._idx += 1
|
||||
|
||||
def evaluate_condition(self,var, operator, value):
|
||||
if isinstance(var, str):
|
||||
if operator == "contains":
|
||||
return value in var
|
||||
elif operator == "not contains":
|
||||
return value not in var
|
||||
elif operator == "start with":
|
||||
return var.startswith(value)
|
||||
elif operator == "end with":
|
||||
return var.endswith(value)
|
||||
elif operator == "is":
|
||||
return var == value
|
||||
elif operator == "is not":
|
||||
return var != value
|
||||
elif operator == "empty":
|
||||
return var == ""
|
||||
elif operator == "not empty":
|
||||
return var != ""
|
||||
|
||||
elif isinstance(var, (int, float)):
|
||||
if operator == "=":
|
||||
return var == value
|
||||
elif operator == "≠":
|
||||
return var != value
|
||||
elif operator == ">":
|
||||
return var > value
|
||||
elif operator == "<":
|
||||
return var < value
|
||||
elif operator == "≥":
|
||||
return var >= value
|
||||
elif operator == "≤":
|
||||
return var <= value
|
||||
elif operator == "empty":
|
||||
return var is None
|
||||
elif operator == "not empty":
|
||||
return var is not None
|
||||
|
||||
elif isinstance(var, bool):
|
||||
if operator == "is":
|
||||
return var is value
|
||||
elif operator == "is not":
|
||||
return var is not value
|
||||
elif operator == "empty":
|
||||
return var is None
|
||||
elif operator == "not empty":
|
||||
return var is not None
|
||||
|
||||
elif isinstance(var, dict):
|
||||
if operator == "empty":
|
||||
return len(var) == 0
|
||||
elif operator == "not empty":
|
||||
return len(var) > 0
|
||||
|
||||
elif isinstance(var, list):
|
||||
if operator == "contains":
|
||||
return value in var
|
||||
elif operator == "not contains":
|
||||
return value not in var
|
||||
|
||||
elif operator == "is":
|
||||
return var == value
|
||||
elif operator == "is not":
|
||||
return var != value
|
||||
|
||||
elif operator == "empty":
|
||||
return len(var) == 0
|
||||
elif operator == "not empty":
|
||||
return len(var) > 0
|
||||
|
||||
raise Exception(f"Invalid operator: {operator}")
|
||||
|
||||
def end(self):
|
||||
if self._idx == -1:
|
||||
return True
|
||||
parent = self.get_parent()
|
||||
logical_operator = parent._param.logical_operator if hasattr(parent._param, "logical_operator") else "and"
|
||||
conditions = []
|
||||
for item in parent._param.loop_termination_condition:
|
||||
if not item.get("variable") or not item.get("operator"):
|
||||
raise ValueError("Loop condition is incomplete.")
|
||||
var = self._canvas.get_variable_value(item["variable"])
|
||||
operator = item["operator"]
|
||||
input_mode = item.get("input_mode", "constant")
|
||||
|
||||
if input_mode == "variable":
|
||||
value = self._canvas.get_variable_value(item.get("value", ""))
|
||||
elif input_mode == "constant":
|
||||
value = item.get("value", "")
|
||||
else:
|
||||
raise ValueError("Invalid input mode.")
|
||||
conditions.append(self.evaluate_condition(var, operator, value))
|
||||
should_end = (
|
||||
all(conditions) if logical_operator == "and"
|
||||
else any(conditions) if logical_operator == "or"
|
||||
else None
|
||||
)
|
||||
if should_end is None:
|
||||
raise ValueError("Invalid logical operator,should be 'and' or 'or'.")
|
||||
|
||||
if should_end:
|
||||
self._idx = -1
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def next(self):
|
||||
if self._idx == -1:
|
||||
self._idx = 0
|
||||
else:
|
||||
self._idx += 1
|
||||
if self._idx >= len(self._items):
|
||||
self._idx = -1
|
||||
return False
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Next turn..."
|
||||
@ -13,10 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import logging
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
@ -24,6 +28,8 @@ from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from jinja2 import Template as Jinja2Template
|
||||
|
||||
from common.connection_utils import timeout
|
||||
from common.misc_utils import get_uuid
|
||||
from common import settings
|
||||
|
||||
|
||||
class MessageParam(ComponentParamBase):
|
||||
@ -34,6 +40,8 @@ class MessageParam(ComponentParamBase):
|
||||
super().__init__()
|
||||
self.content = []
|
||||
self.stream = True
|
||||
self.output_format = None # default output format
|
||||
self.auto_play = False
|
||||
self.outputs = {
|
||||
"content": {
|
||||
"type": "str"
|
||||
@ -61,8 +69,12 @@ class Message(ComponentBase):
|
||||
v = ""
|
||||
ans = ""
|
||||
if isinstance(v, partial):
|
||||
for t in v():
|
||||
ans += t
|
||||
iter_obj = v()
|
||||
if inspect.isasyncgen(iter_obj):
|
||||
ans = asyncio.run(self._consume_async_gen(iter_obj))
|
||||
else:
|
||||
for t in iter_obj:
|
||||
ans += t
|
||||
elif isinstance(v, list) and delimiter:
|
||||
ans = delimiter.join([str(vv) for vv in v])
|
||||
elif not isinstance(v, str):
|
||||
@ -84,7 +96,13 @@ class Message(ComponentBase):
|
||||
_kwargs[_n] = v
|
||||
return script, _kwargs
|
||||
|
||||
def _stream(self, rand_cnt:str):
|
||||
async def _consume_async_gen(self, agen):
|
||||
buf = ""
|
||||
async for t in agen:
|
||||
buf += t
|
||||
return buf
|
||||
|
||||
async def _stream(self, rand_cnt:str):
|
||||
s = 0
|
||||
all_content = ""
|
||||
cache = {}
|
||||
@ -106,15 +124,27 @@ class Message(ComponentBase):
|
||||
v = ""
|
||||
if isinstance(v, partial):
|
||||
cnt = ""
|
||||
for t in v():
|
||||
if self.check_if_canceled("Message streaming"):
|
||||
return
|
||||
iter_obj = v()
|
||||
if inspect.isasyncgen(iter_obj):
|
||||
async for t in iter_obj:
|
||||
if self.check_if_canceled("Message streaming"):
|
||||
return
|
||||
|
||||
all_content += t
|
||||
cnt += t
|
||||
yield t
|
||||
all_content += t
|
||||
cnt += t
|
||||
yield t
|
||||
else:
|
||||
for t in iter_obj:
|
||||
if self.check_if_canceled("Message streaming"):
|
||||
return
|
||||
|
||||
all_content += t
|
||||
cnt += t
|
||||
yield t
|
||||
self.set_input_value(exp, cnt)
|
||||
continue
|
||||
elif inspect.isawaitable(v):
|
||||
v = await v
|
||||
elif not isinstance(v, str):
|
||||
try:
|
||||
v = json.dumps(v, ensure_ascii=False)
|
||||
@ -133,6 +163,7 @@ class Message(ComponentBase):
|
||||
yield rand_cnt[s: ]
|
||||
|
||||
self.set_output("content", all_content)
|
||||
self._convert_content(all_content)
|
||||
|
||||
def _is_jinjia2(self, content:str) -> bool:
|
||||
patt = [
|
||||
@ -164,6 +195,72 @@ class Message(ComponentBase):
|
||||
content = re.sub(n, v, content)
|
||||
|
||||
self.set_output("content", content)
|
||||
self._convert_content(content)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return ""
|
||||
|
||||
def _convert_content(self, content):
|
||||
if not self._param.output_format:
|
||||
return
|
||||
|
||||
import pypandoc
|
||||
doc_id = get_uuid()
|
||||
|
||||
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}:
|
||||
self._param.output_format = "markdown"
|
||||
|
||||
try:
|
||||
if self._param.output_format in {"markdown", "html"}:
|
||||
if isinstance(content, str):
|
||||
converted = pypandoc.convert_text(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
)
|
||||
else:
|
||||
converted = pypandoc.convert_file(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
)
|
||||
|
||||
binary_content = converted.encode("utf-8")
|
||||
|
||||
else: # pdf, docx
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{self._param.output_format}", delete=False) as tmp:
|
||||
tmp_name = tmp.name
|
||||
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
pypandoc.convert_text(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
outputfile=tmp_name,
|
||||
)
|
||||
else:
|
||||
pypandoc.convert_file(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
outputfile=tmp_name,
|
||||
)
|
||||
|
||||
with open(tmp_name, "rb") as f:
|
||||
binary_content = f.read()
|
||||
|
||||
finally:
|
||||
if os.path.exists(tmp_name):
|
||||
os.remove(tmp_name)
|
||||
|
||||
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
|
||||
self.set_output("attachment", {
|
||||
"doc_id":doc_id,
|
||||
"format":self._param.output_format,
|
||||
"file_name":f"{doc_id[:8]}.{self._param.output_format}"})
|
||||
|
||||
logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error converting content to {self._param.output_format}: {e}")
|
||||
|
||||
192
agent/component/variable_assigner.py
Normal file
192
agent/component/variable_assigner.py
Normal file
@ -0,0 +1,192 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import os
|
||||
import numbers
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
class VariableAssignerParam(ComponentParamBase):
|
||||
"""
|
||||
Define the Variable Assigner component parameters.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.variables=[]
|
||||
|
||||
def check(self):
|
||||
return True
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"items": {
|
||||
"type": "json",
|
||||
"name": "Items"
|
||||
}
|
||||
}
|
||||
|
||||
class VariableAssigner(ComponentBase,ABC):
|
||||
component_name = "VariableAssigner"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not isinstance(self._param.variables,list):
|
||||
return
|
||||
else:
|
||||
for item in self._param.variables:
|
||||
if any([not item.get("variable"), not item.get("operator"), not item.get("parameter")]):
|
||||
assert "Variable is not complete."
|
||||
variable=item["variable"]
|
||||
operator=item["operator"]
|
||||
parameter=item["parameter"]
|
||||
variable_value=self._canvas.get_variable_value(variable)
|
||||
new_variable=self._operate(variable_value,operator,parameter)
|
||||
self._canvas.set_variable_value(variable, new_variable)
|
||||
|
||||
def _operate(self,variable,operator,parameter):
|
||||
if operator == "overwrite":
|
||||
return self._overwrite(parameter)
|
||||
elif operator == "clear":
|
||||
return self._clear(variable)
|
||||
elif operator == "set":
|
||||
return self._set(variable,parameter)
|
||||
elif operator == "append":
|
||||
return self._append(variable,parameter)
|
||||
elif operator == "extend":
|
||||
return self._extend(variable,parameter)
|
||||
elif operator == "remove_first":
|
||||
return self._remove_first(variable)
|
||||
elif operator == "remove_last":
|
||||
return self._remove_last(variable)
|
||||
elif operator == "+=":
|
||||
return self._add(variable,parameter)
|
||||
elif operator == "-=":
|
||||
return self._subtract(variable,parameter)
|
||||
elif operator == "*=":
|
||||
return self._multiply(variable,parameter)
|
||||
elif operator == "/=":
|
||||
return self._divide(variable,parameter)
|
||||
else:
|
||||
return
|
||||
|
||||
def _overwrite(self,parameter):
|
||||
return self._canvas.get_variable_value(parameter)
|
||||
|
||||
def _clear(self,variable):
|
||||
if isinstance(variable,list):
|
||||
return []
|
||||
elif isinstance(variable,str):
|
||||
return ""
|
||||
elif isinstance(variable,dict):
|
||||
return {}
|
||||
elif isinstance(variable,int):
|
||||
return 0
|
||||
elif isinstance(variable,float):
|
||||
return 0.0
|
||||
elif isinstance(variable,bool):
|
||||
return False
|
||||
else:
|
||||
return None
|
||||
|
||||
def _set(self,variable,parameter):
|
||||
if variable is None:
|
||||
return self._canvas.get_value_with_variable(parameter)
|
||||
elif isinstance(variable,str):
|
||||
return self._canvas.get_value_with_variable(parameter)
|
||||
elif isinstance(variable,bool):
|
||||
return parameter
|
||||
elif isinstance(variable,int):
|
||||
return parameter
|
||||
elif isinstance(variable,float):
|
||||
return parameter
|
||||
else:
|
||||
return parameter
|
||||
|
||||
def _append(self,variable,parameter):
|
||||
parameter=self._canvas.get_variable_value(parameter)
|
||||
if variable is None:
|
||||
variable=[]
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
elif len(variable)!=0 and not isinstance(parameter,type(variable[0])):
|
||||
return "ERROR:PARAMETER_NOT_LIST_ELEMENT_TYPE"
|
||||
else:
|
||||
variable.append(parameter)
|
||||
return variable
|
||||
|
||||
def _extend(self,variable,parameter):
|
||||
parameter=self._canvas.get_variable_value(parameter)
|
||||
if variable is None:
|
||||
variable=[]
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
elif not isinstance(parameter,list):
|
||||
return "ERROR:PARAMETER_NOT_LIST"
|
||||
elif len(variable)!=0 and len(parameter)!=0 and not isinstance(parameter[0],type(variable[0])):
|
||||
return "ERROR:PARAMETER_NOT_LIST_ELEMENT_TYPE"
|
||||
else:
|
||||
return variable + parameter
|
||||
|
||||
def _remove_first(self,variable):
|
||||
if len(variable)==0:
|
||||
return variable
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
else:
|
||||
return variable[1:]
|
||||
|
||||
def _remove_last(self,variable):
|
||||
if len(variable)==0:
|
||||
return variable
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
else:
|
||||
return variable[:-1]
|
||||
|
||||
def is_number(self, value):
|
||||
if isinstance(value, bool):
|
||||
return False
|
||||
return isinstance(value, numbers.Number)
|
||||
|
||||
def _add(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
return variable + parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def _subtract(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
return variable - parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def _multiply(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
return variable * parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def _divide(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
if parameter==0:
|
||||
return "ERROR:DIVIDE_BY_ZERO"
|
||||
else:
|
||||
return variable/parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Assign variables from canvas."
|
||||
File diff suppressed because one or more lines are too long
@ -83,10 +83,10 @@
|
||||
"value": []
|
||||
}
|
||||
},
|
||||
"password": "20010812Yy!",
|
||||
"password": "",
|
||||
"port": 3306,
|
||||
"sql": "{Agent:WickedGoatsDivide@content}",
|
||||
"username": "13637682833@163.com"
|
||||
"username": ""
|
||||
}
|
||||
},
|
||||
"upstream": [
|
||||
@ -527,10 +527,10 @@
|
||||
"value": []
|
||||
}
|
||||
},
|
||||
"password": "20010812Yy!",
|
||||
"password": "",
|
||||
"port": 3306,
|
||||
"sql": "{Agent:WickedGoatsDivide@content}",
|
||||
"username": "13637682833@163.com"
|
||||
"username": ""
|
||||
},
|
||||
"label": "ExeSQL",
|
||||
"name": "ExeSQL"
|
||||
|
||||
@ -17,13 +17,13 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import TypedDict, List, Any
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
from common.misc_utils import hash_str2int
|
||||
from rag.llm.chat_model import ToolCallSession
|
||||
from rag.prompts.generator import kb_prompt
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
|
||||
from timeit import default_timer as timer
|
||||
|
||||
|
||||
@ -49,12 +49,19 @@ class LLMToolPluginCallSession(ToolCallSession):
|
||||
self.callback = callback
|
||||
|
||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||
return asyncio.run(self.tool_call_async(name, arguments))
|
||||
|
||||
async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
||||
st = timer()
|
||||
if isinstance(self.tools_map[name], MCPToolCallSession):
|
||||
resp = self.tools_map[name].tool_call(name, arguments, 60)
|
||||
tool_obj = self.tools_map[name]
|
||||
if isinstance(tool_obj, MCPToolCallSession):
|
||||
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
|
||||
else:
|
||||
resp = self.tools_map[name].invoke(**arguments)
|
||||
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
|
||||
resp = await tool_obj.invoke_async(**arguments)
|
||||
else:
|
||||
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
|
||||
|
||||
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
||||
return resp
|
||||
@ -140,6 +147,33 @@ class ToolBase(ComponentBase):
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return res
|
||||
|
||||
async def invoke_async(self, **kwargs):
|
||||
"""
|
||||
Async wrapper for tool invocation.
|
||||
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
|
||||
Mirrors the exception handling of `invoke`.
|
||||
"""
|
||||
if self.check_if_canceled("Tool processing"):
|
||||
return
|
||||
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
try:
|
||||
fn_async = getattr(self, "_invoke_async", None)
|
||||
if fn_async and asyncio.iscoroutinefunction(fn_async):
|
||||
res = await fn_async(**kwargs)
|
||||
elif asyncio.iscoroutinefunction(self._invoke):
|
||||
res = await self._invoke(**kwargs)
|
||||
else:
|
||||
res = await asyncio.to_thread(self._invoke, **kwargs)
|
||||
except Exception as e:
|
||||
self._param.outputs["_ERROR"] = {"value": str(e)}
|
||||
logging.exception(e)
|
||||
res = str(e)
|
||||
self._param.debug_inputs = []
|
||||
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return res
|
||||
|
||||
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
|
||||
chunks = []
|
||||
aggs = []
|
||||
|
||||
@ -13,16 +13,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import ast
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC
|
||||
from strenum import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||
from common.connection_utils import timeout
|
||||
from strenum import StrEnum
|
||||
|
||||
from agent.tools.base import ToolBase, ToolMeta, ToolParamBase
|
||||
from common import settings
|
||||
from common.connection_utils import timeout
|
||||
|
||||
|
||||
class Language(StrEnum):
|
||||
@ -62,10 +66,10 @@ class CodeExecParam(ToolParamBase):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
self.meta: ToolMeta = {
|
||||
"name": "execute_code",
|
||||
"description": """
|
||||
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It recieves a piece of code and return a Json string.
|
||||
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It receives a piece of code and return a Json string.
|
||||
Here's a code example for Python(`main` function MUST be included):
|
||||
def main() -> dict:
|
||||
\"\"\"
|
||||
@ -99,16 +103,12 @@ module.exports = { main };
|
||||
"enum": ["python", "javascript"],
|
||||
"required": True,
|
||||
},
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "A piece of code in right format. There MUST be main function.",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
"script": {"type": "string", "description": "A piece of code in right format. There MUST be main function.", "required": True},
|
||||
},
|
||||
}
|
||||
super().__init__()
|
||||
self.lang = Language.PYTHON.value
|
||||
self.script = "def main(arg1: str, arg2: str) -> dict: return {\"result\": arg1 + arg2}"
|
||||
self.script = 'def main(arg1: str, arg2: str) -> dict: return {"result": arg1 + arg2}'
|
||||
self.arguments = {}
|
||||
self.outputs = {"result": {"value": "", "type": "string"}}
|
||||
|
||||
@ -119,17 +119,14 @@ module.exports = { main };
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
res = {}
|
||||
for k, v in self.arguments.items():
|
||||
res[k] = {
|
||||
"type": "line",
|
||||
"name": k
|
||||
}
|
||||
res[k] = {"type": "line", "name": k}
|
||||
return res
|
||||
|
||||
|
||||
class CodeExec(ToolBase, ABC):
|
||||
component_name = "CodeExec"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("CodeExec processing"):
|
||||
return
|
||||
@ -138,17 +135,12 @@ class CodeExec(ToolBase, ABC):
|
||||
script = kwargs.get("script", self._param.script)
|
||||
arguments = {}
|
||||
for k, v in self._param.arguments.items():
|
||||
|
||||
if kwargs.get(k):
|
||||
arguments[k] = kwargs[k]
|
||||
continue
|
||||
arguments[k] = self._canvas.get_variable_value(v) if v else None
|
||||
|
||||
self._execute_code(
|
||||
language=lang,
|
||||
code=script,
|
||||
arguments=arguments
|
||||
)
|
||||
self._execute_code(language=lang, code=script, arguments=arguments)
|
||||
|
||||
def _execute_code(self, language: str, code: str, arguments: dict):
|
||||
import requests
|
||||
@ -169,7 +161,7 @@ class CodeExec(ToolBase, ABC):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return "Task has been canceled"
|
||||
|
||||
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
|
||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:")
|
||||
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
@ -183,35 +175,10 @@ class CodeExec(ToolBase, ABC):
|
||||
if stderr:
|
||||
self.set_output("_ERROR", stderr)
|
||||
return
|
||||
try:
|
||||
rt = eval(body.get("stdout", ""))
|
||||
except Exception:
|
||||
rt = body.get("stdout", "")
|
||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run -> {rt}")
|
||||
if isinstance(rt, tuple):
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
if k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt[i]
|
||||
elif isinstance(rt, dict):
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
if k not in rt or k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt[k]
|
||||
else:
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
if k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt
|
||||
raw_stdout = body.get("stdout", "")
|
||||
parsed_stdout = self._deserialize_stdout(raw_stdout)
|
||||
logging.info(f"[CodeExec]: http://{settings.SANDBOX_HOST}:9385/run -> {parsed_stdout}")
|
||||
self._populate_outputs(parsed_stdout, raw_stdout)
|
||||
else:
|
||||
self.set_output("_ERROR", "There is no response from sandbox")
|
||||
|
||||
@ -228,3 +195,149 @@ class CodeExec(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Running a short script to process data."
|
||||
|
||||
def _deserialize_stdout(self, stdout: str):
|
||||
text = str(stdout).strip()
|
||||
if not text:
|
||||
return ""
|
||||
for loader in (json.loads, ast.literal_eval):
|
||||
try:
|
||||
return loader(text)
|
||||
except Exception:
|
||||
continue
|
||||
return text
|
||||
|
||||
def _coerce_output_value(self, value, expected_type: Optional[str]):
|
||||
if expected_type is None:
|
||||
return value
|
||||
|
||||
etype = expected_type.strip().lower()
|
||||
inner_type = None
|
||||
if etype.startswith("array<") and etype.endswith(">"):
|
||||
inner_type = etype[6:-1].strip()
|
||||
etype = "array"
|
||||
|
||||
try:
|
||||
if etype == "string":
|
||||
return "" if value is None else str(value)
|
||||
|
||||
if etype == "number":
|
||||
if value is None or value == "":
|
||||
return None
|
||||
if isinstance(value, (int, float)):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return value
|
||||
return float(value)
|
||||
|
||||
if etype == "boolean":
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
lv = value.lower()
|
||||
if lv in ("true", "1", "yes", "y", "on"):
|
||||
return True
|
||||
if lv in ("false", "0", "no", "n", "off"):
|
||||
return False
|
||||
return bool(value)
|
||||
|
||||
if etype == "array":
|
||||
candidate = value
|
||||
if isinstance(candidate, str):
|
||||
parsed = self._deserialize_stdout(candidate)
|
||||
candidate = parsed
|
||||
if isinstance(candidate, tuple):
|
||||
candidate = list(candidate)
|
||||
if not isinstance(candidate, list):
|
||||
candidate = [] if candidate is None else [candidate]
|
||||
|
||||
if inner_type == "string":
|
||||
return ["" if v is None else str(v) for v in candidate]
|
||||
if inner_type == "number":
|
||||
coerced = []
|
||||
for v in candidate:
|
||||
try:
|
||||
if v is None or v == "":
|
||||
coerced.append(None)
|
||||
elif isinstance(v, (int, float)):
|
||||
coerced.append(v)
|
||||
else:
|
||||
coerced.append(float(v))
|
||||
except Exception:
|
||||
coerced.append(v)
|
||||
return coerced
|
||||
return candidate
|
||||
|
||||
if etype == "object":
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
parsed = self._deserialize_stdout(value)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return value
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
return value
|
||||
|
||||
def _populate_outputs(self, parsed_stdout, raw_stdout: str):
|
||||
outputs_items = list(self._param.outputs.items())
|
||||
logging.info(f"[CodeExec]: outputs schema keys: {[k for k, _ in outputs_items]}")
|
||||
if not outputs_items:
|
||||
return
|
||||
|
||||
if isinstance(parsed_stdout, dict):
|
||||
for key, meta in outputs_items:
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
val = self._get_by_path(parsed_stdout, key)
|
||||
coerced = self._coerce_output_value(val, meta.get("type"))
|
||||
logging.info(f"[CodeExec]: populate dict key='{key}' raw='{val}' coerced='{coerced}'")
|
||||
self.set_output(key, coerced)
|
||||
return
|
||||
|
||||
if isinstance(parsed_stdout, (list, tuple)):
|
||||
for idx, (key, meta) in enumerate(outputs_items):
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
val = parsed_stdout[idx] if idx < len(parsed_stdout) else None
|
||||
coerced = self._coerce_output_value(val, meta.get("type"))
|
||||
logging.info(f"[CodeExec]: populate list key='{key}' raw='{val}' coerced='{coerced}'")
|
||||
self.set_output(key, coerced)
|
||||
return
|
||||
|
||||
default_val = parsed_stdout if parsed_stdout is not None else raw_stdout
|
||||
for idx, (key, meta) in enumerate(outputs_items):
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
val = default_val if idx == 0 else None
|
||||
coerced = self._coerce_output_value(val, meta.get("type"))
|
||||
logging.info(f"[CodeExec]: populate scalar key='{key}' raw='{val}' coerced='{coerced}'")
|
||||
self.set_output(key, coerced)
|
||||
|
||||
def _get_by_path(self, data, path: str):
|
||||
if not path:
|
||||
return None
|
||||
cur = data
|
||||
for part in path.split("."):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
return None
|
||||
if isinstance(cur, dict):
|
||||
cur = cur.get(part)
|
||||
elif isinstance(cur, list):
|
||||
try:
|
||||
idx = int(part)
|
||||
cur = cur[idx]
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
if cur is None:
|
||||
return None
|
||||
logging.info(f"[CodeExec]: resolve path '{path}' -> {cur}")
|
||||
return cur
|
||||
|
||||
@ -132,12 +132,12 @@ class Retrieval(ToolBase, ABC):
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if self._param.meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
||||
filters = gen_meta_filter(chat_mdl, metas, query)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, query)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif self._param.meta_data_filter.get("method") == "manual":
|
||||
filters=self._param.meta_data_filter["manual"]
|
||||
filters = self._param.meta_data_filter["manual"]
|
||||
for flt in filters:
|
||||
pat = re.compile(self.variable_ref_patt)
|
||||
s = flt["value"]
|
||||
@ -165,9 +165,9 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
out_parts.append(s[last:])
|
||||
flt["value"] = "".join(out_parts)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
doc_ids.extend(meta_filter(metas, filters, self._param.meta_data_filter.get("logic", "and")))
|
||||
if filters and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
|
||||
if self._param.cross_languages:
|
||||
query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
|
||||
@ -198,6 +198,7 @@ class Retrieval(ToolBase, ABC):
|
||||
return
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
|
||||
if self._param.use_kg:
|
||||
ck = settings.kg_retriever.retrieval(query,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
|
||||
@ -14,5 +14,5 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from beartype.claw import beartype_this_package
|
||||
beartype_this_package()
|
||||
# from beartype.claw import beartype_this_package
|
||||
# beartype_this_package()
|
||||
|
||||
@ -13,35 +13,35 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from flask import Blueprint, Flask
|
||||
from werkzeug.wrappers.request import Request
|
||||
from flask_cors import CORS
|
||||
from quart import Blueprint, Quart, request, g, current_app, session
|
||||
from flasgger import Swagger
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
from quart_cors import cors
|
||||
from common.constants import StatusEnum
|
||||
from api.db.db_models import close_connection
|
||||
from api.db.db_models import close_connection, APIToken
|
||||
from api.db.services import UserService
|
||||
from api.utils.json_encode import CustomJSONEncoder
|
||||
from api.utils import commands
|
||||
|
||||
from flask_mail import Mail
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from quart_auth import Unauthorized
|
||||
from common import settings
|
||||
from api.utils.api_utils import server_error_response
|
||||
from api.constants import API_VERSION
|
||||
from common.misc_utils import get_uuid
|
||||
|
||||
settings.init_settings()
|
||||
|
||||
__all__ = ["app"]
|
||||
|
||||
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
||||
|
||||
app = Flask(__name__)
|
||||
app = Quart(__name__)
|
||||
app = cors(app, allow_origin="*")
|
||||
smtp_mail_server = Mail()
|
||||
|
||||
# Add this at the beginning of your file to configure Swagger UI
|
||||
@ -76,32 +76,166 @@ swagger = Swagger(
|
||||
},
|
||||
)
|
||||
|
||||
CORS(app, supports_credentials=True, max_age=2592000)
|
||||
app.url_map.strict_slashes = False
|
||||
app.json_encoder = CustomJSONEncoder
|
||||
app.errorhandler(Exception)(server_error_response)
|
||||
|
||||
# Configure Quart timeouts for slow LLM responses (e.g., local Ollama on CPU)
|
||||
# Default Quart timeouts are 60 seconds which is too short for many LLM backends
|
||||
app.config["RESPONSE_TIMEOUT"] = int(os.environ.get("QUART_RESPONSE_TIMEOUT", 600))
|
||||
app.config["BODY_TIMEOUT"] = int(os.environ.get("QUART_BODY_TIMEOUT", 600))
|
||||
|
||||
## convince for dev and debug
|
||||
# app.config["LOGIN_DISABLED"] = True
|
||||
app.config["SESSION_PERMANENT"] = False
|
||||
app.config["SESSION_TYPE"] = "filesystem"
|
||||
app.config["SESSION_TYPE"] = "redis"
|
||||
app.config["SESSION_REDIS"] = settings.decrypt_database_config(name="redis")
|
||||
app.config["MAX_CONTENT_LENGTH"] = int(
|
||||
os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024)
|
||||
)
|
||||
|
||||
Session(app)
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
app.config['SECRET_KEY'] = settings.SECRET_KEY
|
||||
app.secret_key = settings.SECRET_KEY
|
||||
commands.register_commands(app)
|
||||
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
from collections.abc import Awaitable, Callable
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
def search_pages_path(pages_dir):
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
def _load_user():
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = request.headers.get("Authorization")
|
||||
g.user = None
|
||||
if not authorization:
|
||||
return
|
||||
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
logging.warning("Authentication attempt with empty access token")
|
||||
return None
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
|
||||
return None
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if not user and len(authorization.split()) == 2:
|
||||
objs = APIToken.query(token=authorization.split()[1])
|
||||
if objs:
|
||||
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
g.user = user[0]
|
||||
return user[0]
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
|
||||
|
||||
current_user = LocalProxy(_load_user)
|
||||
|
||||
|
||||
def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
|
||||
"""A decorator to restrict route access to authenticated users.
|
||||
|
||||
This should be used to wrap a route handler (or view function) to
|
||||
enforce that only authenticated requests can access it. Note that
|
||||
it is important that this decorator be wrapped by the route
|
||||
decorator and not vice, versa, as below.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@app.route('/')
|
||||
@login_required
|
||||
async def index():
|
||||
...
|
||||
|
||||
If the request is not authenticated a
|
||||
`quart.exceptions.Unauthorized` exception will be raised.
|
||||
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
if not current_user:# or not session.get("_user_id"):
|
||||
raise Unauthorized()
|
||||
else:
|
||||
return await current_app.ensure_async(func)(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def login_user(user, remember=False, duration=None, force=False, fresh=True):
|
||||
"""
|
||||
Logs a user in. You should pass the actual user object to this. If the
|
||||
user's `is_active` property is ``False``, they will not be logged in
|
||||
unless `force` is ``True``.
|
||||
|
||||
This will return ``True`` if the log in attempt succeeds, and ``False`` if
|
||||
it fails (i.e. because the user is inactive).
|
||||
|
||||
:param user: The user object to log in.
|
||||
:type user: object
|
||||
:param remember: Whether to remember the user after their session expires.
|
||||
Defaults to ``False``.
|
||||
:type remember: bool
|
||||
:param duration: The amount of time before the remember cookie expires. If
|
||||
``None`` the value set in the settings is used. Defaults to ``None``.
|
||||
:type duration: :class:`datetime.timedelta`
|
||||
:param force: If the user is inactive, setting this to ``True`` will log
|
||||
them in regardless. Defaults to ``False``.
|
||||
:type force: bool
|
||||
:param fresh: setting this to ``False`` will log in the user with a session
|
||||
marked as not "fresh". Defaults to ``True``.
|
||||
:type fresh: bool
|
||||
"""
|
||||
if not force and not user.is_active:
|
||||
return False
|
||||
|
||||
session["_user_id"] = user.id
|
||||
session["_fresh"] = fresh
|
||||
session["_id"] = get_uuid()
|
||||
return True
|
||||
|
||||
|
||||
def logout_user():
|
||||
"""
|
||||
Logs a user out. (You do not need to pass the actual user.) This will
|
||||
also clean up the remember me cookie if it exists.
|
||||
"""
|
||||
if "_user_id" in session:
|
||||
session.pop("_user_id")
|
||||
|
||||
if "_fresh" in session:
|
||||
session.pop("_fresh")
|
||||
|
||||
if "_id" in session:
|
||||
session.pop("_id")
|
||||
|
||||
COOKIE_NAME = "remember_token"
|
||||
cookie_name = current_app.config.get("REMEMBER_COOKIE_NAME", COOKIE_NAME)
|
||||
if cookie_name in request.cookies:
|
||||
session["_remember"] = "clear"
|
||||
if "_remember_seconds" in session:
|
||||
session.pop("_remember_seconds")
|
||||
|
||||
return True
|
||||
|
||||
def search_pages_path(page_path):
|
||||
app_path_list = [
|
||||
path for path in pages_dir.glob("*_app.py") if not path.name.startswith(".")
|
||||
path for path in page_path.glob("*_app.py") if not path.name.startswith(".")
|
||||
]
|
||||
api_path_list = [
|
||||
path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".")
|
||||
path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")
|
||||
]
|
||||
app_path_list.extend(api_path_list)
|
||||
return app_path_list
|
||||
@ -138,44 +272,12 @@ pages_dir = [
|
||||
]
|
||||
|
||||
client_urls_prefix = [
|
||||
register_page(path) for dir in pages_dir for path in search_pages_path(dir)
|
||||
register_page(path) for directory in pages_dir for path in search_pages_path(directory)
|
||||
]
|
||||
|
||||
|
||||
@login_manager.request_loader
|
||||
def load_user(web_request):
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = web_request.headers.get("Authorization")
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
logging.warning("Authentication attempt with empty access token")
|
||||
return None
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
|
||||
return None
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
return user[0]
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@app.teardown_request
|
||||
def _db_close(exc):
|
||||
def _db_close(exception):
|
||||
if exception:
|
||||
logging.exception(f"Request failed: {exception}")
|
||||
close_connection()
|
||||
|
||||
@ -13,46 +13,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from flask import request, Response
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db import VALID_FILE_TYPES, FileType
|
||||
from api.db.db_models import APIToken, Task, File
|
||||
from api.db.services import duplicate_name
|
||||
from quart import request
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||
from api.db.services.dialog_service import DialogService, chat
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import queue_tasks, TaskService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
||||
generate_confirmation_token
|
||||
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import keyword_extraction
|
||||
from api.utils.api_utils import generate_confirmation_token, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from agent.canvas import Canvas
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def new_token():
|
||||
req = request.json
|
||||
async def new_token():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if not tenants:
|
||||
@ -97,8 +71,8 @@ def token_list():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@validate_request("tokens", "tenant_id")
|
||||
@login_required
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
for token in req["tokens"]:
|
||||
APITokenService.filter_delete(
|
||||
@ -126,770 +100,18 @@ def stats():
|
||||
"to_date",
|
||||
datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
|
||||
"agent" if "canvas_id" in request.args else None)
|
||||
res = {
|
||||
"pv": [(o["dt"], o["pv"]) for o in objs],
|
||||
"uv": [(o["dt"], o["uv"]) for o in objs],
|
||||
"speed": [(o["dt"], float(o["tokens"]) / (float(o["duration"] + 0.1))) for o in objs],
|
||||
"tokens": [(o["dt"], float(o["tokens"]) / 1000.) for o in objs],
|
||||
"round": [(o["dt"], o["round"]) for o in objs],
|
||||
"thumb_up": [(o["dt"], o["thumb_up"]) for o in objs]
|
||||
}
|
||||
|
||||
res = {"pv": [], "uv": [], "speed": [], "tokens": [], "round": [], "thumb_up": []}
|
||||
|
||||
for obj in objs:
|
||||
dt = obj["dt"]
|
||||
res["pv"].append((dt, obj["pv"]))
|
||||
res["uv"].append((dt, obj["uv"]))
|
||||
res["speed"].append((dt, float(obj["tokens"]) / (float(obj["duration"]) + 0.1))) # +0.1 to avoid division by zero
|
||||
res["tokens"].append((dt, float(obj["tokens"]) / 1000.0)) # convert to thousands
|
||||
res["round"].append((dt, obj["round"]))
|
||||
res["thumb_up"].append((dt, obj["thumb_up"]))
|
||||
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/new_conversation', methods=['GET']) # noqa: F821
|
||||
def set_conversation():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
if objs[0].source == "agent":
|
||||
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": request.args.get("user_id", ""),
|
||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
||||
"source": "agent"
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
else:
|
||||
e, dia = DialogService.get_by_id(objs[0].dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found")
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": dia.id,
|
||||
"user_id": request.args.get("user_id", ""),
|
||||
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
||||
@validate_request("conversation_id", "messages")
|
||||
def completion():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
if "quote" not in req:
|
||||
req["quote"] = False
|
||||
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
def fillin_conv(ans):
|
||||
nonlocal conv, message_id
|
||||
if not conv.reference:
|
||||
conv.reference.append(ans["reference"])
|
||||
else:
|
||||
conv.reference[-1] = ans["reference"]
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
||||
ans["id"] = message_id
|
||||
|
||||
def rename_field(ans):
|
||||
reference = ans['reference']
|
||||
if not isinstance(reference, dict):
|
||||
return
|
||||
for chunk_i in reference.get('chunks', []):
|
||||
if 'docnm_kwd' in chunk_i:
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
|
||||
try:
|
||||
if conv.source == "agent":
|
||||
stream = req.get("stream", True)
|
||||
conv.message.append(msg[-1])
|
||||
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
final_ans = {"reference": [], "content": ""}
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
|
||||
canvas.messages.append(msg[-1])
|
||||
canvas.add_user_input(msg[-1]["content"])
|
||||
answer = canvas.run(stream=stream)
|
||||
|
||||
assert answer is not None, "Nothing. Is it over?"
|
||||
|
||||
if stream:
|
||||
assert isinstance(answer, partial), "Nothing. Is it over?"
|
||||
|
||||
def sse():
|
||||
nonlocal answer, cvs, conv
|
||||
try:
|
||||
for ans in answer():
|
||||
for k in ans.keys():
|
||||
final_ans[k] = ans[k]
|
||||
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
|
||||
fillin_conv(ans)
|
||||
rename_field(ans)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
canvas.history.append(("assistant", final_ans["content"]))
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(sse(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
|
||||
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
||||
fillin_conv(result)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
rename_field(result)
|
||||
return get_json_result(data=result)
|
||||
|
||||
# ******************For dialog******************
|
||||
conv.message.append(msg[-1])
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
def stream():
|
||||
nonlocal dia, msg, req, conv
|
||||
try:
|
||||
for ans in chat(dia, msg, True, **req):
|
||||
fillin_conv(ans)
|
||||
rename_field(ans)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
if req.get("stream", True):
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
answer = None
|
||||
for ans in chat(dia, msg, **req):
|
||||
answer = ans
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
break
|
||||
rename_field(answer)
|
||||
return get_json_result(data=answer)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/conversation/<conversation_id>', methods=['GET']) # noqa: F821
|
||||
# @login_required
|
||||
def get_conversation(conversation_id):
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
e, conv = API4ConversationService.get_by_id(conversation_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
|
||||
conv = conv.to_dict()
|
||||
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
|
||||
return get_json_result(data=False, message='Authentication error: API key is invalid for this conversation_id!"',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
for referenct_i in conv['reference']:
|
||||
if referenct_i is None or len(referenct_i) == 0:
|
||||
continue
|
||||
for chunk_i in referenct_i['chunks']:
|
||||
if 'docnm_kwd' in chunk_i.keys():
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/document/upload', methods=['POST']) # noqa: F821
|
||||
@validate_request("kb_name")
|
||||
def upload():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
kb_name = request.form.get("kb_name").strip()
|
||||
tenant_id = objs[0].tenant_id
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
kb_id = kb.id
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
kb_root_folder = FileService.get_kb_folder(tenant_id)
|
||||
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
|
||||
|
||||
try:
|
||||
if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)):
|
||||
return get_data_error_result(
|
||||
message="Exceed the maximum file number of a free user!")
|
||||
|
||||
filename = duplicate_name(
|
||||
DocumentService.query,
|
||||
name=file.filename,
|
||||
kb_id=kb_id)
|
||||
filetype = filename_type(filename)
|
||||
if not filetype:
|
||||
return get_data_error_result(
|
||||
message="This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
while settings.STORAGE_IMPL.obj_exist(kb_id, location):
|
||||
location += "_"
|
||||
blob = request.files['file'].read()
|
||||
settings.STORAGE_IMPL.put(kb_id, location, blob)
|
||||
doc = {
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": kb.tenant_id,
|
||||
"type": filetype,
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
"thumbnail": thumbnail(filename, blob),
|
||||
"suffix": Path(filename).suffix.lstrip("."),
|
||||
}
|
||||
|
||||
form_data = request.form
|
||||
if "parser_id" in form_data.keys():
|
||||
if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
|
||||
doc["parser_id"] = request.form.get("parser_id").strip()
|
||||
if doc["type"] == FileType.VISUAL:
|
||||
doc["parser_id"] = ParserType.PICTURE.value
|
||||
if doc["type"] == FileType.AURAL:
|
||||
doc["parser_id"] = ParserType.AUDIO.value
|
||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||
doc["parser_id"] = ParserType.PRESENTATION.value
|
||||
if re.search(r"\.(eml)$", filename):
|
||||
doc["parser_id"] = ParserType.EMAIL.value
|
||||
|
||||
doc_result = DocumentService.insert(doc)
|
||||
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
if "run" in form_data.keys():
|
||||
if request.form.get("run").strip() == "1":
|
||||
try:
|
||||
info = {"run": 1, "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
|
||||
DocumentService.update_by_id(doc["id"], info)
|
||||
# if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||
tenant_id = DocumentService.get_tenant_id(doc["id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
# e, doc = DocumentService.get_by_id(doc["id"])
|
||||
TaskService.filter_delete([Task.doc_id == doc["id"]])
|
||||
e, doc = DocumentService.get_by_id(doc["id"])
|
||||
doc = doc.to_dict()
|
||||
doc["tenant_id"] = tenant_id
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
return get_json_result(data=doc_result.to_json())
|
||||
|
||||
|
||||
@manager.route('/document/upload_and_parse', methods=['POST']) # noqa: F821
|
||||
@validate_request("conversation_id")
|
||||
def upload_parse():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist('file')
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
||||
return get_json_result(data=doc_ids)
|
||||
|
||||
|
||||
@manager.route('/list_chunks', methods=['POST']) # noqa: F821
|
||||
# @login_required
|
||||
def list_chunks():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
|
||||
try:
|
||||
if "doc_name" in req.keys():
|
||||
tenant_id = DocumentService.get_tenant_id_by_name(req['doc_name'])
|
||||
doc_id = DocumentService.get_doc_id_by_doc_name(req['doc_name'])
|
||||
|
||||
elif "doc_id" in req.keys():
|
||||
tenant_id = DocumentService.get_tenant_id(req['doc_id'])
|
||||
doc_id = req['doc_id']
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message="Can't find doc_name or doc_id"
|
||||
)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
|
||||
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = [
|
||||
{
|
||||
"content": res_item["content_with_weight"],
|
||||
"doc_name": res_item["docnm_kwd"],
|
||||
"image_id": res_item["img_id"]
|
||||
} for res_item in res
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
return get_json_result(data=res)
|
||||
|
||||
@manager.route('/get_chunk/<chunk_id>', methods=['GET']) # noqa: F821
|
||||
# @login_required
|
||||
def get_chunk(chunk_id):
|
||||
from rag.nlp import search
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
tenant_id = objs[0].tenant_id
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||
if chunk is None:
|
||||
return server_error_response(Exception("Chunk not found"))
|
||||
k = []
|
||||
for n in chunk.keys():
|
||||
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
|
||||
k.append(n)
|
||||
for n in k:
|
||||
del chunk[n]
|
||||
|
||||
return get_json_result(data=chunk)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/list_kb_docs', methods=['POST']) # noqa: F821
|
||||
# @login_required
|
||||
def list_kb_docs():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
tenant_id = objs[0].tenant_id
|
||||
kb_name = req.get("kb_name", "").strip()
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
kb_id = kb.id
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
page_number = int(req.get("page", 1))
|
||||
items_per_page = int(req.get("page_size", 15))
|
||||
orderby = req.get("orderby", "create_time")
|
||||
desc = req.get("desc", True)
|
||||
keywords = req.get("keywords", "")
|
||||
status = req.get("status", [])
|
||||
if status:
|
||||
invalid_status = {s for s in status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(
|
||||
message=f"Invalid filter status conditions: {', '.join(invalid_status)}"
|
||||
)
|
||||
types = req.get("types", [])
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(
|
||||
message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}"
|
||||
)
|
||||
try:
|
||||
docs, tol = DocumentService.get_by_kb_id(
|
||||
kb_id, page_number, items_per_page, orderby, desc, keywords, status, types)
|
||||
docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs]
|
||||
|
||||
return get_json_result(data={"total": tol, "docs": docs})
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/document/infos', methods=['POST']) # noqa: F821
|
||||
@validate_request("doc_ids")
|
||||
def docinfos():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
doc_ids = req["doc_ids"]
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
return get_json_result(data=list(docs.dicts()))
|
||||
|
||||
|
||||
@manager.route('/document', methods=['DELETE']) # noqa: F821
|
||||
# @login_required
|
||||
def document_rm():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
req = request.json
|
||||
try:
|
||||
doc_ids = DocumentService.get_doc_ids_by_doc_names(req.get("doc_names", []))
|
||||
for doc_id in req.get("doc_ids", []):
|
||||
if doc_id not in doc_ids:
|
||||
doc_ids.append(doc_id)
|
||||
|
||||
if not doc_ids:
|
||||
return get_json_result(
|
||||
data=False, message="Can't find doc_names or doc_ids"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
|
||||
errors = ""
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
doc_dic = {}
|
||||
for doc in docs:
|
||||
doc_dic[doc.id] = doc
|
||||
|
||||
for doc_id in doc_ids:
|
||||
try:
|
||||
if doc_id not in doc_dic:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
doc = doc_dic[doc_id]
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
|
||||
f2d = File2DocumentService.get_by_document_id(doc_id)
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc_id)
|
||||
|
||||
settings.STORAGE_IMPL.rm(b, n)
|
||||
except Exception as e:
|
||||
errors += str(e)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/completion_aibotk', methods=['POST']) # noqa: F821
|
||||
@validate_request("Authorization", "conversation_id", "word")
|
||||
def completion_faq():
|
||||
import base64
|
||||
req = request.json
|
||||
|
||||
token = req["Authorization"]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
if "quote" not in req:
|
||||
req["quote"] = True
|
||||
|
||||
msg = [{"role": "user", "content": req["word"]}]
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
def fillin_conv(ans):
|
||||
nonlocal conv, message_id
|
||||
if not conv.reference:
|
||||
conv.reference.append(ans["reference"])
|
||||
else:
|
||||
conv.reference[-1] = ans["reference"]
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
||||
ans["id"] = message_id
|
||||
|
||||
try:
|
||||
if conv.source == "agent":
|
||||
conv.message.append(msg[-1])
|
||||
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
final_ans = {"reference": [], "doc_aggs": []}
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
|
||||
canvas.messages.append(msg[-1])
|
||||
canvas.add_user_input(msg[-1]["content"])
|
||||
answer = canvas.run(stream=False)
|
||||
|
||||
assert answer is not None, "Nothing. Is it over?"
|
||||
|
||||
data_type_picture = {
|
||||
"type": 3,
|
||||
"url": "base64 content"
|
||||
}
|
||||
data = [
|
||||
{
|
||||
"type": 1,
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
|
||||
ans = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
||||
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
|
||||
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
||||
for chunk_idx in chunk_idxs[:1]:
|
||||
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
||||
try:
|
||||
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
||||
response = settings.STORAGE_IMPL.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
response = {"code": 200, "msg": "success", "data": data}
|
||||
return response
|
||||
|
||||
# ******************For dialog******************
|
||||
conv.message.append(msg[-1])
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
del req["conversation_id"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
data_type_picture = {
|
||||
"type": 3,
|
||||
"url": "base64 content"
|
||||
}
|
||||
data = [
|
||||
{
|
||||
"type": 1,
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
ans = ""
|
||||
for a in chat(dia, msg, stream=False, **req):
|
||||
ans = a
|
||||
break
|
||||
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
|
||||
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
||||
for chunk_idx in chunk_idxs[:1]:
|
||||
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
||||
try:
|
||||
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
||||
response = settings.STORAGE_IMPL.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
response = {"code": 200, "msg": "success", "data": data}
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/retrieval', methods=['POST']) # noqa: F821
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
kb_ids = req.get("kb_id", [])
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
question = req.get("question")
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("page_size", 30))
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = int(req.get("top_k", 1024))
|
||||
highlight = bool(req.get("highlight", False))
|
||||
|
||||
try:
|
||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
if len(embd_nms) != 1:
|
||||
return get_json_result(
|
||||
data=False, message='Knowledge bases use different embedding models or does not exist."',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, llm_name=kbs[0].embd_id)
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
|
||||
rank_feature=label_question(question, kbs))
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
return get_json_result(data=ranks)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
code=RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import requests
|
||||
from common.http_client import async_request, sync_request
|
||||
from .oauth import OAuthClient, UserInfo
|
||||
|
||||
|
||||
@ -34,24 +34,49 @@ class GithubOAuthClient(OAuthClient):
|
||||
|
||||
def fetch_user_info(self, access_token, **kwargs):
|
||||
"""
|
||||
Fetch GitHub user info.
|
||||
Fetch GitHub user info (synchronous).
|
||||
"""
|
||||
user_info = {}
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
# user info
|
||||
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||
response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||
response.raise_for_status()
|
||||
user_info.update(response.json())
|
||||
# email info
|
||||
response = requests.get(self.userinfo_url+"/emails", headers=headers, timeout=self.http_request_timeout)
|
||||
response.raise_for_status()
|
||||
email_info = response.json()
|
||||
user_info["email"] = next(
|
||||
(email for email in email_info if email["primary"]), None
|
||||
)["email"]
|
||||
email_response = sync_request(
|
||||
"GET", self.userinfo_url + "/emails", headers=headers, timeout=self.http_request_timeout
|
||||
)
|
||||
email_response.raise_for_status()
|
||||
email_info = email_response.json()
|
||||
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||
return self.normalize_user_info(user_info)
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch github user info: {e}")
|
||||
|
||||
async def async_fetch_user_info(self, access_token, **kwargs):
|
||||
"""Async variant of fetch_user_info using httpx."""
|
||||
user_info = {}
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
try:
|
||||
response = await async_request(
|
||||
"GET",
|
||||
self.userinfo_url,
|
||||
headers=headers,
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_info.update(response.json())
|
||||
|
||||
email_response = await async_request(
|
||||
"GET",
|
||||
self.userinfo_url + "/emails",
|
||||
headers=headers,
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
email_response.raise_for_status()
|
||||
email_info = email_response.json()
|
||||
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||
return self.normalize_user_info(user_info)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch github user info: {e}")
|
||||
|
||||
|
||||
|
||||
@ -14,8 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import requests
|
||||
import urllib.parse
|
||||
from common.http_client import async_request, sync_request
|
||||
|
||||
|
||||
class UserInfo:
|
||||
@ -74,15 +74,40 @@ class OAuthClient:
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"grant_type": "authorization_code"
|
||||
}
|
||||
response = requests.post(
|
||||
response = sync_request(
|
||||
"POST",
|
||||
self.token_url,
|
||||
data=payload,
|
||||
headers={"Accept": "application/json"},
|
||||
timeout=self.http_request_timeout
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to exchange authorization code for token: {e}")
|
||||
|
||||
async def async_exchange_code_for_token(self, code):
|
||||
"""
|
||||
Async variant of exchange_code_for_token using httpx.
|
||||
"""
|
||||
payload = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
try:
|
||||
response = await async_request(
|
||||
"POST",
|
||||
self.token_url,
|
||||
data=payload,
|
||||
headers={"Accept": "application/json"},
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to exchange authorization code for token: {e}")
|
||||
|
||||
|
||||
@ -92,11 +117,27 @@ class OAuthClient:
|
||||
"""
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||
response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
return self.normalize_user_info(user_info)
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch user info: {e}")
|
||||
|
||||
async def async_fetch_user_info(self, access_token, **kwargs):
|
||||
"""Async variant of fetch_user_info using httpx."""
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
try:
|
||||
response = await async_request(
|
||||
"GET",
|
||||
self.userinfo_url,
|
||||
headers=headers,
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
return self.normalize_user_info(user_info)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch user info: {e}")
|
||||
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
#
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
from common.http_client import sync_request
|
||||
from .oauth import OAuthClient
|
||||
|
||||
|
||||
@ -50,10 +50,10 @@ class OIDCClient(OAuthClient):
|
||||
"""
|
||||
try:
|
||||
metadata_url = f"{issuer}/.well-known/openid-configuration"
|
||||
response = requests.get(metadata_url, timeout=7)
|
||||
response = sync_request("GET", metadata_url, timeout=7)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch OIDC metadata: {e}")
|
||||
|
||||
|
||||
@ -95,6 +95,13 @@ class OIDCClient(OAuthClient):
|
||||
user_info.update(super().fetch_user_info(access_token).to_dict())
|
||||
return self.normalize_user_info(user_info)
|
||||
|
||||
async def async_fetch_user_info(self, access_token, id_token=None, **kwargs):
|
||||
user_info = {}
|
||||
if id_token:
|
||||
user_info = self.parse_id_token(id_token)
|
||||
user_info.update((await super().async_fetch_user_info(access_token)).to_dict())
|
||||
return self.normalize_user_info(user_info)
|
||||
|
||||
|
||||
def normalize_user_info(self, user_info):
|
||||
return super().normalize_user_info(user_info)
|
||||
|
||||
@ -13,19 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
import flask
|
||||
import trio
|
||||
from flask import request, Response
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from quart import request, Response, make_response
|
||||
from agent.component import LLM
|
||||
from api.db import CanvasCategory, FileType
|
||||
from api.db import CanvasCategory
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
@ -35,17 +29,18 @@ from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
|
||||
get_request_json
|
||||
from agent.canvas import Canvas
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
from api.db.db_models import APIToken, Task
|
||||
import time
|
||||
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
||||
from rag.flow.pipeline import Pipeline
|
||||
from rag.nlp import search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/templates', methods=['GET']) # noqa: F821
|
||||
@ -57,8 +52,9 @@ def templates():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@validate_request("canvas_ids")
|
||||
@login_required
|
||||
def rm():
|
||||
for i in request.json["canvas_ids"]:
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
for i in req["canvas_ids"]:
|
||||
if not UserCanvasService.accessible(i, current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -70,8 +66,8 @@ def rm():
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@validate_request("dsl", "title")
|
||||
@login_required
|
||||
def save():
|
||||
req = request.json
|
||||
async def save():
|
||||
req = await get_request_json()
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
@ -129,18 +125,18 @@ def getsse(canvas_id):
|
||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def run():
|
||||
req = request.json
|
||||
async def run():
|
||||
req = await get_request_json()
|
||||
query = req.get("query", "")
|
||||
files = req.get("files", [])
|
||||
inputs = req.get("inputs", {})
|
||||
user_id = req.get("user_id", current_user.id)
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
||||
e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
@ -150,7 +146,7 @@ def run():
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
task_id = get_uuid()
|
||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
|
||||
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, files[0], 0)
|
||||
if not ok:
|
||||
return get_data_error_result(message=error_message)
|
||||
return get_json_result(data={"message_id": task_id})
|
||||
@ -160,10 +156,10 @@ def run():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
def sse():
|
||||
async def sse():
|
||||
nonlocal canvas, user_id
|
||||
try:
|
||||
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
@ -179,15 +175,15 @@ def run():
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
resp.call_on_close(lambda: canvas.cancel_task())
|
||||
#resp.call_on_close(lambda: canvas.cancel_task())
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "dsl", "component_id")
|
||||
@login_required
|
||||
def rerun():
|
||||
req = request.json
|
||||
async def rerun():
|
||||
req = await get_request_json()
|
||||
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||
if not doc:
|
||||
return get_data_error_result(message="Document not found.")
|
||||
@ -224,8 +220,8 @@ def cancel(task_id):
|
||||
@manager.route('/reset', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def reset():
|
||||
req = request.json
|
||||
async def reset():
|
||||
req = await get_request_json()
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -245,76 +241,16 @@ def reset():
|
||||
|
||||
|
||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||
def upload(canvas_id):
|
||||
async def upload(canvas_id):
|
||||
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
user_id = cvs["user_id"]
|
||||
def structured(filename, filetype, blob, content_type):
|
||||
nonlocal user_id
|
||||
if filetype == FileType.PDF.value:
|
||||
blob = read_potential_broken_pdf(blob)
|
||||
|
||||
location = get_uuid()
|
||||
FileService.put_blob(user_id, location, blob)
|
||||
|
||||
return {
|
||||
"id": location,
|
||||
"name": filename,
|
||||
"size": sys.getsizeof(blob),
|
||||
"extension": filename.split(".")[-1].lower(),
|
||||
"mime_type": content_type,
|
||||
"created_by": user_id,
|
||||
"created_at": time.time(),
|
||||
"preview_url": None
|
||||
}
|
||||
|
||||
if request.args.get("url"):
|
||||
from crawl4ai import (
|
||||
AsyncWebCrawler,
|
||||
BrowserConfig,
|
||||
CrawlerRunConfig,
|
||||
DefaultMarkdownGenerator,
|
||||
PruningContentFilter,
|
||||
CrawlResult
|
||||
)
|
||||
try:
|
||||
url = request.args.get("url")
|
||||
filename = re.sub(r"\?.*", "", url.split("/")[-1])
|
||||
async def adownload():
|
||||
browser_config = BrowserConfig(
|
||||
headless=True,
|
||||
verbose=False,
|
||||
)
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
crawler_config = CrawlerRunConfig(
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=PruningContentFilter()
|
||||
),
|
||||
pdf=True,
|
||||
screenshot=False
|
||||
)
|
||||
result: CrawlResult = await crawler.arun(
|
||||
url=url,
|
||||
config=crawler_config
|
||||
)
|
||||
return result
|
||||
page = trio.run(adownload())
|
||||
if page.pdf:
|
||||
if filename.split(".")[-1].lower() != "pdf":
|
||||
filename += ".pdf"
|
||||
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
|
||||
|
||||
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
file = request.files['file']
|
||||
files = await request.files
|
||||
file = files['file'] if files and files.get("file") else None
|
||||
try:
|
||||
DocumentService.check_doc_health(user_id, file.filename)
|
||||
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
|
||||
return get_json_result(data=FileService.upload_info(user_id, file, request.args.get("url")))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -342,8 +278,8 @@ def input_form():
|
||||
@manager.route('/debug', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "component_id", "params")
|
||||
@login_required
|
||||
def debug():
|
||||
req = request.json
|
||||
async def debug():
|
||||
req = await get_request_json()
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -374,8 +310,8 @@ def debug():
|
||||
@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
|
||||
@validate_request("db_type", "database", "username", "host", "port", "password")
|
||||
@login_required
|
||||
def test_db_connect():
|
||||
req = request.json
|
||||
async def test_db_connect():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
if req["db_type"] in ["mysql", "mariadb"]:
|
||||
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||
@ -426,7 +362,6 @@ def test_db_connect():
|
||||
try:
|
||||
import trino
|
||||
import os
|
||||
from trino.auth import BasicAuthentication
|
||||
except Exception as e:
|
||||
return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}")
|
||||
|
||||
@ -438,7 +373,7 @@ def test_db_connect():
|
||||
|
||||
auth = None
|
||||
if http_scheme == "https" and req.get("password"):
|
||||
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||
auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||
|
||||
conn = trino.dbapi.connect(
|
||||
host=req["host"],
|
||||
@ -471,8 +406,8 @@ def test_db_connect():
|
||||
@login_required
|
||||
def getlistversion(canvas_id):
|
||||
try:
|
||||
list =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
|
||||
return get_json_result(data=list)
|
||||
versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
|
||||
return get_json_result(data=versions)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"Error getting history files: {e}")
|
||||
|
||||
@ -520,8 +455,8 @@ def list_canvas():
|
||||
@manager.route('/setting', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "title", "permission")
|
||||
@login_required
|
||||
def setting():
|
||||
req = request.json
|
||||
async def setting():
|
||||
req = await get_request_json()
|
||||
req["user_id"] = current_user.id
|
||||
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
@ -602,8 +537,8 @@ def prompts():
|
||||
|
||||
|
||||
@manager.route('/download', methods=['GET']) # noqa: F821
|
||||
def download():
|
||||
async def download():
|
||||
id = request.args.get("id")
|
||||
created_by = request.args.get("created_by")
|
||||
blob = FileService.get_blob(created_by, id)
|
||||
return flask.make_response(blob)
|
||||
return await make_response(blob)
|
||||
|
||||
@ -13,13 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
|
||||
import xxhash
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import request
|
||||
|
||||
from api.db.services.dialog_service import meta_filter
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -27,7 +27,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
get_request_json
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
@ -35,13 +36,14 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def list_chunk():
|
||||
req = request.json
|
||||
async def list_chunk():
|
||||
req = await get_request_json()
|
||||
doc_id = req["doc_id"]
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
@ -121,8 +123,8 @@ def get():
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
||||
def set():
|
||||
req = request.json
|
||||
async def set():
|
||||
req = await get_request_json()
|
||||
d = {
|
||||
"id": req["chunk_id"],
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
@ -146,31 +148,35 @@ def set():
|
||||
d["available_int"] = req["available_int"]
|
||||
|
||||
try:
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
def _set_sync():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
if doc.parser_id == ParserType.QA:
|
||||
arr = [
|
||||
t for t in re.split(
|
||||
r"[\n\t]",
|
||||
req["content_with_weight"]) if len(t) > 1]
|
||||
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
||||
d = beAdoc(d, q, a, not any(
|
||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||
_d = d
|
||||
if doc.parser_id == ParserType.QA:
|
||||
arr = [
|
||||
t for t in re.split(
|
||||
r"[\n\t]",
|
||||
req["content_with_weight"]) if len(t) > 1]
|
||||
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
||||
_d = beAdoc(d, q, a, not any(
|
||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
|
||||
return get_json_result(data=True)
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
_d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_set_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -178,19 +184,22 @@ def set():
|
||||
@manager.route('/switch', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "available_int", "doc_id")
|
||||
def switch():
|
||||
req = request.json
|
||||
async def switch():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
for cid in req["chunk_ids"]:
|
||||
if not settings.docStoreConn.update({"id": cid},
|
||||
{"available_int": int(req["available_int"])},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
def _switch_sync():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
for cid in req["chunk_ids"]:
|
||||
if not settings.docStoreConn.update({"id": cid},
|
||||
{"available_int": int(req["available_int"])},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_switch_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -198,23 +207,26 @@ def switch():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "doc_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Chunk deleting failure")
|
||||
deleted_chunk_ids = req["chunk_ids"]
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
for cid in deleted_chunk_ids:
|
||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
return get_json_result(data=True)
|
||||
def _rm_sync():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Chunk deleting failure")
|
||||
deleted_chunk_ids = req["chunk_ids"]
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
for cid in deleted_chunk_ids:
|
||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -222,8 +234,8 @@ def rm():
|
||||
@manager.route('/create', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "content_with_weight")
|
||||
def create():
|
||||
req = request.json
|
||||
async def create():
|
||||
req = await get_request_json()
|
||||
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
|
||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
@ -244,35 +256,38 @@ def create():
|
||||
d["tag_feas"] = req["tag_feas"]
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
d["kb_id"] = [doc.kb_id]
|
||||
d["docnm_kwd"] = doc.name
|
||||
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
||||
d["doc_id"] = doc.id
|
||||
def _create_sync():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
d["kb_id"] = [doc.kb_id]
|
||||
d["docnm_kwd"] = doc.name
|
||||
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
||||
d["doc_id"] = doc.id
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
if kb.pagerank:
|
||||
d[PAGERANK_FLD] = kb.pagerank
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
if kb.pagerank:
|
||||
d[PAGERANK_FLD] = kb.pagerank
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
return get_json_result(data={"chunk_id": chunck_id})
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
return get_json_result(data={"chunk_id": chunck_id})
|
||||
|
||||
return await asyncio.to_thread(_create_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -280,8 +295,8 @@ def create():
|
||||
@manager.route('/retrieval_test', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval_test():
|
||||
req = request.json
|
||||
async def retrieval_test():
|
||||
req = await get_request_json()
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
@ -296,25 +311,28 @@ def retrieval_test():
|
||||
use_kg = req.get("use_kg", False)
|
||||
top = int(req.get("top_k", 1024))
|
||||
langs = req.get("cross_languages", [])
|
||||
tenant_ids = []
|
||||
user_id = current_user.id
|
||||
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
def _retrieval_sync():
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
tenant_ids = []
|
||||
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not local_doc_ids:
|
||||
local_doc_ids = ["-999"]
|
||||
|
||||
tenants = UserTenantService.query(user_id=user_id)
|
||||
for kb_id in kb_ids:
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(
|
||||
@ -330,8 +348,9 @@ def retrieval_test():
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
|
||||
_question = question
|
||||
if langs:
|
||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
||||
_question = cross_languages(kb.tenant_id, None, _question, langs)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
@ -341,19 +360,19 @@ def retrieval_test():
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
_question += keyword_extraction(chat_mdl, _question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
labels = label_question(_question, [kb])
|
||||
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
float(req.get("similarity_threshold", 0.0)),
|
||||
float(req.get("vector_similarity_weight", 0.3)),
|
||||
top,
|
||||
doc_ids, rerank_mdl=rerank_mdl,
|
||||
local_doc_ids, rerank_mdl=rerank_mdl,
|
||||
highlight=req.get("highlight", False),
|
||||
rank_feature=labels
|
||||
)
|
||||
if use_kg:
|
||||
ck = settings.kg_retriever.retrieval(question,
|
||||
ck = settings.kg_retriever.retrieval(_question,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
@ -366,6 +385,9 @@ def retrieval_test():
|
||||
ranks["labels"] = labels
|
||||
|
||||
return get_json_result(data=ranks)
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_retrieval_sync)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
@ -20,24 +21,24 @@ import uuid
|
||||
from html import escape
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import request, make_response
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
|
||||
from api.db import InputType
|
||||
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, validate_request
|
||||
from common.constants import RetCode, TaskStatus
|
||||
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, DocumentSource
|
||||
from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
|
||||
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, DocumentSource
|
||||
from common.data_source.google_util.constant import GOOGLE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
|
||||
from common.misc_utils import get_uuid
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def set_connector():
|
||||
req = request.json
|
||||
async def set_connector():
|
||||
req = await get_request_json()
|
||||
if req.get("id"):
|
||||
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
|
||||
ConnectorService.update_by_id(req["id"], conn)
|
||||
@ -55,10 +56,9 @@ def set_connector():
|
||||
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
|
||||
"status": TaskStatus.SCHEDULE,
|
||||
}
|
||||
conn["status"] = TaskStatus.SCHEDULE
|
||||
ConnectorService.save(**conn)
|
||||
|
||||
time.sleep(1)
|
||||
await asyncio.sleep(1)
|
||||
e, conn = ConnectorService.get_by_id(req["id"])
|
||||
|
||||
return get_json_result(data=conn.to_dict())
|
||||
@ -89,8 +89,8 @@ def list_logs(connector_id):
|
||||
|
||||
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
def resume(connector_id):
|
||||
req = request.json
|
||||
async def resume(connector_id):
|
||||
req = await get_request_json()
|
||||
if req.get("resume"):
|
||||
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
|
||||
else:
|
||||
@ -101,8 +101,8 @@ def resume(connector_id):
|
||||
@manager.route("/<connector_id>/rebuild", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def rebuild(connector_id):
|
||||
req = request.json
|
||||
async def rebuild(connector_id):
|
||||
req = await get_request_json()
|
||||
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
|
||||
if err:
|
||||
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
|
||||
@ -122,12 +122,30 @@ GOOGLE_WEB_FLOW_RESULT_PREFIX = "google_drive_web_flow_result"
|
||||
WEB_FLOW_TTL_SECS = 15 * 60
|
||||
|
||||
|
||||
def _web_state_cache_key(flow_id: str) -> str:
|
||||
return f"{GOOGLE_WEB_FLOW_STATE_PREFIX}:{flow_id}"
|
||||
def _web_state_cache_key(flow_id: str, source_type: str | None = None) -> str:
|
||||
"""Return Redis key for web OAuth state.
|
||||
|
||||
The default prefix keeps backward compatibility for Google Drive.
|
||||
When source_type == "gmail", a different prefix is used so that
|
||||
Drive/Gmail flows don't clash in Redis.
|
||||
"""
|
||||
if source_type == "gmail":
|
||||
prefix = "gmail_web_flow_state"
|
||||
else:
|
||||
prefix = GOOGLE_WEB_FLOW_STATE_PREFIX
|
||||
return f"{prefix}:{flow_id}"
|
||||
|
||||
|
||||
def _web_result_cache_key(flow_id: str) -> str:
|
||||
return f"{GOOGLE_WEB_FLOW_RESULT_PREFIX}:{flow_id}"
|
||||
def _web_result_cache_key(flow_id: str, source_type: str | None = None) -> str:
|
||||
"""Return Redis key for web OAuth result.
|
||||
|
||||
Mirrors _web_state_cache_key logic for result storage.
|
||||
"""
|
||||
if source_type == "gmail":
|
||||
prefix = "gmail_web_flow_result"
|
||||
else:
|
||||
prefix = GOOGLE_WEB_FLOW_RESULT_PREFIX
|
||||
return f"{prefix}:{flow_id}"
|
||||
|
||||
|
||||
def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]:
|
||||
@ -146,43 +164,61 @@ def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]:
|
||||
return {"web": web_section}
|
||||
|
||||
|
||||
def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
|
||||
async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, source="drive"):
|
||||
status = "success" if success else "error"
|
||||
auto_close = "window.close();" if success else ""
|
||||
escaped_message = escape(message)
|
||||
# Drive: ragflow-google-drive-oauth
|
||||
# Gmail: ragflow-gmail-oauth
|
||||
payload_type = f"ragflow-{source}-oauth"
|
||||
payload_json = json.dumps(
|
||||
{
|
||||
"type": "ragflow-google-drive-oauth",
|
||||
"type": payload_type,
|
||||
"status": status,
|
||||
"flowId": flow_id or "",
|
||||
"message": message,
|
||||
}
|
||||
)
|
||||
html = GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE.format(
|
||||
# TODO(google-oauth): title/heading/message may need to reflect drive/gmail based on cached type
|
||||
html = GOOGLE_WEB_OAUTH_POPUP_TEMPLATE.format(
|
||||
title=f"Google {source.capitalize()} Authorization",
|
||||
heading="Authorization complete" if success else "Authorization failed",
|
||||
message=escaped_message,
|
||||
payload_json=payload_json,
|
||||
auto_close=auto_close,
|
||||
)
|
||||
response = make_response(html, 200)
|
||||
response = await make_response(html, 200)
|
||||
response.headers["Content-Type"] = "text/html; charset=utf-8"
|
||||
return response
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821
|
||||
@manager.route("/google/oauth/web/start", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("credentials")
|
||||
def start_google_drive_web_oauth():
|
||||
if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI:
|
||||
async def start_google_web_oauth():
|
||||
source = request.args.get("type", "google-drive")
|
||||
if source not in ("google-drive", "gmail"):
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.")
|
||||
|
||||
if source == "gmail":
|
||||
redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI
|
||||
scopes = GOOGLE_SCOPES[DocumentSource.GMAIL]
|
||||
else:
|
||||
redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI if source == "google-drive" else GMAIL_WEB_OAUTH_REDIRECT_URI
|
||||
scopes = GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE if source == "google-drive" else DocumentSource.GMAIL]
|
||||
|
||||
if not redirect_uri:
|
||||
return get_json_result(
|
||||
code=RetCode.SERVER_ERROR,
|
||||
message="Google Drive OAuth redirect URI is not configured on the server.",
|
||||
message="Google OAuth redirect URI is not configured on the server.",
|
||||
)
|
||||
|
||||
req = request.json or {}
|
||||
req = await get_request_json()
|
||||
raw_credentials = req.get("credentials", "")
|
||||
|
||||
try:
|
||||
credentials = _load_credentials(raw_credentials)
|
||||
print(credentials)
|
||||
except ValueError as exc:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))
|
||||
|
||||
@ -199,8 +235,8 @@ def start_google_drive_web_oauth():
|
||||
|
||||
flow_id = str(uuid.uuid4())
|
||||
try:
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||
flow = Flow.from_client_config(client_config, scopes=scopes)
|
||||
flow.redirect_uri = redirect_uri
|
||||
authorization_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
include_granted_scopes="true",
|
||||
@ -219,7 +255,7 @@ def start_google_drive_web_oauth():
|
||||
"client_config": client_config,
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
REDIS_CONN.set_obj(_web_state_cache_key(flow_id), cache_payload, WEB_FLOW_TTL_SECS)
|
||||
REDIS_CONN.set_obj(_web_state_cache_key(flow_id, source), cache_payload, WEB_FLOW_TTL_SECS)
|
||||
|
||||
return get_json_result(
|
||||
data={
|
||||
@ -230,60 +266,122 @@ def start_google_drive_web_oauth():
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||
def google_drive_web_oauth_callback():
|
||||
@manager.route("/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||
async def google_gmail_web_oauth_callback():
|
||||
state_id = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
source = "gmail"
|
||||
if source != 'gmail':
|
||||
return await _render_web_oauth_popup("", False, "Invalid Google OAuth type.", source)
|
||||
|
||||
error_description = request.args.get("error_description") or error
|
||||
|
||||
if not state_id:
|
||||
return _render_web_oauth_popup("", False, "Missing OAuth state parameter.")
|
||||
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source)
|
||||
|
||||
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id))
|
||||
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source))
|
||||
if not state_cache:
|
||||
return _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.")
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source)
|
||||
|
||||
state_obj = json.loads(state_cache)
|
||||
client_config = state_obj.get("client_config")
|
||||
if not client_config:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.")
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
|
||||
|
||||
if error:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.")
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source)
|
||||
|
||||
code = request.args.get("code")
|
||||
if not code:
|
||||
return _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.")
|
||||
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
|
||||
|
||||
try:
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||
# TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail)
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GMAIL])
|
||||
flow.redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI
|
||||
flow.fetch_token(code=code)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.")
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source)
|
||||
|
||||
creds_json = flow.credentials.to_json()
|
||||
result_payload = {
|
||||
"user_id": state_obj.get("user_id"),
|
||||
"credentials": creds_json,
|
||||
}
|
||||
REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS)
|
||||
|
||||
return _render_web_oauth_popup(state_id, True, "Authorization completed successfully.")
|
||||
print("\n\n", _web_result_cache_key(state_id, source), "\n\n")
|
||||
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
|
||||
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821
|
||||
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||
async def google_drive_web_oauth_callback():
|
||||
state_id = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
source = "google-drive"
|
||||
if source not in ("google-drive", "gmail"):
|
||||
return await _render_web_oauth_popup("", False, "Invalid Google OAuth type.", source)
|
||||
|
||||
error_description = request.args.get("error_description") or error
|
||||
|
||||
if not state_id:
|
||||
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source)
|
||||
|
||||
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source))
|
||||
if not state_cache:
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source)
|
||||
|
||||
state_obj = json.loads(state_cache)
|
||||
client_config = state_obj.get("client_config")
|
||||
if not client_config:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
|
||||
|
||||
if error:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source)
|
||||
|
||||
code = request.args.get("code")
|
||||
if not code:
|
||||
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
|
||||
|
||||
try:
|
||||
# TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail)
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||
flow.fetch_token(code=code)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source)
|
||||
|
||||
creds_json = flow.credentials.to_json()
|
||||
result_payload = {
|
||||
"user_id": state_obj.get("user_id"),
|
||||
"credentials": creds_json,
|
||||
}
|
||||
REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
|
||||
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
|
||||
|
||||
@manager.route("/google/oauth/web/result", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("flow_id")
|
||||
def poll_google_drive_web_result():
|
||||
req = request.json or {}
|
||||
async def poll_google_web_result():
|
||||
req = await request.json or {}
|
||||
source = request.args.get("type")
|
||||
if source not in ("google-drive", "gmail"):
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.")
|
||||
flow_id = req.get("flow_id")
|
||||
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id))
|
||||
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id, source))
|
||||
if not cache_raw:
|
||||
return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.")
|
||||
|
||||
@ -291,5 +389,5 @@ def poll_google_drive_web_result():
|
||||
if result.get("user_id") != current_user.id:
|
||||
return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.")
|
||||
|
||||
REDIS_CONN.delete(_web_result_cache_key(flow_id))
|
||||
REDIS_CONN.delete(_web_result_cache_key(flow_id, source))
|
||||
return get_json_result(data={"credentials": result.get("credentials")})
|
||||
|
||||
@ -14,11 +14,13 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
import tempfile
|
||||
from quart import Response, request
|
||||
from api.apps import current_user, login_required
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
|
||||
@ -26,7 +28,7 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import chunks_format
|
||||
from common.constants import RetCode, LLMType
|
||||
@ -34,8 +36,8 @@ from common.constants import RetCode, LLMType
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def set_conversation():
|
||||
req = request.json
|
||||
async def set_conversation():
|
||||
req = await get_request_json()
|
||||
conv_id = req.get("conversation_id")
|
||||
is_new = req.get("is_new")
|
||||
name = req.get("name", "New conversation")
|
||||
@ -78,14 +80,13 @@ def set_conversation():
|
||||
|
||||
@manager.route("/get", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def get():
|
||||
async def get():
|
||||
conv_id = request.args["conversation_id"]
|
||||
try:
|
||||
e, conv = ConversationService.get_by_id(conv_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
avatar = None
|
||||
for tenant in tenants:
|
||||
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
|
||||
if dialog and len(dialog) > 0:
|
||||
@ -129,8 +130,9 @@ def getsse(dialog_id):
|
||||
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def rm():
|
||||
conv_ids = request.json["conversation_ids"]
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
conv_ids = req["conversation_ids"]
|
||||
try:
|
||||
for cid in conv_ids:
|
||||
exist, conv = ConversationService.get_by_id(cid)
|
||||
@ -150,7 +152,7 @@ def rm():
|
||||
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_conversation():
|
||||
async def list_conversation():
|
||||
dialog_id = request.args["dialog_id"]
|
||||
try:
|
||||
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
||||
@ -166,8 +168,8 @@ def list_conversation():
|
||||
@manager.route("/completion", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "messages")
|
||||
def completion():
|
||||
req = request.json
|
||||
async def completion():
|
||||
req = await get_request_json()
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
@ -248,11 +250,69 @@ def completion():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def sequence2txt():
|
||||
req = await request.form
|
||||
stream_mode = req.get("stream", "false").lower() == "true"
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_data_error_result(message="Missing 'file' in multipart form-data")
|
||||
|
||||
uploaded = files["file"]
|
||||
|
||||
ALLOWED_EXTS = {
|
||||
".wav", ".mp3", ".m4a", ".aac",
|
||||
".flac", ".ogg", ".webm",
|
||||
".opus", ".wma"
|
||||
}
|
||||
|
||||
filename = uploaded.filename or ""
|
||||
suffix = os.path.splitext(filename)[-1].lower()
|
||||
if suffix not in ALLOWED_EXTS:
|
||||
return get_data_error_result(message=
|
||||
f"Unsupported audio format: {suffix}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}"
|
||||
)
|
||||
fd, temp_audio_path = tempfile.mkstemp(suffix=suffix)
|
||||
os.close(fd)
|
||||
await uploaded.save(temp_audio_path)
|
||||
|
||||
tenants = TenantService.get_info_by(current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
asr_id = tenants[0]["asr_id"]
|
||||
if not asr_id:
|
||||
return get_data_error_result(message="No default ASR model is set")
|
||||
|
||||
asr_mdl=LLMBundle(tenants[0]["tenant_id"], LLMType.SPEECH2TEXT, asr_id)
|
||||
if not stream_mode:
|
||||
text = asr_mdl.transcription(temp_audio_path)
|
||||
try:
|
||||
os.remove(temp_audio_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||
return get_json_result(data={"text": text})
|
||||
async def event_stream():
|
||||
try:
|
||||
for evt in asr_mdl.stream_transcription(temp_audio_path):
|
||||
yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
err = {"event": "error", "text": str(e)}
|
||||
yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n"
|
||||
finally:
|
||||
try:
|
||||
os.remove(temp_audio_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||
|
||||
return Response(event_stream(), content_type="text/event-stream")
|
||||
|
||||
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def tts():
|
||||
req = request.json
|
||||
async def tts():
|
||||
req = await get_request_json()
|
||||
text = req["text"]
|
||||
|
||||
tenants = TenantService.get_info_by(current_user.id)
|
||||
@ -284,8 +344,8 @@ def tts():
|
||||
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def delete_msg():
|
||||
req = request.json
|
||||
async def delete_msg():
|
||||
req = await get_request_json()
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
@ -307,8 +367,8 @@ def delete_msg():
|
||||
@manager.route("/thumbup", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def thumbup():
|
||||
req = request.json
|
||||
async def thumbup():
|
||||
req = await get_request_json()
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
@ -334,8 +394,8 @@ def thumbup():
|
||||
@manager.route("/ask", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def ask_about():
|
||||
req = request.json
|
||||
async def ask_about():
|
||||
req = await get_request_json()
|
||||
uid = current_user.id
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
@ -366,8 +426,8 @@ def ask_about():
|
||||
@manager.route("/mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def mindmap():
|
||||
req = request.json
|
||||
async def mindmap():
|
||||
req = await get_request_json()
|
||||
search_id = req.get("search_id", "")
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
search_config = search_app.get("search_config", {}) if search_app else {}
|
||||
@ -384,8 +444,8 @@ def mindmap():
|
||||
@manager.route("/related_questions", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question")
|
||||
def related_questions():
|
||||
req = request.json
|
||||
async def related_questions():
|
||||
req = await get_request_json()
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_config = {}
|
||||
@ -402,7 +462,7 @@ def related_questions():
|
||||
if "parameter" in gen_conf:
|
||||
del gen_conf["parameter"]
|
||||
prompt = load_prompt("related_question")
|
||||
ans = chat_mdl.chat(
|
||||
ans = await chat_mdl.async_chat(
|
||||
prompt,
|
||||
[
|
||||
{
|
||||
|
||||
@ -14,25 +14,24 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from common.constants import StatusEnum
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@validate_request("prompt_config")
|
||||
@login_required
|
||||
def set_dialog():
|
||||
req = request.json
|
||||
async def set_dialog():
|
||||
req = await get_request_json()
|
||||
dialog_id = req.get("dialog_id", "")
|
||||
is_create = not dialog_id
|
||||
name = req.get("name", "New Dialog")
|
||||
@ -154,33 +153,34 @@ def get_kb_names(kb_ids):
|
||||
@login_required
|
||||
def list_dialogs():
|
||||
try:
|
||||
diags = DialogService.query(
|
||||
conversations = DialogService.query(
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value,
|
||||
reverse=True,
|
||||
order_by=DialogService.model.create_time)
|
||||
diags = [d.to_dict() for d in diags]
|
||||
for d in diags:
|
||||
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
|
||||
return get_json_result(data=diags)
|
||||
conversations = [d.to_dict() for d in conversations]
|
||||
for conversation in conversations:
|
||||
conversation["kb_ids"], conversation["kb_names"] = get_kb_names(conversation["kb_ids"])
|
||||
return get_json_result(data=conversations)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/next', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def list_dialogs_next():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
parser_id = request.args.get("parser_id")
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
async def list_dialogs_next():
|
||||
args = request.args
|
||||
keywords = args.get("keywords", "")
|
||||
page_number = int(args.get("page", 0))
|
||||
items_per_page = int(args.get("page_size", 0))
|
||||
parser_id = args.get("parser_id")
|
||||
orderby = args.get("orderby", "create_time")
|
||||
if args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -207,8 +207,8 @@ def list_dialogs_next():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("dialog_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
dialog_list=[]
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
try:
|
||||
|
||||
@ -13,16 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import os.path
|
||||
import pathlib
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
|
||||
from quart import request, make_response
|
||||
from api.apps import current_user, login_required
|
||||
from api.common.check_team_permission import check_kb_team_permission
|
||||
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
||||
from api.db import VALID_FILE_TYPES, FileType
|
||||
@ -39,7 +37,7 @@ from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
validate_request, get_request_json,
|
||||
)
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from common.file_utils import get_project_base_directory
|
||||
@ -53,14 +51,16 @@ from common import settings
|
||||
@manager.route("/upload", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def upload():
|
||||
kb_id = request.form.get("kb_id")
|
||||
async def upload():
|
||||
form = await request.form
|
||||
kb_id = form.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
if "file" not in request.files:
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -73,7 +73,7 @@ def upload():
|
||||
if not check_kb_team_permission(kb, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
err, files = FileService.upload_document(kb, file_objs, current_user.id)
|
||||
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
|
||||
if err:
|
||||
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||
|
||||
@ -87,12 +87,13 @@ def upload():
|
||||
@manager.route("/web_crawl", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id", "name", "url")
|
||||
def web_crawl():
|
||||
kb_id = request.form.get("kb_id")
|
||||
async def web_crawl():
|
||||
form = await request.form
|
||||
kb_id = form.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
name = request.form.get("name")
|
||||
url = request.form.get("url")
|
||||
name = form.get("name")
|
||||
url = form.get("url")
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
@ -152,8 +153,8 @@ def web_crawl():
|
||||
@manager.route("/create", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "kb_id")
|
||||
def create():
|
||||
req = request.json
|
||||
async def create():
|
||||
req = await get_request_json()
|
||||
kb_id = req["kb_id"]
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -208,7 +209,7 @@ def create():
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_docs():
|
||||
async def list_docs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -230,7 +231,7 @@ def list_docs():
|
||||
create_time_from = int(request.args.get("create_time_from", 0))
|
||||
create_time_to = int(request.args.get("create_time_to", 0))
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
|
||||
run_status = req.get("run_status", [])
|
||||
if run_status:
|
||||
@ -270,8 +271,8 @@ def list_docs():
|
||||
|
||||
@manager.route("/filter", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def get_filter():
|
||||
req = request.get_json()
|
||||
async def get_filter():
|
||||
req = await get_request_json()
|
||||
|
||||
kb_id = req.get("kb_id")
|
||||
if not kb_id:
|
||||
@ -308,8 +309,8 @@ def get_filter():
|
||||
|
||||
@manager.route("/infos", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def docinfos():
|
||||
req = request.json
|
||||
async def doc_infos():
|
||||
req = await get_request_json()
|
||||
doc_ids = req["doc_ids"]
|
||||
for doc_id in doc_ids:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
@ -340,8 +341,8 @@ def thumbnails():
|
||||
@manager.route("/change_status", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_ids", "status")
|
||||
def change_status():
|
||||
req = request.get_json()
|
||||
async def change_status():
|
||||
req = await get_request_json()
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
status = str(req.get("status", ""))
|
||||
|
||||
@ -380,8 +381,8 @@ def change_status():
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
doc_ids = req["doc_id"]
|
||||
if isinstance(doc_ids, str):
|
||||
doc_ids = [doc_ids]
|
||||
@ -390,7 +391,7 @@ def rm():
|
||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
errors = FileService.delete_docs(doc_ids, current_user.id)
|
||||
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||
@ -401,46 +402,50 @@ def rm():
|
||||
@manager.route("/run", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_ids", "run")
|
||||
def run():
|
||||
req = request.json
|
||||
for doc_id in req["doc_ids"]:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
async def run():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
kb_table_num_map = {}
|
||||
for id in req["doc_ids"]:
|
||||
info = {"run": str(req["run"]), "progress": 0}
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
|
||||
info["progress_msg"] = ""
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
def _run_sync():
|
||||
for doc_id in req["doc_ids"]:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
kb_table_num_map = {}
|
||||
for id in req["doc_ids"]:
|
||||
info = {"run": str(req["run"]), "progress": 0}
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
|
||||
info["progress_msg"] = ""
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
|
||||
if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||
if str(doc.run) == TaskStatus.RUNNING.value:
|
||||
cancel_all_task_of(id)
|
||||
else:
|
||||
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
|
||||
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
|
||||
DocumentService.clear_chunk_num_when_rerun(doc.id)
|
||||
tenant_id = DocumentService.get_tenant_id(id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
DocumentService.update_by_id(id, info)
|
||||
if req.get("delete", False):
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||
if str(doc.run) == TaskStatus.RUNNING.value:
|
||||
cancel_all_task_of(id)
|
||||
else:
|
||||
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
|
||||
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
|
||||
DocumentService.clear_chunk_num_when_rerun(doc.id)
|
||||
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||
doc = doc.to_dict()
|
||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||
DocumentService.update_by_id(id, info)
|
||||
if req.get("delete", False):
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||
doc_dict = doc.to_dict()
|
||||
DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_run_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -448,66 +453,72 @@ def run():
|
||||
@manager.route("/rename", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "name")
|
||||
def rename():
|
||||
req = request.json
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
async def rename():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
|
||||
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
|
||||
def _rename_sync():
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||
if d.name == req["name"]:
|
||||
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
|
||||
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
|
||||
return get_data_error_result(message="Database error (Document rename)!")
|
||||
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||
if d.name == req["name"]:
|
||||
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
||||
|
||||
informs = File2DocumentService.get_by_document_id(req["doc_id"])
|
||||
if informs:
|
||||
e, file = FileService.get_by_id(informs[0].file_id)
|
||||
FileService.update_by_id(file.id, {"name": req["name"]})
|
||||
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
|
||||
return get_data_error_result(message="Database error (Document rename)!")
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
title_tks = rag_tokenizer.tokenize(req["name"])
|
||||
es_body = {
|
||||
"docnm_kwd": req["name"],
|
||||
"title_tks": title_tks,
|
||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||
}
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.update(
|
||||
{"doc_id": req["doc_id"]},
|
||||
es_body,
|
||||
search.index_name(tenant_id),
|
||||
doc.kb_id,
|
||||
)
|
||||
informs = File2DocumentService.get_by_document_id(req["doc_id"])
|
||||
if informs:
|
||||
e, file = FileService.get_by_id(informs[0].file_id)
|
||||
FileService.update_by_id(file.id, {"name": req["name"]})
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
title_tks = rag_tokenizer.tokenize(req["name"])
|
||||
es_body = {
|
||||
"docnm_kwd": req["name"],
|
||||
"title_tks": title_tks,
|
||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||
}
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.update(
|
||||
{"doc_id": req["doc_id"]},
|
||||
es_body,
|
||||
search.index_name(tenant_id),
|
||||
doc.kb_id,
|
||||
)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rename_sync)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/get/<doc_id>", methods=["GET"]) # noqa: F821
|
||||
# @login_required
|
||||
def get(doc_id):
|
||||
async def get(doc_id):
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
response = flask.make_response(settings.STORAGE_IMPL.get(b, n))
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||
response = await make_response(data)
|
||||
|
||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
if ext:
|
||||
if doc.type == FileType.VISUAL.value:
|
||||
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
|
||||
else:
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
|
||||
@ -517,12 +528,27 @@ def get(doc_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/download/<attachment_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def download_attachment(attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/change_parser", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def change_parser():
|
||||
async def change_parser():
|
||||
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
@ -544,6 +570,7 @@ def change_parser():
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
if "pipeline_id" in req and req["pipeline_id"] != "":
|
||||
@ -572,13 +599,14 @@ def change_parser():
|
||||
|
||||
@manager.route("/image/<image_id>", methods=["GET"]) # noqa: F821
|
||||
# @login_required
|
||||
def get_image(image_id):
|
||||
async def get_image(image_id):
|
||||
try:
|
||||
arr = image_id.split("-")
|
||||
if len(arr) != 2:
|
||||
return get_data_error_result(message="Image not found.")
|
||||
bkt, nm = image_id.split("-")
|
||||
response = flask.make_response(settings.STORAGE_IMPL.get(bkt, nm))
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", "image/JPEG")
|
||||
return response
|
||||
except Exception as e:
|
||||
@ -588,24 +616,26 @@ def get_image(image_id):
|
||||
@manager.route("/upload_and_parse", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id")
|
||||
def upload_and_parse():
|
||||
if "file" not in request.files:
|
||||
async def upload_and_parse():
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
|
||||
|
||||
form = await request.form
|
||||
doc_ids = doc_upload_and_parse(form.get("conversation_id"), file_objs, current_user.id)
|
||||
return get_json_result(data=doc_ids)
|
||||
|
||||
|
||||
@manager.route("/parse", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def parse():
|
||||
url = request.json.get("url") if request.json else ""
|
||||
async def parse():
|
||||
req = await get_request_json()
|
||||
url = req.get("url", "")
|
||||
if url:
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -646,10 +676,11 @@ def parse():
|
||||
txt = FileService.parse_docs([f], current_user.id)
|
||||
return get_json_result(data=txt)
|
||||
|
||||
if "file" not in request.files:
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
txt = FileService.parse_docs(file_objs, current_user.id)
|
||||
|
||||
return get_json_result(data=txt)
|
||||
@ -658,8 +689,8 @@ def parse():
|
||||
@manager.route("/set_meta", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "meta")
|
||||
def set_meta():
|
||||
req = request.json
|
||||
async def set_meta():
|
||||
req = await get_request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
@ -685,3 +716,13 @@ def set_meta():
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/upload_info", methods=["POST"]) # noqa: F821
|
||||
async def upload_info():
|
||||
files = await request.files
|
||||
file = files['file'] if files and files.get("file") else None
|
||||
try:
|
||||
return get_json_result(data=FileService.upload_info(current_user.id, file, request.args.get("url")))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
479
api/apps/evaluation_app.py
Normal file
479
api/apps/evaluation_app.py
Normal file
@ -0,0 +1,479 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
RAG Evaluation API Endpoints
|
||||
|
||||
Provides REST API for RAG evaluation functionality including:
|
||||
- Dataset management
|
||||
- Test case management
|
||||
- Evaluation execution
|
||||
- Results retrieval
|
||||
- Configuration recommendations
|
||||
"""
|
||||
|
||||
from quart import request
|
||||
from api.apps import login_required, current_user
|
||||
from api.db.services.evaluation_service import EvaluationService
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
get_request_json,
|
||||
server_error_response,
|
||||
validate_request
|
||||
)
|
||||
from common.constants import RetCode
|
||||
|
||||
|
||||
# ==================== Dataset Management ====================
|
||||
|
||||
@manager.route('/dataset/create', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "kb_ids")
|
||||
async def create_dataset():
|
||||
"""
|
||||
Create a new evaluation dataset.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"name": "Dataset name",
|
||||
"description": "Optional description",
|
||||
"kb_ids": ["kb_id1", "kb_id2"]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
name = req.get("name", "").strip()
|
||||
description = req.get("description", "")
|
||||
kb_ids = req.get("kb_ids", [])
|
||||
|
||||
if not name:
|
||||
return get_data_error_result(message="Dataset name cannot be empty")
|
||||
|
||||
if not kb_ids or not isinstance(kb_ids, list):
|
||||
return get_data_error_result(message="kb_ids must be a non-empty list")
|
||||
|
||||
success, result = EvaluationService.create_dataset(
|
||||
name=name,
|
||||
description=description,
|
||||
kb_ids=kb_ids,
|
||||
tenant_id=current_user.id,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message=result)
|
||||
|
||||
return get_json_result(data={"dataset_id": result})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def list_datasets():
|
||||
"""
|
||||
List evaluation datasets for current tenant.
|
||||
|
||||
Query params:
|
||||
- page: Page number (default: 1)
|
||||
- page_size: Items per page (default: 20)
|
||||
"""
|
||||
try:
|
||||
page = int(request.args.get("page", 1))
|
||||
page_size = int(request.args.get("page_size", 20))
|
||||
|
||||
result = EvaluationService.list_datasets(
|
||||
tenant_id=current_user.id,
|
||||
user_id=current_user.id,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_dataset(dataset_id):
|
||||
"""Get dataset details by ID"""
|
||||
try:
|
||||
dataset = EvaluationService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
return get_data_error_result(
|
||||
message="Dataset not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
return get_json_result(data=dataset)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>', methods=['PUT']) # noqa: F821
|
||||
@login_required
|
||||
async def update_dataset(dataset_id):
|
||||
"""
|
||||
Update dataset.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"name": "New name",
|
||||
"description": "New description",
|
||||
"kb_ids": ["kb_id1", "kb_id2"]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
|
||||
# Remove fields that shouldn't be updated
|
||||
req.pop("id", None)
|
||||
req.pop("tenant_id", None)
|
||||
req.pop("created_by", None)
|
||||
req.pop("create_time", None)
|
||||
|
||||
success = EvaluationService.update_dataset(dataset_id, **req)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message="Failed to update dataset")
|
||||
|
||||
return get_json_result(data={"dataset_id": dataset_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>', methods=['DELETE']) # noqa: F821
|
||||
@login_required
|
||||
async def delete_dataset(dataset_id):
|
||||
"""Delete dataset (soft delete)"""
|
||||
try:
|
||||
success = EvaluationService.delete_dataset(dataset_id)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message="Failed to delete dataset")
|
||||
|
||||
return get_json_result(data={"dataset_id": dataset_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
# ==================== Test Case Management ====================
|
||||
|
||||
@manager.route('/dataset/<dataset_id>/case/add', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question")
|
||||
async def add_test_case(dataset_id):
|
||||
"""
|
||||
Add a test case to a dataset.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"question": "Test question",
|
||||
"reference_answer": "Optional ground truth answer",
|
||||
"relevant_doc_ids": ["doc_id1", "doc_id2"],
|
||||
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"],
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
question = req.get("question", "").strip()
|
||||
|
||||
if not question:
|
||||
return get_data_error_result(message="Question cannot be empty")
|
||||
|
||||
success, result = EvaluationService.add_test_case(
|
||||
dataset_id=dataset_id,
|
||||
question=question,
|
||||
reference_answer=req.get("reference_answer"),
|
||||
relevant_doc_ids=req.get("relevant_doc_ids"),
|
||||
relevant_chunk_ids=req.get("relevant_chunk_ids"),
|
||||
metadata=req.get("metadata")
|
||||
)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message=result)
|
||||
|
||||
return get_json_result(data={"case_id": result})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>/case/import', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("cases")
|
||||
async def import_test_cases(dataset_id):
|
||||
"""
|
||||
Bulk import test cases.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"cases": [
|
||||
{
|
||||
"question": "Question 1",
|
||||
"reference_answer": "Answer 1",
|
||||
...
|
||||
},
|
||||
{
|
||||
"question": "Question 2",
|
||||
...
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
cases = req.get("cases", [])
|
||||
|
||||
if not cases or not isinstance(cases, list):
|
||||
return get_data_error_result(message="cases must be a non-empty list")
|
||||
|
||||
success_count, failure_count = EvaluationService.import_test_cases(
|
||||
dataset_id=dataset_id,
|
||||
cases=cases
|
||||
)
|
||||
|
||||
return get_json_result(data={
|
||||
"success_count": success_count,
|
||||
"failure_count": failure_count,
|
||||
"total": len(cases)
|
||||
})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>/cases', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_test_cases(dataset_id):
|
||||
"""Get all test cases for a dataset"""
|
||||
try:
|
||||
cases = EvaluationService.get_test_cases(dataset_id)
|
||||
return get_json_result(data={"cases": cases, "total": len(cases)})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/case/<case_id>', methods=['DELETE']) # noqa: F821
|
||||
@login_required
|
||||
async def delete_test_case(case_id):
|
||||
"""Delete a test case"""
|
||||
try:
|
||||
success = EvaluationService.delete_test_case(case_id)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message="Failed to delete test case")
|
||||
|
||||
return get_json_result(data={"case_id": case_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
# ==================== Evaluation Execution ====================
|
||||
|
||||
@manager.route('/run/start', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("dataset_id", "dialog_id")
|
||||
async def start_evaluation():
|
||||
"""
|
||||
Start an evaluation run.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"dataset_id": "dataset_id",
|
||||
"dialog_id": "dialog_id",
|
||||
"name": "Optional run name"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
dataset_id = req.get("dataset_id")
|
||||
dialog_id = req.get("dialog_id")
|
||||
name = req.get("name")
|
||||
|
||||
success, result = EvaluationService.start_evaluation(
|
||||
dataset_id=dataset_id,
|
||||
dialog_id=dialog_id,
|
||||
user_id=current_user.id,
|
||||
name=name
|
||||
)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message=result)
|
||||
|
||||
return get_json_result(data={"run_id": result})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/<run_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_evaluation_run(run_id):
|
||||
"""Get evaluation run details"""
|
||||
try:
|
||||
result = EvaluationService.get_run_results(run_id)
|
||||
|
||||
if not result:
|
||||
return get_data_error_result(
|
||||
message="Evaluation run not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/<run_id>/results', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_run_results(run_id):
|
||||
"""Get detailed results for an evaluation run"""
|
||||
try:
|
||||
result = EvaluationService.get_run_results(run_id)
|
||||
|
||||
if not result:
|
||||
return get_data_error_result(
|
||||
message="Evaluation run not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def list_evaluation_runs():
|
||||
"""
|
||||
List evaluation runs.
|
||||
|
||||
Query params:
|
||||
- dataset_id: Filter by dataset (optional)
|
||||
- dialog_id: Filter by dialog (optional)
|
||||
- page: Page number (default: 1)
|
||||
- page_size: Items per page (default: 20)
|
||||
"""
|
||||
try:
|
||||
# TODO: Implement list_runs in EvaluationService
|
||||
return get_json_result(data={"runs": [], "total": 0})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/<run_id>', methods=['DELETE']) # noqa: F821
|
||||
@login_required
|
||||
async def delete_evaluation_run(run_id):
|
||||
"""Delete an evaluation run"""
|
||||
try:
|
||||
# TODO: Implement delete_run in EvaluationService
|
||||
return get_json_result(data={"run_id": run_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
# ==================== Analysis & Recommendations ====================
|
||||
|
||||
@manager.route('/run/<run_id>/recommendations', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_recommendations(run_id):
|
||||
"""Get configuration recommendations based on evaluation results"""
|
||||
try:
|
||||
recommendations = EvaluationService.get_recommendations(run_id)
|
||||
return get_json_result(data={"recommendations": recommendations})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/compare', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("run_ids")
|
||||
async def compare_runs():
|
||||
"""
|
||||
Compare multiple evaluation runs.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"run_ids": ["run_id1", "run_id2", "run_id3"]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
run_ids = req.get("run_ids", [])
|
||||
|
||||
if not run_ids or not isinstance(run_ids, list) or len(run_ids) < 2:
|
||||
return get_data_error_result(
|
||||
message="run_ids must be a list with at least 2 run IDs"
|
||||
)
|
||||
|
||||
# TODO: Implement compare_runs in EvaluationService
|
||||
return get_json_result(data={"comparison": {}})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/<run_id>/export', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def export_results(run_id):
|
||||
"""Export evaluation results as JSON/CSV"""
|
||||
try:
|
||||
# format_type = request.args.get("format", "json") # TODO: Use for CSV export
|
||||
|
||||
result = EvaluationService.get_run_results(run_id)
|
||||
|
||||
if not result:
|
||||
return get_data_error_result(
|
||||
message="Evaluation run not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
# TODO: Implement CSV export
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
# ==================== Real-time Evaluation ====================
|
||||
|
||||
@manager.route('/evaluate_single', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "dialog_id")
|
||||
async def evaluate_single():
|
||||
"""
|
||||
Evaluate a single question-answer pair in real-time.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"question": "Test question",
|
||||
"dialog_id": "dialog_id",
|
||||
"reference_answer": "Optional ground truth",
|
||||
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# req = await get_request_json() # TODO: Use for single evaluation implementation
|
||||
|
||||
# TODO: Implement single evaluation
|
||||
# This would execute the RAG pipeline and return metrics immediately
|
||||
|
||||
return get_json_result(data={
|
||||
"answer": "",
|
||||
"metrics": {},
|
||||
"retrieved_chunks": []
|
||||
})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -19,22 +19,20 @@ from pathlib import Path
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from api.apps import login_required, current_user
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode
|
||||
from api.db import FileType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.utils.api_utils import get_json_result
|
||||
|
||||
|
||||
@manager.route('/convert', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids", "kb_ids")
|
||||
def convert():
|
||||
req = request.json
|
||||
async def convert():
|
||||
req = await get_request_json()
|
||||
kb_ids = req["kb_ids"]
|
||||
file_ids = req["file_ids"]
|
||||
file2documents = []
|
||||
@ -79,7 +77,8 @@ def convert():
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": FileService.get_parser(file.type, file.name, kb.parser_id),
|
||||
"parser_id": kb.parser_id,
|
||||
"pipeline_id": kb.pipeline_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": current_user.id,
|
||||
"type": file.type,
|
||||
@ -103,8 +102,8 @@ def convert():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
file_ids = req["file_ids"]
|
||||
if not file_ids:
|
||||
return get_json_result(
|
||||
|
||||
@ -14,13 +14,12 @@
|
||||
# limitations under the License
|
||||
#
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request, make_response
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
from api.common.check_team_permission import check_file_team_permission
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -31,7 +30,7 @@ from common.constants import RetCode, FileSource
|
||||
from api.db import FileType
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.file_service import FileService
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.api_utils import get_json_result, get_request_json
|
||||
from api.utils.file_utils import filename_type
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from common import settings
|
||||
@ -40,17 +39,19 @@ from common import settings
|
||||
@manager.route('/upload', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
# @validate_request("parent_id")
|
||||
def upload():
|
||||
pf_id = request.form.get("parent_id")
|
||||
async def upload():
|
||||
form = await request.form
|
||||
pf_id = form.get("parent_id")
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
if 'file' not in request.files:
|
||||
files = await request.files
|
||||
if 'file' not in files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
file_objs = request.files.getlist('file')
|
||||
file_objs = files.getlist('file')
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
@ -61,9 +62,10 @@ def upload():
|
||||
e, pf_folder = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_data_error_result( message="Can't find this folder!")
|
||||
for file_obj in file_objs:
|
||||
|
||||
async def _handle_single_file(file_obj):
|
||||
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id):
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
|
||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
||||
|
||||
# split file name path
|
||||
@ -75,35 +77,36 @@ def upload():
|
||||
file_len = len(file_obj_names)
|
||||
|
||||
# get folder
|
||||
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
||||
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
|
||||
len_id_list = len(file_id_list)
|
||||
|
||||
# create folder
|
||||
if file_len != len_id_list:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
len_id_list)
|
||||
else:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
||||
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
len_id_list)
|
||||
|
||||
# file type
|
||||
filetype = filename_type(file_obj_names[file_len - 1])
|
||||
location = file_obj_names[file_len - 1]
|
||||
while settings.STORAGE_IMPL.obj_exist(last_folder.id, location):
|
||||
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
|
||||
location += "_"
|
||||
blob = file_obj.read()
|
||||
filename = duplicate_name(
|
||||
blob = await asyncio.to_thread(file_obj.read)
|
||||
filename = await asyncio.to_thread(
|
||||
duplicate_name,
|
||||
FileService.query,
|
||||
name=file_obj_names[file_len - 1],
|
||||
parent_id=last_folder.id)
|
||||
settings.STORAGE_IMPL.put(last_folder.id, location, blob)
|
||||
file = {
|
||||
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
|
||||
file_data = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": last_folder.id,
|
||||
"tenant_id": current_user.id,
|
||||
@ -113,8 +116,13 @@ def upload():
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
}
|
||||
file = FileService.insert(file)
|
||||
file_res.append(file.to_json())
|
||||
inserted = await asyncio.to_thread(FileService.insert, file_data)
|
||||
return inserted.to_json()
|
||||
|
||||
for file_obj in file_objs:
|
||||
res = await _handle_single_file(file_obj)
|
||||
file_res.append(res)
|
||||
|
||||
return get_json_result(data=file_res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -123,10 +131,10 @@ def upload():
|
||||
@manager.route('/create', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.json
|
||||
pf_id = request.json.get("parent_id")
|
||||
input_file_type = request.json.get("type")
|
||||
async def create():
|
||||
req = await get_request_json()
|
||||
pf_id = req.get("parent_id")
|
||||
input_file_type = req.get("type")
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
@ -238,59 +246,62 @@ def get_all_parent_folders():
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
file_ids = req["file_ids"]
|
||||
|
||||
def _delete_single_file(file):
|
||||
try:
|
||||
if file.location:
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
except Exception:
|
||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}")
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(file.id)
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if e and doc:
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if tenant_id:
|
||||
DocumentService.remove_document(doc, tenant_id)
|
||||
File2DocumentService.delete_by_file_id(file.id)
|
||||
|
||||
FileService.delete(file)
|
||||
|
||||
def _delete_folder_recursive(folder, tenant_id):
|
||||
sub_files = FileService.list_all_files_by_parent_id(folder.id)
|
||||
for sub_file in sub_files:
|
||||
if sub_file.type == FileType.FOLDER.value:
|
||||
_delete_folder_recursive(sub_file, tenant_id)
|
||||
else:
|
||||
_delete_single_file(sub_file)
|
||||
|
||||
FileService.delete(folder)
|
||||
|
||||
try:
|
||||
for file_id in file_ids:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e or not file:
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
def _delete_single_file(file):
|
||||
try:
|
||||
if file.location:
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
except Exception as e:
|
||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
|
||||
|
||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||
continue
|
||||
informs = File2DocumentService.get_by_file_id(file.id)
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if e and doc:
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if tenant_id:
|
||||
DocumentService.remove_document(doc, tenant_id)
|
||||
File2DocumentService.delete_by_file_id(file.id)
|
||||
|
||||
if file.type == FileType.FOLDER.value:
|
||||
_delete_folder_recursive(file, current_user.id)
|
||||
continue
|
||||
FileService.delete(file)
|
||||
|
||||
_delete_single_file(file)
|
||||
def _delete_folder_recursive(folder, tenant_id):
|
||||
sub_files = FileService.list_all_files_by_parent_id(folder.id)
|
||||
for sub_file in sub_files:
|
||||
if sub_file.type == FileType.FOLDER.value:
|
||||
_delete_folder_recursive(sub_file, tenant_id)
|
||||
else:
|
||||
_delete_single_file(sub_file)
|
||||
|
||||
return get_json_result(data=True)
|
||||
FileService.delete(folder)
|
||||
|
||||
def _rm_sync():
|
||||
for file_id in file_ids:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e or not file:
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||
continue
|
||||
|
||||
if file.type == FileType.FOLDER.value:
|
||||
_delete_folder_recursive(file, current_user.id)
|
||||
continue
|
||||
|
||||
_delete_single_file(file)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -299,8 +310,8 @@ def rm():
|
||||
@manager.route('/rename', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_id", "name")
|
||||
def rename():
|
||||
req = request.json
|
||||
async def rename():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
@ -338,7 +349,7 @@ def rename():
|
||||
|
||||
@manager.route('/get/<file_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get(file_id):
|
||||
async def get(file_id):
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
@ -346,12 +357,12 @@ def get(file_id):
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
|
||||
if not blob:
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = settings.STORAGE_IMPL.get(b, n)
|
||||
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||
|
||||
response = flask.make_response(blob)
|
||||
response = await make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
if ext:
|
||||
@ -368,8 +379,8 @@ def get(file_id):
|
||||
@manager.route("/mv", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("src_file_ids", "dest_file_id")
|
||||
def move():
|
||||
req = request.json
|
||||
async def move():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
file_ids = req["src_file_ids"]
|
||||
dest_parent_id = req["dest_file_id"]
|
||||
@ -444,10 +455,12 @@ def move():
|
||||
},
|
||||
)
|
||||
|
||||
for file in files:
|
||||
_move_entry_recursive(file, dest_folder)
|
||||
def _move_sync():
|
||||
for file in files:
|
||||
_move_entry_recursive(file, dest_folder)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return get_json_result(data=True)
|
||||
return await asyncio.to_thread(_move_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -16,12 +16,12 @@
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import asyncio
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
import numpy as np
|
||||
|
||||
|
||||
from api.db.services.connector_service import Connector2KbService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
@ -30,7 +30,8 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
|
||||
get_request_json
|
||||
from api.db import VALID_FILE_TYPES
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
@ -41,23 +42,28 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/create', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.json
|
||||
req = KnowledgebaseService.create_with_name(
|
||||
async def create():
|
||||
req = await get_request_json()
|
||||
e, res = KnowledgebaseService.create_with_name(
|
||||
name = req.pop("name", None),
|
||||
tenant_id = current_user.id,
|
||||
parser_id = req.pop("parser_id", None),
|
||||
**req
|
||||
)
|
||||
|
||||
if not e:
|
||||
return res
|
||||
|
||||
try:
|
||||
if not KnowledgebaseService.save(**req):
|
||||
if not KnowledgebaseService.save(**res):
|
||||
return get_data_error_result()
|
||||
return get_json_result(data={"kb_id":req["id"]})
|
||||
return get_json_result(data={"kb_id":res["id"]})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -66,8 +72,8 @@ def create():
|
||||
@login_required
|
||||
@validate_request("kb_id", "name", "description", "parser_id")
|
||||
@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||
def update():
|
||||
req = request.json
|
||||
async def update():
|
||||
req = await get_request_json()
|
||||
if not isinstance(req["name"], str):
|
||||
return get_data_error_result(message="Dataset name must be string.")
|
||||
if req["name"].strip() == "":
|
||||
@ -111,12 +117,22 @@ def update():
|
||||
|
||||
if kb.pagerank != req.get("pagerank", 0):
|
||||
if req.get("pagerank", 0) > 0:
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
await asyncio.to_thread(
|
||||
settings.docStoreConn.update,
|
||||
{"kb_id": kb.id},
|
||||
{PAGERANK_FLD: req["pagerank"]},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb.id,
|
||||
)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
await asyncio.to_thread(
|
||||
settings.docStoreConn.update,
|
||||
{"exists": PAGERANK_FLD},
|
||||
{"remove": PAGERANK_FLD},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb.id,
|
||||
)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||
if not e:
|
||||
@ -165,18 +181,19 @@ def detail():
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def list_kbs():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
parser_id = request.args.get("parser_id")
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
async def list_kbs():
|
||||
args = request.args
|
||||
keywords = args.get("keywords", "")
|
||||
page_number = int(args.get("page", 0))
|
||||
items_per_page = int(args.get("page_size", 0))
|
||||
parser_id = args.get("parser_id")
|
||||
orderby = args.get("orderby", "create_time")
|
||||
if args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -198,11 +215,12 @@ def list_kbs():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -217,25 +235,28 @@ def rm():
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
||||
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||
def _rm_sync():
|
||||
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
||||
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
f2d = File2DocumentService.get_by_document_id(doc.id)
|
||||
if f2d:
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
FileService.filter_delete(
|
||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
|
||||
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
f2d = File2DocumentService.get_by_document_id(doc.id)
|
||||
if f2d:
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
FileService.filter_delete(
|
||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
|
||||
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
for kb in kbs:
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
for kb in kbs:
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -278,8 +299,8 @@ def list_tags_from_kbs():
|
||||
|
||||
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def rm_tags(kb_id):
|
||||
req = request.json
|
||||
async def rm_tags(kb_id):
|
||||
req = await get_request_json()
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -298,8 +319,8 @@ def rm_tags(kb_id):
|
||||
|
||||
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def rename_tags(kb_id):
|
||||
req = request.json
|
||||
async def rename_tags(kb_id):
|
||||
req = await get_request_json()
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -402,7 +423,7 @@ def get_basic_info():
|
||||
|
||||
@manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_pipeline_logs():
|
||||
async def list_pipeline_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -421,7 +442,7 @@ def list_pipeline_logs():
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
@ -446,7 +467,7 @@ def list_pipeline_logs():
|
||||
|
||||
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_pipeline_dataset_logs():
|
||||
async def list_pipeline_dataset_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -463,7 +484,7 @@ def list_pipeline_dataset_logs():
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
@ -480,12 +501,12 @@ def list_pipeline_dataset_logs():
|
||||
|
||||
@manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def delete_pipeline_logs():
|
||||
async def delete_pipeline_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
log_ids = req.get("log_ids", [])
|
||||
|
||||
PipelineOperationLogService.delete_by_ids(log_ids)
|
||||
@ -509,8 +530,8 @@ def pipeline_log_detail():
|
||||
|
||||
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_graphrag():
|
||||
req = request.json
|
||||
async def run_graphrag():
|
||||
req = await get_request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -578,8 +599,8 @@ def trace_graphrag():
|
||||
|
||||
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_raptor():
|
||||
req = request.json
|
||||
async def run_raptor():
|
||||
req = await get_request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -647,8 +668,8 @@ def trace_raptor():
|
||||
|
||||
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_mindmap():
|
||||
req = request.json
|
||||
async def run_mindmap():
|
||||
req = await get_request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -731,6 +752,8 @@ def delete_kb_task():
|
||||
def cancel_task(task_id):
|
||||
REDIS_CONN.set(f"{task_id}-cancel", "x")
|
||||
|
||||
kb_task_id_field: str = ""
|
||||
kb_task_finish_at: str = ""
|
||||
match pipeline_task_type:
|
||||
case PipelineTaskType.GRAPH_RAG:
|
||||
kb_task_id_field = "graphrag_task_id"
|
||||
@ -761,7 +784,7 @@ def delete_kb_task():
|
||||
|
||||
@manager.route("/check_embedding", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
def check_embedding():
|
||||
async def check_embedding():
|
||||
|
||||
def _guess_vec_field(src: dict) -> str | None:
|
||||
for k in src or {}:
|
||||
@ -807,12 +830,12 @@ def check_embedding():
|
||||
offset=0, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
)
|
||||
total = docStoreConn.getTotal(res0)
|
||||
total = docStoreConn.get_total(res0)
|
||||
if total <= 0:
|
||||
return []
|
||||
|
||||
n = min(n, total)
|
||||
offsets = sorted(random.sample(range(total), n))
|
||||
offsets = sorted(random.sample(range(min(total,1000)), n))
|
||||
out = []
|
||||
|
||||
for off in offsets:
|
||||
@ -824,7 +847,7 @@ def check_embedding():
|
||||
offset=off, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
)
|
||||
ids = docStoreConn.getChunkIds(res1)
|
||||
ids = docStoreConn.get_chunk_ids(res1)
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
@ -845,9 +868,14 @@ def check_embedding():
|
||||
"position_int": full_doc.get("position_int"),
|
||||
"top_int": full_doc.get("top_int"),
|
||||
"content_with_weight": full_doc.get("content_with_weight") or "",
|
||||
"question_kwd": full_doc.get("question_kwd") or []
|
||||
})
|
||||
return out
|
||||
req = request.json
|
||||
|
||||
def _clean(s: str) -> str:
|
||||
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
|
||||
return s if s else "None"
|
||||
req = await get_request_json()
|
||||
kb_id = req.get("kb_id", "")
|
||||
embd_id = req.get("embd_id", "")
|
||||
n = int(req.get("check_num", 5))
|
||||
@ -859,8 +887,10 @@ def check_embedding():
|
||||
|
||||
results, eff_sims = [], []
|
||||
for ck in samples:
|
||||
txt = (ck.get("content_with_weight") or "").strip()
|
||||
if not txt:
|
||||
title = ck.get("doc_name") or "Title"
|
||||
txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
|
||||
txt_in = _clean(txt_in)
|
||||
if not txt_in:
|
||||
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
|
||||
continue
|
||||
|
||||
@ -869,10 +899,19 @@ def check_embedding():
|
||||
continue
|
||||
|
||||
try:
|
||||
qv, _ = emb_mdl.encode_queries(txt)
|
||||
sim = _cos_sim(qv, ck["vector"])
|
||||
except Exception:
|
||||
return get_error_data_result(message="embedding failure")
|
||||
v, _ = emb_mdl.encode([title, txt_in])
|
||||
assert len(v[1]) == len(ck["vector"]), f"The dimension ({len(v[1])}) of given embedding model is different from the original ({len(ck['vector'])})"
|
||||
sim_content = _cos_sim(v[1], ck["vector"])
|
||||
title_w = 0.1
|
||||
qv_mix = title_w * v[0] + (1 - title_w) * v[1]
|
||||
sim_mix = _cos_sim(qv_mix, ck["vector"])
|
||||
sim = sim_content
|
||||
mode = "content_only"
|
||||
if sim_mix > sim:
|
||||
sim = sim_mix
|
||||
mode = "title+content"
|
||||
except Exception as e:
|
||||
return get_error_data_result(message=f"Embedding failure. {e}")
|
||||
|
||||
eff_sims.append(sim)
|
||||
results.append({
|
||||
@ -892,9 +931,8 @@ def check_embedding():
|
||||
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"match_mode": mode,
|
||||
}
|
||||
if summary["avg_cos_sim"] > 0.99:
|
||||
if summary["avg_cos_sim"] > 0.9:
|
||||
return get_json_result(data={"summary": summary, "results": results})
|
||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})
|
||||
|
||||
|
||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})
|
||||
|
||||
@ -15,20 +15,19 @@
|
||||
#
|
||||
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
from api.apps import current_user, login_required
|
||||
from langfuse import Langfuse
|
||||
|
||||
from api.db.db_models import DB
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.utils.api_utils import get_error_data_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_error_data_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
|
||||
|
||||
@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("secret_key", "public_key", "host")
|
||||
def set_api_key():
|
||||
req = request.get_json()
|
||||
async def set_api_key():
|
||||
req = await get_request_json()
|
||||
secret_key = req.get("secret_key", "")
|
||||
public_key = req.get("public_key", "")
|
||||
host = req.get("host", "")
|
||||
|
||||
@ -16,14 +16,14 @@
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
|
||||
from api.apps import login_required, current_user
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from common.constants import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.utils.api_utils import get_json_result, get_allowed_llm_factories
|
||||
from rag.utils.base64_image import test_image
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
||||
|
||||
@ -52,8 +52,8 @@ def factories():
|
||||
@manager.route("/set_api_key", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "api_key")
|
||||
def set_api_key():
|
||||
req = request.json
|
||||
async def set_api_key():
|
||||
req = await get_request_json()
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
factory = req["llm_factory"]
|
||||
@ -122,8 +122,8 @@ def set_api_key():
|
||||
@manager.route("/add_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
def add_llm():
|
||||
req = request.json
|
||||
async def add_llm():
|
||||
req = await get_request_json()
|
||||
factory = req["llm_factory"]
|
||||
api_key = req.get("api_key", "x")
|
||||
llm_name = req.get("llm_name")
|
||||
@ -142,11 +142,11 @@ def add_llm():
|
||||
|
||||
elif factory == "Tencent Hunyuan":
|
||||
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
|
||||
return set_api_key()
|
||||
return await set_api_key()
|
||||
|
||||
elif factory == "Tencent Cloud":
|
||||
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
|
||||
return set_api_key()
|
||||
return await set_api_key()
|
||||
|
||||
elif factory == "Bedrock":
|
||||
# For Bedrock, due to its special authentication method
|
||||
@ -267,8 +267,8 @@ def add_llm():
|
||||
@manager.route("/delete_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
def delete_llm():
|
||||
req = request.json
|
||||
async def delete_llm():
|
||||
req = await get_request_json()
|
||||
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -276,8 +276,8 @@ def delete_llm():
|
||||
@manager.route("/enable_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
def enable_llm():
|
||||
req = request.json
|
||||
async def enable_llm():
|
||||
req = await get_request_json()
|
||||
TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
|
||||
)
|
||||
@ -287,8 +287,8 @@ def enable_llm():
|
||||
@manager.route("/delete_factory", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
def delete_factory():
|
||||
req = request.json
|
||||
async def delete_factory():
|
||||
req = await get_request_json()
|
||||
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import Response, request
|
||||
from api.apps import current_user, login_required
|
||||
|
||||
from api.db.db_models import MCPServer
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
@ -22,15 +22,14 @@ from api.db.services.user_service import TenantService
|
||||
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
|
||||
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
get_mcp_tools
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
|
||||
from api.utils.web_utils import get_float, safe_json_parse
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_mcp() -> Response:
|
||||
async def list_mcp() -> Response:
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
@ -40,7 +39,7 @@ def list_mcp() -> Response:
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
try:
|
||||
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
|
||||
@ -72,8 +71,8 @@ def detail() -> Response:
|
||||
@manager.route("/create", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "url", "server_type")
|
||||
def create() -> Response:
|
||||
req = request.get_json()
|
||||
async def create() -> Response:
|
||||
req = await get_request_json()
|
||||
|
||||
server_type = req.get("server_type", "")
|
||||
if server_type not in VALID_MCP_SERVER_TYPES:
|
||||
@ -127,8 +126,8 @@ def create() -> Response:
|
||||
@manager.route("/update", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id")
|
||||
def update() -> Response:
|
||||
req = request.get_json()
|
||||
async def update() -> Response:
|
||||
req = await get_request_json()
|
||||
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
@ -183,8 +182,8 @@ def update() -> Response:
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def rm() -> Response:
|
||||
req = request.get_json()
|
||||
async def rm() -> Response:
|
||||
req = await get_request_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
|
||||
try:
|
||||
@ -201,8 +200,8 @@ def rm() -> Response:
|
||||
@manager.route("/import", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcpServers")
|
||||
def import_multiple() -> Response:
|
||||
req = request.get_json()
|
||||
async def import_multiple() -> Response:
|
||||
req = await get_request_json()
|
||||
servers = req.get("mcpServers", {})
|
||||
if not servers:
|
||||
return get_data_error_result(message="No MCP servers provided.")
|
||||
@ -268,8 +267,8 @@ def import_multiple() -> Response:
|
||||
@manager.route("/export", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def export_multiple() -> Response:
|
||||
req = request.get_json()
|
||||
async def export_multiple() -> Response:
|
||||
req = await get_request_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
|
||||
if not mcp_ids:
|
||||
@ -300,8 +299,8 @@ def export_multiple() -> Response:
|
||||
@manager.route("/list_tools", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def list_tools() -> Response:
|
||||
req = request.get_json()
|
||||
async def list_tools() -> Response:
|
||||
req = await get_request_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
if not mcp_ids:
|
||||
return get_data_error_result(message="No MCP server IDs provided.")
|
||||
@ -347,8 +346,8 @@ def list_tools() -> Response:
|
||||
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id", "tool_name", "arguments")
|
||||
def test_tool() -> Response:
|
||||
req = request.get_json()
|
||||
async def test_tool() -> Response:
|
||||
req = await get_request_json()
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
@ -380,8 +379,8 @@ def test_tool() -> Response:
|
||||
@manager.route("/cache_tools", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id", "tools")
|
||||
def cache_tool() -> Response:
|
||||
req = request.get_json()
|
||||
async def cache_tool() -> Response:
|
||||
req = await get_request_json()
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
@ -403,8 +402,8 @@ def cache_tool() -> Response:
|
||||
|
||||
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821
|
||||
@validate_request("url", "server_type")
|
||||
def test_mcp() -> Response:
|
||||
req = request.get_json()
|
||||
async def test_mcp() -> Response:
|
||||
req = await get_request_json()
|
||||
|
||||
url = req.get("url", "")
|
||||
if not url:
|
||||
|
||||
@ -15,8 +15,8 @@
|
||||
#
|
||||
|
||||
|
||||
from flask import Response
|
||||
from flask_login import login_required
|
||||
from quart import Response
|
||||
from api.apps import login_required
|
||||
from api.utils.api_utils import get_json_result
|
||||
from plugin import GlobalPluginManager
|
||||
|
||||
|
||||
@ -25,9 +25,9 @@ from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required
|
||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
|
||||
from api.utils.api_utils import get_result
|
||||
from flask import request, Response
|
||||
from quart import request, Response
|
||||
|
||||
|
||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||
@ -41,19 +41,19 @@ def list_agents(tenant_id):
|
||||
return get_error_data_result("The agent doesn't exist.")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 30))
|
||||
orderby = request.args.get("orderby", "update_time")
|
||||
order_by = request.args.get("orderby", "update_time")
|
||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title)
|
||||
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, order_by, desc, id, title)
|
||||
return get_result(data=canvas)
|
||||
|
||||
|
||||
@manager.route("/agents", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create_agent(tenant_id: str):
|
||||
req: dict[str, Any] = cast(dict[str, Any], request.json)
|
||||
async def create_agent(tenant_id: str):
|
||||
req: dict[str, Any] = cast(dict[str, Any], await get_request_json())
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
@ -89,8 +89,8 @@ def create_agent(tenant_id: str):
|
||||
|
||||
@manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update_agent(tenant_id: str, agent_id: str):
|
||||
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], request.json).items() if v is not None}
|
||||
async def update_agent(tenant_id: str, agent_id: str):
|
||||
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await get_request_json())).items() if v is not None}
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
@ -135,8 +135,8 @@ def delete_agent(tenant_id: str, agent_id: str):
|
||||
|
||||
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def webhook(tenant_id: str, agent_id: str):
|
||||
req = request.json
|
||||
async def webhook(tenant_id: str, agent_id: str):
|
||||
req = await get_request_json()
|
||||
if not UserCanvasService.accessible(req["id"], tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -159,10 +159,10 @@ def webhook(tenant_id: str, agent_id: str):
|
||||
data=False, message=str(e),
|
||||
code=RetCode.EXCEPTION_ERROR)
|
||||
|
||||
def sse():
|
||||
async def sse():
|
||||
nonlocal canvas
|
||||
try:
|
||||
for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req):
|
||||
async for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
|
||||
@ -14,22 +14,20 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
|
||||
from quart import request
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode, StatusEnum
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, get_request_json
|
||||
|
||||
|
||||
@manager.route("/chats", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
req = request.json
|
||||
async def create(tenant_id):
|
||||
req = await get_request_json()
|
||||
ids = [i for i in req.get("dataset_ids", []) if i]
|
||||
for kb_id in ids:
|
||||
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
|
||||
@ -145,10 +143,10 @@ def create(tenant_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, chat_id):
|
||||
async def update(tenant_id, chat_id):
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message="You do not own the chat")
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
ids = req.get("dataset_ids", [])
|
||||
if "show_quotation" in req:
|
||||
req["do_refer"] = req.pop("show_quotation")
|
||||
@ -228,10 +226,10 @@ def update(tenant_id, chat_id):
|
||||
|
||||
@manager.route("/chats", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id):
|
||||
async def delete_chats(tenant_id):
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
if not req:
|
||||
ids = None
|
||||
else:
|
||||
@ -251,8 +249,8 @@ def delete(tenant_id):
|
||||
errors.append(f"Assistant({id}) not found.")
|
||||
continue
|
||||
temp_dict = {"status": StatusEnum.INVALID.value}
|
||||
DialogService.update_by_id(id, temp_dict)
|
||||
success_count += 1
|
||||
success_count += DialogService.update_by_id(id, temp_dict)
|
||||
print(success_count, "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$", flush=True)
|
||||
|
||||
if errors:
|
||||
if success_count > 0:
|
||||
|
||||
@ -18,13 +18,14 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from flask import request
|
||||
from quart import request
|
||||
from peewee import OperationalError
|
||||
from api.db.db_models import File
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.constants import RetCode, FileSource, StatusEnum
|
||||
from api.utils.api_utils import (
|
||||
@ -53,7 +54,7 @@ from common import settings
|
||||
|
||||
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
async def create(tenant_id):
|
||||
"""
|
||||
Create a new dataset.
|
||||
---
|
||||
@ -115,17 +116,19 @@ def create(tenant_id):
|
||||
# | embedding_model| embd_id |
|
||||
# | chunk_method | parser_id |
|
||||
|
||||
req, err = validate_and_parse_json_request(request, CreateDatasetReq)
|
||||
req, err = await validate_and_parse_json_request(request, CreateDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
req = KnowledgebaseService.create_with_name(
|
||||
e, req = KnowledgebaseService.create_with_name(
|
||||
name = req.pop("name", None),
|
||||
tenant_id = tenant_id,
|
||||
parser_id = req.pop("parser_id", None),
|
||||
**req
|
||||
)
|
||||
|
||||
if not e:
|
||||
return req
|
||||
|
||||
# Insert embedding model(embd id)
|
||||
ok, t = TenantService.get_by_id(tenant_id)
|
||||
if not ok:
|
||||
@ -144,7 +147,6 @@ def create(tenant_id):
|
||||
ok, k = KnowledgebaseService.get_by_id(req["id"])
|
||||
if not ok:
|
||||
return get_error_data_result(message="Dataset created failed")
|
||||
|
||||
response_data = remap_dictionary_keys(k.to_dict())
|
||||
return get_result(data=response_data)
|
||||
except Exception as e:
|
||||
@ -153,7 +155,7 @@ def create(tenant_id):
|
||||
|
||||
@manager.route("/datasets", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id):
|
||||
async def delete(tenant_id):
|
||||
"""
|
||||
Delete datasets.
|
||||
---
|
||||
@ -191,7 +193,7 @@ def delete(tenant_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req, err = validate_and_parse_json_request(request, DeleteDatasetReq)
|
||||
req, err = await validate_and_parse_json_request(request, DeleteDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
@ -251,7 +253,7 @@ def delete(tenant_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, dataset_id):
|
||||
async def update(tenant_id, dataset_id):
|
||||
"""
|
||||
Update a dataset.
|
||||
---
|
||||
@ -317,7 +319,7 @@ def update(tenant_id, dataset_id):
|
||||
# | embedding_model| embd_id |
|
||||
# | chunk_method | parser_id |
|
||||
extras = {"dataset_id": dataset_id}
|
||||
req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
|
||||
req, err = await validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
@ -532,3 +534,157 @@ def delete_knowledge_graph(tenant_id, dataset_id):
|
||||
search.index_name(kb.tenant_id), dataset_id)
|
||||
|
||||
return get_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/run_graphrag", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def run_graphrag(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=dataset_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
||||
logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}")
|
||||
|
||||
return get_result(data={"graphrag_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/trace_graphrag", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def trace_graphrag(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if not task_id:
|
||||
return get_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_result(data={})
|
||||
|
||||
return get_result(data=task.to_dict())
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/run_raptor", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def run_raptor(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=dataset_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
|
||||
logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}")
|
||||
|
||||
return get_result(data={"raptor_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/trace_raptor", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def trace_raptor(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if not task_id:
|
||||
return get_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
|
||||
|
||||
return get_result(data=task.to_dict())
|
||||
@ -15,12 +15,12 @@
|
||||
#
|
||||
import logging
|
||||
|
||||
from flask import request, jsonify
|
||||
from quart import jsonify
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
||||
from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request
|
||||
from rag.app.tag import label_question
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
from common.constants import RetCode, LLMType
|
||||
@ -29,7 +29,7 @@ from common import settings
|
||||
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
|
||||
@apikey_required
|
||||
@validate_request("knowledge_id", "query")
|
||||
def retrieval(tenant_id):
|
||||
async def retrieval(tenant_id):
|
||||
"""
|
||||
Dify-compatible retrieval API
|
||||
---
|
||||
@ -113,14 +113,14 @@ def retrieval(tenant_id):
|
||||
404:
|
||||
description: Knowledge base or document not found
|
||||
"""
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
question = req["query"]
|
||||
kb_id = req["knowledge_id"]
|
||||
use_kg = req.get("use_kg", False)
|
||||
retrieval_setting = req.get("retrieval_setting", {})
|
||||
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
|
||||
top = int(retrieval_setting.get("top_k", 1024))
|
||||
metadata_condition = req.get("metadata_condition", {})
|
||||
metadata_condition = req.get("metadata_condition", {}) or {}
|
||||
metas = DocumentService.get_meta_by_kbs([kb_id])
|
||||
|
||||
doc_ids = []
|
||||
@ -131,12 +131,10 @@ def retrieval(tenant_id):
|
||||
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
print(metadata_condition)
|
||||
# print("after", convert_conditions(metadata_condition))
|
||||
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition)))
|
||||
# print("doc_ids", doc_ids)
|
||||
if not doc_ids and metadata_condition is not None:
|
||||
doc_ids = ['-999']
|
||||
if metadata_condition:
|
||||
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||
if not doc_ids and metadata_condition:
|
||||
doc_ids = ["-999"]
|
||||
ranks = settings.retriever.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
|
||||
@ -20,7 +20,7 @@ import re
|
||||
from io import BytesIO
|
||||
|
||||
import xxhash
|
||||
from flask import request, send_file
|
||||
from quart import request, send_file
|
||||
from peewee import OperationalError
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
@ -33,9 +33,10 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.task_service import TaskService, queue_tasks
|
||||
from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
|
||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
|
||||
get_request_json
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
@ -69,7 +70,7 @@ class Chunk(BaseModel):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def upload(dataset_id, tenant_id):
|
||||
async def upload(dataset_id, tenant_id):
|
||||
"""
|
||||
Upload documents to a dataset.
|
||||
---
|
||||
@ -93,6 +94,10 @@ def upload(dataset_id, tenant_id):
|
||||
type: file
|
||||
required: true
|
||||
description: Document files to upload.
|
||||
- in: formData
|
||||
name: parent_path
|
||||
type: string
|
||||
description: Optional nested path under the parent folder. Uses '/' separators.
|
||||
responses:
|
||||
200:
|
||||
description: Successfully uploaded documents.
|
||||
@ -126,9 +131,11 @@ def upload(dataset_id, tenant_id):
|
||||
type: string
|
||||
description: Processing status.
|
||||
"""
|
||||
if "file" not in request.files:
|
||||
form = await request.form
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_error_data_result(message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_result(message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -151,7 +158,7 @@ def upload(dataset_id, tenant_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not e:
|
||||
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
|
||||
err, files = FileService.upload_document(kb, file_objs, tenant_id)
|
||||
err, files = FileService.upload_document(kb, file_objs, tenant_id, parent_path=form.get("parent_path"))
|
||||
if err:
|
||||
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||
# rename key's name
|
||||
@ -175,7 +182,7 @@ def upload(dataset_id, tenant_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update_doc(tenant_id, dataset_id, document_id):
|
||||
async def update_doc(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Update a document within a dataset.
|
||||
---
|
||||
@ -224,7 +231,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
|
||||
return get_error_data_result(message="You don't own the dataset.")
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
@ -314,9 +321,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
try:
|
||||
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
|
||||
return get_error_data_result(message="Database error (Document update)!")
|
||||
|
||||
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
||||
return get_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -343,19 +348,17 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
}
|
||||
renamed_doc = {}
|
||||
for key, value in doc.to_dict().items():
|
||||
if key == "run":
|
||||
renamed_doc["run"] = run_mapping.get(str(value))
|
||||
new_key = key_mapping.get(key, key)
|
||||
renamed_doc[new_key] = value
|
||||
if key == "run":
|
||||
renamed_doc["run"] = run_mapping.get(value)
|
||||
renamed_doc["run"] = run_mapping.get(str(value))
|
||||
|
||||
return get_result(data=renamed_doc)
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def download(tenant_id, dataset_id, document_id):
|
||||
async def download(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Download a document from a dataset.
|
||||
---
|
||||
@ -405,10 +408,10 @@ def download(tenant_id, dataset_id, document_id):
|
||||
return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR)
|
||||
file = BytesIO(file_stream)
|
||||
# Use send_file with a proper filename and MIME type
|
||||
return send_file(
|
||||
return await send_file(
|
||||
file,
|
||||
as_attachment=True,
|
||||
download_name=doc[0].name,
|
||||
attachment_filename=doc[0].name,
|
||||
mimetype="application/octet-stream", # Set a default MIME type
|
||||
)
|
||||
|
||||
@ -529,7 +532,7 @@ def list_docs(dataset_id, tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||
|
||||
q = request.args
|
||||
document_id = q.get("id")
|
||||
document_id = q.get("id")
|
||||
name = q.get("name")
|
||||
|
||||
if document_id and not DocumentService.query(id=document_id, kb_id=dataset_id):
|
||||
@ -538,16 +541,16 @@ def list_docs(dataset_id, tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the document {name}.")
|
||||
|
||||
page = int(q.get("page", 1))
|
||||
page_size = int(q.get("page_size", 30))
|
||||
page_size = int(q.get("page_size", 30))
|
||||
orderby = q.get("orderby", "create_time")
|
||||
desc = str(q.get("desc", "true")).strip().lower() != "false"
|
||||
keywords = q.get("keywords", "")
|
||||
|
||||
# filters - align with OpenAPI parameter names
|
||||
suffix = q.getlist("suffix")
|
||||
run_status = q.getlist("run")
|
||||
create_time_from = int(q.get("create_time_from", 0))
|
||||
create_time_to = int(q.get("create_time_to", 0))
|
||||
suffix = q.getlist("suffix")
|
||||
run_status = q.getlist("run")
|
||||
create_time_from = int(q.get("create_time_from", 0))
|
||||
create_time_to = int(q.get("create_time_to", 0))
|
||||
|
||||
# map run status (accept text or numeric) - align with API parameter
|
||||
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
|
||||
@ -568,7 +571,7 @@ def list_docs(dataset_id, tenant_id):
|
||||
# rename keys + map run status back to text for output
|
||||
key_mapping = {
|
||||
"chunk_num": "chunk_count",
|
||||
"kb_id": "dataset_id",
|
||||
"kb_id": "dataset_id",
|
||||
"token_num": "token_count",
|
||||
"parser_id": "chunk_method",
|
||||
}
|
||||
@ -585,7 +588,7 @@ def list_docs(dataset_id, tenant_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id, dataset_id):
|
||||
async def delete(tenant_id, dataset_id):
|
||||
"""
|
||||
Delete documents from a dataset.
|
||||
---
|
||||
@ -624,7 +627,7 @@ def delete(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
if not req:
|
||||
doc_ids = None
|
||||
else:
|
||||
@ -695,7 +698,7 @@ def delete(tenant_id, dataset_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/chunks", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def parse(tenant_id, dataset_id):
|
||||
async def parse(tenant_id, dataset_id):
|
||||
"""
|
||||
Start parsing documents into chunks.
|
||||
---
|
||||
@ -734,7 +737,7 @@ def parse(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
if not req.get("document_ids"):
|
||||
return get_error_data_result("`document_ids` is required")
|
||||
doc_list = req.get("document_ids")
|
||||
@ -778,7 +781,7 @@ def parse(tenant_id, dataset_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/chunks", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def stop_parsing(tenant_id, dataset_id):
|
||||
async def stop_parsing(tenant_id, dataset_id):
|
||||
"""
|
||||
Stop parsing documents into chunks.
|
||||
---
|
||||
@ -817,7 +820,7 @@ def stop_parsing(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
|
||||
if not req.get("document_ids"):
|
||||
return get_error_data_result("`document_ids` is required")
|
||||
@ -832,6 +835,8 @@ def stop_parsing(tenant_id, dataset_id):
|
||||
return get_error_data_result(message=f"You don't own the document {id}.")
|
||||
if int(doc[0].progress) == 1 or doc[0].progress == 0:
|
||||
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
|
||||
# Send cancellation signal via Redis to stop background task
|
||||
cancel_all_task_of(id)
|
||||
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
||||
DocumentService.update_by_id(id, info)
|
||||
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
|
||||
@ -1019,7 +1024,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
"/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["POST"]
|
||||
)
|
||||
@token_required
|
||||
def add_chunk(tenant_id, dataset_id, document_id):
|
||||
async def add_chunk(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Add a chunk to a document.
|
||||
---
|
||||
@ -1089,7 +1094,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
||||
if not doc:
|
||||
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
||||
doc = doc[0]
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
if not str(req.get("content", "")).strip():
|
||||
return get_error_data_result(message="`content` is required")
|
||||
if "important_keywords" in req:
|
||||
@ -1148,7 +1153,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
||||
"datasets/<dataset_id>/documents/<document_id>/chunks", methods=["DELETE"]
|
||||
)
|
||||
@token_required
|
||||
def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
async def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Remove chunks from a document.
|
||||
---
|
||||
@ -1195,7 +1200,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
docs = DocumentService.get_by_ids([document_id])
|
||||
if not docs:
|
||||
raise LookupError(f"Can't find the document with ID {document_id}!")
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
condition = {"doc_id": document_id}
|
||||
if "chunk_ids" in req:
|
||||
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
|
||||
@ -1219,7 +1224,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
"/datasets/<dataset_id>/documents/<document_id>/chunks/<chunk_id>", methods=["PUT"]
|
||||
)
|
||||
@token_required
|
||||
def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
"""
|
||||
Update a chunk within a document.
|
||||
---
|
||||
@ -1281,8 +1286,8 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
if not doc:
|
||||
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
||||
doc = doc[0]
|
||||
req = request.json
|
||||
if "content" in req:
|
||||
req = await get_request_json()
|
||||
if "content" in req and req["content"] is not None:
|
||||
content = req["content"]
|
||||
else:
|
||||
content = chunk.get("content_with_weight", "")
|
||||
@ -1323,7 +1328,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
|
||||
@manager.route("/retrieval", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def retrieval_test(tenant_id):
|
||||
async def retrieval_test(tenant_id):
|
||||
"""
|
||||
Retrieve chunks based on a query.
|
||||
---
|
||||
@ -1404,7 +1409,7 @@ def retrieval_test(tenant_id):
|
||||
format: float
|
||||
description: Similarity score.
|
||||
"""
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
if not req.get("dataset_ids"):
|
||||
return get_error_data_result("`dataset_ids` is required.")
|
||||
kb_ids = req["dataset_ids"]
|
||||
@ -1427,6 +1432,7 @@ def retrieval_test(tenant_id):
|
||||
question = req["question"]
|
||||
doc_ids = req.get("document_ids", [])
|
||||
use_kg = req.get("use_kg", False)
|
||||
toc_enhance = req.get("toc_enhance", False)
|
||||
langs = req.get("cross_languages", [])
|
||||
if not isinstance(doc_ids, list):
|
||||
return get_error_data_result("`documents` should be a list")
|
||||
@ -1435,9 +1441,14 @@ def retrieval_test(tenant_id):
|
||||
if doc_id not in doc_ids_list:
|
||||
return get_error_data_result(f"The datasets don't own the document {doc_id}")
|
||||
if not doc_ids:
|
||||
metadata_condition = req.get("metadata_condition", {})
|
||||
metadata_condition = req.get("metadata_condition", {}) or {}
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
doc_ids = meta_filter(metas, convert_conditions(metadata_condition))
|
||||
doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
|
||||
# If metadata_condition has conditions but no docs match, return empty result
|
||||
if not doc_ids and metadata_condition.get("conditions"):
|
||||
return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}})
|
||||
if metadata_condition and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = int(req.get("top_k", 1024))
|
||||
@ -1478,6 +1489,11 @@ def retrieval_test(tenant_id):
|
||||
highlight=highlight,
|
||||
rank_feature=label_question(question, kbs),
|
||||
)
|
||||
if toc_enhance:
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
cks = settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
|
||||
if cks:
|
||||
ranks["chunks"] = cks
|
||||
if use_kg:
|
||||
ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
|
||||
@ -14,35 +14,33 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from quart import request, make_response
|
||||
from pathlib import Path
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import server_error_response, token_required
|
||||
from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
|
||||
from common.misc_utils import get_uuid
|
||||
from api.db import FileType
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.file_service import FileService
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.file_utils import filename_type
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from common import settings
|
||||
|
||||
from common.constants import RetCode
|
||||
|
||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def upload(tenant_id):
|
||||
async def upload(tenant_id):
|
||||
"""
|
||||
Upload a file to the system.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -79,26 +77,28 @@ def upload(tenant_id):
|
||||
type: string
|
||||
description: File type (e.g., document, folder)
|
||||
"""
|
||||
pf_id = request.form.get("parent_id")
|
||||
form = await request.form
|
||||
files = await request.files
|
||||
pf_id = form.get("parent_id")
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(data=False, message='No file part!', code=400)
|
||||
file_objs = request.files.getlist('file')
|
||||
if 'file' not in files:
|
||||
return get_json_result(data=False, message='No file part!', code=RetCode.BAD_REQUEST)
|
||||
file_objs = files.getlist('file')
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(data=False, message='No selected file!', code=400)
|
||||
return get_json_result(data=False, message='No selected file!', code=RetCode.BAD_REQUEST)
|
||||
|
||||
file_res = []
|
||||
|
||||
try:
|
||||
e, pf_folder = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Can't find this folder!", code=404)
|
||||
return get_json_result(data=False, message="Can't find this folder!", code=RetCode.NOT_FOUND)
|
||||
|
||||
for file_obj in file_objs:
|
||||
# Handle file path
|
||||
@ -114,13 +114,13 @@ def upload(tenant_id):
|
||||
if file_len != len_id_list:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
||||
return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
len_id_list)
|
||||
else:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
||||
return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
len_id_list)
|
||||
|
||||
@ -151,12 +151,12 @@ def upload(tenant_id):
|
||||
|
||||
@manager.route('/file/create', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
async def create(tenant_id):
|
||||
"""
|
||||
Create a new file or folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -193,16 +193,16 @@ def create(tenant_id):
|
||||
type:
|
||||
type: string
|
||||
"""
|
||||
req = request.json
|
||||
pf_id = request.json.get("parent_id")
|
||||
input_file_type = request.json.get("type")
|
||||
req = await get_request_json()
|
||||
pf_id = req.get("parent_id")
|
||||
input_file_type = req.get("type")
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
try:
|
||||
if not FileService.is_parent_folder_exist(pf_id):
|
||||
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=400)
|
||||
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=RetCode.BAD_REQUEST)
|
||||
if FileService.query(name=req["name"], parent_id=pf_id):
|
||||
return get_json_result(data=False, message="Duplicated folder name in the same folder.", code=409)
|
||||
|
||||
@ -229,12 +229,12 @@ def create(tenant_id):
|
||||
|
||||
@manager.route('/file/list', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def list_files(tenant_id):
|
||||
async def list_files(tenant_id):
|
||||
"""
|
||||
List files under a specific folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -306,13 +306,13 @@ def list_files(tenant_id):
|
||||
try:
|
||||
e, file = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
files, total = FileService.get_by_pf_id(tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords)
|
||||
|
||||
parent_folder = FileService.get_parent_folder(pf_id)
|
||||
if not parent_folder:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
return get_json_result(data={"total": total, "files": files, "parent_folder": parent_folder.to_json()})
|
||||
except Exception as e:
|
||||
@ -321,12 +321,12 @@ def list_files(tenant_id):
|
||||
|
||||
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_root_folder(tenant_id):
|
||||
async def get_root_folder(tenant_id):
|
||||
"""
|
||||
Get user's root folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
responses:
|
||||
@ -357,12 +357,12 @@ def get_root_folder(tenant_id):
|
||||
|
||||
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_parent_folder():
|
||||
async def get_parent_folder():
|
||||
"""
|
||||
Get parent folder info of a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -392,7 +392,7 @@ def get_parent_folder():
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
parent_folder = FileService.get_parent_folder(file_id)
|
||||
return get_json_result(data={"parent_folder": parent_folder.to_json()})
|
||||
@ -402,12 +402,12 @@ def get_parent_folder():
|
||||
|
||||
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_all_parent_folders(tenant_id):
|
||||
async def get_all_parent_folders(tenant_id):
|
||||
"""
|
||||
Get all parent folders of a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -439,7 +439,7 @@ def get_all_parent_folders(tenant_id):
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
parent_folders = FileService.get_all_parent_folders(file_id)
|
||||
parent_folders_res = [folder.to_json() for folder in parent_folders]
|
||||
@ -450,12 +450,12 @@ def get_all_parent_folders(tenant_id):
|
||||
|
||||
@manager.route('/file/rm', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def rm(tenant_id):
|
||||
async def rm(tenant_id):
|
||||
"""
|
||||
Delete one or multiple files/folders.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -481,40 +481,40 @@ def rm(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
file_ids = req["file_ids"]
|
||||
try:
|
||||
for file_id in file_ids:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="File or Folder not found!", code=404)
|
||||
return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND)
|
||||
if not file.tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_id_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
for inner_file_id in file_id_list:
|
||||
e, file = FileService.get_by_id(inner_file_id)
|
||||
if not e:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
FileService.delete_folder_by_pf_id(tenant_id, file_id)
|
||||
else:
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
if not FileService.delete(file):
|
||||
return get_json_result(message="Database error (File removal)!", code=500)
|
||||
return get_json_result(message="Database error (File removal)!", code=RetCode.SERVER_ERROR)
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(file_id)
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_json_result(message="Database error (Document removal)!", code=500)
|
||||
return get_json_result(message="Database error (Document removal)!", code=RetCode.SERVER_ERROR)
|
||||
File2DocumentService.delete_by_file_id(file_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
@ -524,12 +524,12 @@ def rm(tenant_id):
|
||||
|
||||
@manager.route('/file/rename', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def rename(tenant_id):
|
||||
async def rename(tenant_id):
|
||||
"""
|
||||
Rename a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -556,27 +556,27 @@ def rename(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||
file.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=400)
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.BAD_REQUEST)
|
||||
|
||||
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||
if existing_file.name == req["name"]:
|
||||
return get_json_result(data=False, message="Duplicated file name in the same folder.", code=409)
|
||||
|
||||
if not FileService.update_by_id(req["file_id"], {"name": req["name"]}):
|
||||
return get_json_result(message="Database error (File rename)!", code=500)
|
||||
return get_json_result(message="Database error (File rename)!", code=RetCode.SERVER_ERROR)
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(req["file_id"])
|
||||
if informs:
|
||||
if not DocumentService.update_by_id(informs[0].document_id, {"name": req["name"]}):
|
||||
return get_json_result(message="Database error (Document rename)!", code=500)
|
||||
return get_json_result(message="Database error (Document rename)!", code=RetCode.SERVER_ERROR)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
@ -585,12 +585,12 @@ def rename(tenant_id):
|
||||
|
||||
@manager.route('/file/get/<file_id>', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get(tenant_id, file_id):
|
||||
async def get(tenant_id, file_id):
|
||||
"""
|
||||
Download a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
produces:
|
||||
@ -606,20 +606,20 @@ def get(tenant_id, file_id):
|
||||
description: File stream
|
||||
schema:
|
||||
type: file
|
||||
404:
|
||||
RetCode.NOT_FOUND:
|
||||
description: File not found
|
||||
"""
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
if not blob:
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = settings.STORAGE_IMPL.get(b, n)
|
||||
|
||||
response = flask.make_response(blob)
|
||||
response = await make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name)
|
||||
if ext:
|
||||
if file.type == FileType.VISUAL.value:
|
||||
@ -630,15 +630,28 @@ def get(tenant_id, file_id):
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route("/file/download/<attachment_id>", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
async def download_attachment(tenant_id,attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def move(tenant_id):
|
||||
async def move(tenant_id):
|
||||
"""
|
||||
Move one or multiple files to another folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -667,7 +680,7 @@ def move(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
try:
|
||||
file_ids = req["src_file_ids"]
|
||||
parent_id = req["dest_file_id"]
|
||||
@ -677,13 +690,13 @@ def move(tenant_id):
|
||||
for file_id in file_ids:
|
||||
file = files_dict[file_id]
|
||||
if not file:
|
||||
return get_json_result(message="File or Folder not found!", code=404)
|
||||
return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND)
|
||||
if not file.tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
fe, _ = FileService.get_by_id(parent_id)
|
||||
if not fe:
|
||||
return get_json_result(message="Parent Folder not found!", code=404)
|
||||
return get_json_result(message="Parent Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
FileService.move_file(file_ids, parent_id)
|
||||
return get_json_result(data=True)
|
||||
@ -693,8 +706,8 @@ def move(tenant_id):
|
||||
|
||||
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def convert(tenant_id):
|
||||
req = request.json
|
||||
async def convert(tenant_id):
|
||||
req = await get_request_json()
|
||||
kb_ids = req["kb_ids"]
|
||||
file_ids = req["file_ids"]
|
||||
file2documents = []
|
||||
@ -705,7 +718,7 @@ def convert(tenant_id):
|
||||
for file_id in file_ids:
|
||||
file = files_set[file_id]
|
||||
if not file:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
file_ids_list = [file_id]
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
@ -716,13 +729,13 @@ def convert(tenant_id):
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_json_result(
|
||||
message="Database error (Document removal)!", code=404)
|
||||
message="Database error (Document removal)!", code=RetCode.NOT_FOUND)
|
||||
File2DocumentService.delete_by_file_id(id)
|
||||
|
||||
# insert
|
||||
@ -730,11 +743,11 @@ def convert(tenant_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
message="Can't find this knowledgebase!", code=404)
|
||||
message="Can't find this knowledgebase!", code=RetCode.NOT_FOUND)
|
||||
e, file = FileService.get_by_id(id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
message="Can't find this file!", code=404)
|
||||
message="Can't find this file!", code=RetCode.NOT_FOUND)
|
||||
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
|
||||
@ -13,12 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
import tiktoken
|
||||
from flask import Response, jsonify, request
|
||||
from quart import Response, jsonify, request
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.db.db_models import APIToken
|
||||
@ -35,7 +36,7 @@ from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
|
||||
get_result, server_error_response, token_required, validate_request
|
||||
get_result, get_request_json, server_error_response, token_required, validate_request
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
||||
@ -44,8 +45,8 @@ from common import settings
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id, chat_id):
|
||||
req = request.json
|
||||
async def create(tenant_id, chat_id):
|
||||
req = await get_request_json()
|
||||
req["dialog_id"] = chat_id
|
||||
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
||||
if not dia:
|
||||
@ -73,7 +74,7 @@ def create(tenant_id, chat_id):
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create_agent_session(tenant_id, agent_id):
|
||||
async def create_agent_session(tenant_id, agent_id):
|
||||
user_id = request.args.get("user_id", tenant_id)
|
||||
e, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
if not e:
|
||||
@ -97,8 +98,8 @@ def create_agent_session(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, chat_id, session_id):
|
||||
req = request.json
|
||||
async def update(tenant_id, chat_id, session_id):
|
||||
req = await get_request_json()
|
||||
req["dialog_id"] = chat_id
|
||||
conv_id = session_id
|
||||
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
|
||||
@ -119,8 +120,8 @@ def update(tenant_id, chat_id, session_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def chat_completion(tenant_id, chat_id):
|
||||
req = request.json
|
||||
async def chat_completion(tenant_id, chat_id):
|
||||
req = await get_request_json()
|
||||
if not req:
|
||||
req = {"question": ""}
|
||||
if not req.get("session_id"):
|
||||
@ -149,7 +150,7 @@ def chat_completion(tenant_id, chat_id):
|
||||
@manager.route("/chats_openai/<chat_id>/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("model", "messages") # noqa: F821
|
||||
@token_required
|
||||
def chat_completion_openai_like(tenant_id, chat_id):
|
||||
async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
"""
|
||||
OpenAI-like chat completion API that simulates the behavior of OpenAI's completions endpoint.
|
||||
|
||||
@ -206,7 +207,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
if reference:
|
||||
print(completion.choices[0].message.reference)
|
||||
"""
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
|
||||
need_reference = bool(req.get("reference", False))
|
||||
|
||||
@ -383,8 +384,8 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
@manager.route("/agents_openai/<agent_id>/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("model", "messages") # noqa: F821
|
||||
@token_required
|
||||
def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
req = request.json
|
||||
async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
req = await get_request_json()
|
||||
tiktokenenc = tiktoken.get_encoding("cl100k_base")
|
||||
messages = req.get("messages", [])
|
||||
if not messages:
|
||||
@ -428,28 +429,26 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
return resp
|
||||
else:
|
||||
# For non-streaming, just return the response directly
|
||||
response = next(
|
||||
completion_openai(
|
||||
async for response in completion_openai(
|
||||
tenant_id,
|
||||
agent_id,
|
||||
question,
|
||||
session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""),
|
||||
stream=False,
|
||||
**req,
|
||||
)
|
||||
)
|
||||
return jsonify(response)
|
||||
):
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def agent_completions(tenant_id, agent_id):
|
||||
req = request.json
|
||||
async def agent_completions(tenant_id, agent_id):
|
||||
req = await get_request_json()
|
||||
|
||||
if req.get("stream", True):
|
||||
|
||||
def generate():
|
||||
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
async def generate():
|
||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
if isinstance(answer, str):
|
||||
try:
|
||||
ans = json.loads(answer[5:]) # remove "data:"
|
||||
@ -473,7 +472,7 @@ def agent_completions(tenant_id, agent_id):
|
||||
full_content = ""
|
||||
reference = {}
|
||||
final_ans = ""
|
||||
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
try:
|
||||
ans = json.loads(answer[5:])
|
||||
|
||||
@ -493,7 +492,7 @@ def agent_completions(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list_session(tenant_id, chat_id):
|
||||
async def list_session(tenant_id, chat_id):
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
|
||||
id = request.args.get("id")
|
||||
@ -547,7 +546,7 @@ def list_session(tenant_id, chat_id):
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list_agent_session(tenant_id, agent_id):
|
||||
async def list_agent_session(tenant_id, agent_id):
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
return get_error_data_result(message=f"You don't own the agent {agent_id}.")
|
||||
id = request.args.get("id")
|
||||
@ -610,13 +609,13 @@ def list_agent_session(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id, chat_id):
|
||||
async def delete(tenant_id, chat_id):
|
||||
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message="You don't own the chat")
|
||||
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
convs = ConversationService.query(dialog_id=chat_id)
|
||||
if not req:
|
||||
ids = None
|
||||
@ -661,10 +660,10 @@ def delete(tenant_id, chat_id):
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete_agent_session(tenant_id, agent_id):
|
||||
async def delete_agent_session(tenant_id, agent_id):
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
|
||||
if not cvs:
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
@ -716,8 +715,8 @@ def delete_agent_session(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def ask_about(tenant_id):
|
||||
req = request.json
|
||||
async def ask_about(tenant_id):
|
||||
req = await get_request_json()
|
||||
if not req.get("question"):
|
||||
return get_error_data_result("`question` is required.")
|
||||
if not req.get("dataset_ids"):
|
||||
@ -755,8 +754,8 @@ def ask_about(tenant_id):
|
||||
|
||||
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def related_questions(tenant_id):
|
||||
req = request.json
|
||||
async def related_questions(tenant_id):
|
||||
req = await get_request_json()
|
||||
if not req.get("question"):
|
||||
return get_error_data_result("`question` is required.")
|
||||
question = req["question"]
|
||||
@ -789,7 +788,7 @@ Reason:
|
||||
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
||||
|
||||
"""
|
||||
ans = chat_mdl.chat(
|
||||
ans = await chat_mdl.async_chat(
|
||||
prompt,
|
||||
[
|
||||
{
|
||||
@ -806,8 +805,8 @@ Related search terms:
|
||||
|
||||
|
||||
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
|
||||
def chatbot_completions(dialog_id):
|
||||
req = request.json
|
||||
async def chatbot_completions(dialog_id):
|
||||
req = await get_request_json()
|
||||
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
@ -833,7 +832,7 @@ def chatbot_completions(dialog_id):
|
||||
|
||||
|
||||
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
|
||||
def chatbots_inputs(dialog_id):
|
||||
async def chatbots_inputs(dialog_id):
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -856,8 +855,8 @@ def chatbots_inputs(dialog_id):
|
||||
|
||||
|
||||
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
def agent_bot_completions(agent_id):
|
||||
req = request.json
|
||||
async def agent_bot_completions(agent_id):
|
||||
req = await get_request_json()
|
||||
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
@ -875,12 +874,12 @@ def agent_bot_completions(agent_id):
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
|
||||
async for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
|
||||
return get_result(data=answer)
|
||||
|
||||
|
||||
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
|
||||
def begin_inputs(agent_id):
|
||||
async def begin_inputs(agent_id):
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -901,7 +900,7 @@ def begin_inputs(agent_id):
|
||||
|
||||
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
|
||||
@validate_request("question", "kb_ids")
|
||||
def ask_about_embedded():
|
||||
async def ask_about_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -910,7 +909,7 @@ def ask_about_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
uid = objs[0].tenant_id
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
@ -940,7 +939,7 @@ def ask_about_embedded():
|
||||
|
||||
@manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval_test_embedded():
|
||||
async def retrieval_test_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -949,7 +948,7 @@ def retrieval_test_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
@ -965,28 +964,30 @@ def retrieval_test_embedded():
|
||||
use_kg = req.get("use_kg", False)
|
||||
top = int(req.get("top_k", 1024))
|
||||
langs = req.get("cross_languages", [])
|
||||
tenant_ids = []
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
if not tenant_id:
|
||||
return get_error_data_result(message="permission denined.")
|
||||
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
def _retrieval_sync():
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
tenant_ids = []
|
||||
_question = question
|
||||
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, _question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not local_doc_ids:
|
||||
local_doc_ids = ["-999"]
|
||||
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=tenant_id)
|
||||
for kb_id in kb_ids:
|
||||
for tenant in tenants:
|
||||
@ -1002,7 +1003,7 @@ def retrieval_test_embedded():
|
||||
return get_error_data_result(message="Knowledgebase not found!")
|
||||
|
||||
if langs:
|
||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
||||
_question = cross_languages(kb.tenant_id, None, _question, langs)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
@ -1012,15 +1013,15 @@ def retrieval_test_embedded():
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
_question += keyword_extraction(chat_mdl, _question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
labels = label_question(_question, [kb])
|
||||
ranks = settings.retriever.retrieval(
|
||||
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||
)
|
||||
if use_kg:
|
||||
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl,
|
||||
ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
|
||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
@ -1030,6 +1031,9 @@ def retrieval_test_embedded():
|
||||
ranks["labels"] = labels
|
||||
|
||||
return get_json_result(data=ranks)
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_retrieval_sync)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
||||
@ -1039,7 +1043,7 @@ def retrieval_test_embedded():
|
||||
|
||||
@manager.route("/searchbots/related_questions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("question")
|
||||
def related_questions_embedded():
|
||||
async def related_questions_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -1048,7 +1052,7 @@ def related_questions_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
tenant_id = objs[0].tenant_id
|
||||
if not tenant_id:
|
||||
return get_error_data_result(message="permission denined.")
|
||||
@ -1066,7 +1070,7 @@ def related_questions_embedded():
|
||||
|
||||
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
|
||||
prompt = load_prompt("related_question")
|
||||
ans = chat_mdl.chat(
|
||||
ans = await chat_mdl.async_chat(
|
||||
prompt,
|
||||
[
|
||||
{
|
||||
@ -1083,7 +1087,7 @@ Related search terms:
|
||||
|
||||
|
||||
@manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821
|
||||
def detail_share_embedded():
|
||||
async def detail_share_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -1115,7 +1119,7 @@ def detail_share_embedded():
|
||||
|
||||
@manager.route("/searchbots/mindmap", methods=["POST"]) # noqa: F821
|
||||
@validate_request("question", "kb_ids")
|
||||
def mindmap():
|
||||
async def mindmap():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -1125,7 +1129,7 @@ def mindmap():
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
|
||||
@ -14,8 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import request
|
||||
from api.apps import current_user, login_required
|
||||
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from api.db.db_models import DB
|
||||
@ -24,14 +24,14 @@ from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode, StatusEnum
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, get_request_json, server_error_response, validate_request
|
||||
|
||||
|
||||
@manager.route("/create", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.get_json()
|
||||
async def create():
|
||||
req = await get_request_json()
|
||||
search_name = req["name"]
|
||||
description = req.get("description", "")
|
||||
if not isinstance(search_name, str):
|
||||
@ -65,8 +65,8 @@ def create():
|
||||
@login_required
|
||||
@validate_request("search_id", "name", "search_config", "tenant_id")
|
||||
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||
def update():
|
||||
req = request.get_json()
|
||||
async def update():
|
||||
req = await get_request_json()
|
||||
if not isinstance(req["name"], str):
|
||||
return get_data_error_result(message="Search name must be string.")
|
||||
if req["name"].strip() == "":
|
||||
@ -140,7 +140,7 @@ def detail():
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_search_app():
|
||||
async def list_search_app():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
@ -150,7 +150,7 @@ def list_search_app():
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -173,8 +173,8 @@ def list_search_app():
|
||||
@manager.route("/rm", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("search_id")
|
||||
def rm():
|
||||
req = request.get_json()
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
search_id = req["search_id"]
|
||||
if not SearchService.accessible4deletion(search_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
@ -17,7 +17,7 @@ import logging
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import APITokenService
|
||||
@ -34,7 +34,7 @@ from common.time_utils import current_timestamp, datetime_format
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from flask import jsonify
|
||||
from quart import jsonify
|
||||
from api.utils.health_utils import run_health_checks
|
||||
from common import settings
|
||||
|
||||
|
||||
@ -14,10 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.apps import smtp_mail_server
|
||||
from api.db import UserTenantRole
|
||||
from api.db.db_models import UserTenant
|
||||
from api.db.services.user_service import UserTenantService, UserService
|
||||
@ -25,9 +21,10 @@ from api.db.services.user_service import UserTenantService, UserService
|
||||
from common.constants import RetCode, StatusEnum
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import delta_seconds
|
||||
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from api.utils.web_utils import send_invite_email
|
||||
from common import settings
|
||||
from api.apps import smtp_mail_server, login_required, current_user
|
||||
|
||||
|
||||
@manager.route("/<tenant_id>/user/list", methods=["GET"]) # noqa: F821
|
||||
@ -51,14 +48,14 @@ def user_list(tenant_id):
|
||||
@manager.route('/<tenant_id>/user', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("email")
|
||||
def create(tenant_id):
|
||||
async def create(tenant_id):
|
||||
if current_user.id != tenant_id:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
invite_user_email = req["email"]
|
||||
invite_users = UserService.query(email=invite_user_email)
|
||||
if not invite_users:
|
||||
|
||||
@ -22,8 +22,7 @@ import secrets
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from flask import redirect, request, session, make_response
|
||||
from flask_login import current_user, login_required, login_user, logout_user
|
||||
from quart import redirect, request, session, make_response
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
|
||||
from api.apps.auth import get_auth_client
|
||||
@ -40,12 +39,13 @@ from common.connection_utils import construct_response
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
get_request_json,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
)
|
||||
from api.utils.crypt import decrypt
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api.apps import smtp_mail_server
|
||||
from api.apps import smtp_mail_server, login_required, current_user, login_user, logout_user
|
||||
from api.utils.web_utils import (
|
||||
send_email_html,
|
||||
OTP_LENGTH,
|
||||
@ -58,10 +58,11 @@ from api.utils.web_utils import (
|
||||
captcha_key,
|
||||
)
|
||||
from common import settings
|
||||
from common.http_client import async_request
|
||||
|
||||
|
||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||
def login():
|
||||
async def login():
|
||||
"""
|
||||
User login endpoint.
|
||||
---
|
||||
@ -91,10 +92,11 @@ def login():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
if not request.json:
|
||||
json_body = await get_request_json()
|
||||
if not json_body:
|
||||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
|
||||
|
||||
email = request.json.get("email", "")
|
||||
email = json_body.get("email", "")
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
return get_json_result(
|
||||
@ -103,7 +105,7 @@ def login():
|
||||
message=f"Email: {email} is not registered!",
|
||||
)
|
||||
|
||||
password = request.json.get("password")
|
||||
password = json_body.get("password")
|
||||
try:
|
||||
password = decrypt(password)
|
||||
except BaseException:
|
||||
@ -121,11 +123,12 @@ def login():
|
||||
response_data = user.to_json()
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.update_time = (current_timestamp(),)
|
||||
user.update_date = (datetime_format(datetime.now()),)
|
||||
user.update_time = current_timestamp()
|
||||
user.update_date = datetime_format(datetime.now())
|
||||
user.save()
|
||||
msg = "Welcome back!"
|
||||
return construct_response(data=response_data, auth=user.get_id(), message=msg)
|
||||
|
||||
return await construct_response(data=response_data, auth=user.get_id(), message=msg)
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -135,7 +138,7 @@ def login():
|
||||
|
||||
|
||||
@manager.route("/login/channels", methods=["GET"]) # noqa: F821
|
||||
def get_login_channels():
|
||||
async def get_login_channels():
|
||||
"""
|
||||
Get all supported authentication channels.
|
||||
"""
|
||||
@ -156,7 +159,7 @@ def get_login_channels():
|
||||
|
||||
|
||||
@manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
|
||||
def oauth_login(channel):
|
||||
async def oauth_login(channel):
|
||||
channel_config = settings.OAUTH_CONFIG.get(channel)
|
||||
if not channel_config:
|
||||
raise ValueError(f"Invalid channel name: {channel}")
|
||||
@ -169,7 +172,7 @@ def oauth_login(channel):
|
||||
|
||||
|
||||
@manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
|
||||
def oauth_callback(channel):
|
||||
async def oauth_callback(channel):
|
||||
"""
|
||||
Handle the OAuth/OIDC callback for various channels dynamically.
|
||||
"""
|
||||
@ -191,7 +194,10 @@ def oauth_callback(channel):
|
||||
return redirect("/?error=missing_code")
|
||||
|
||||
# Exchange authorization code for access token
|
||||
token_info = auth_cli.exchange_code_for_token(code)
|
||||
if hasattr(auth_cli, "async_exchange_code_for_token"):
|
||||
token_info = await auth_cli.async_exchange_code_for_token(code)
|
||||
else:
|
||||
token_info = auth_cli.exchange_code_for_token(code)
|
||||
access_token = token_info.get("access_token")
|
||||
if not access_token:
|
||||
return redirect("/?error=token_failed")
|
||||
@ -199,7 +205,10 @@ def oauth_callback(channel):
|
||||
id_token = token_info.get("id_token")
|
||||
|
||||
# Fetch user info
|
||||
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
|
||||
if hasattr(auth_cli, "async_fetch_user_info"):
|
||||
user_info = await auth_cli.async_fetch_user_info(access_token, id_token=id_token)
|
||||
else:
|
||||
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
|
||||
if not user_info.email:
|
||||
return redirect("/?error=email_missing")
|
||||
|
||||
@ -258,7 +267,7 @@ def oauth_callback(channel):
|
||||
|
||||
|
||||
@manager.route("/github_callback", methods=["GET"]) # noqa: F821
|
||||
def github_callback():
|
||||
async def github_callback():
|
||||
"""
|
||||
**Deprecated**, Use `/oauth/callback/<channel>` instead.
|
||||
|
||||
@ -278,9 +287,8 @@ def github_callback():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
import requests
|
||||
|
||||
res = requests.post(
|
||||
res = await async_request(
|
||||
"POST",
|
||||
settings.GITHUB_OAUTH.get("url"),
|
||||
data={
|
||||
"client_id": settings.GITHUB_OAUTH.get("client_id"),
|
||||
@ -298,7 +306,7 @@ def github_callback():
|
||||
|
||||
session["access_token"] = res["access_token"]
|
||||
session["access_token_from"] = "github"
|
||||
user_info = user_info_from_github(session["access_token"])
|
||||
user_info = await user_info_from_github(session["access_token"])
|
||||
email_address = user_info["email"]
|
||||
users = UserService.query(email=email_address)
|
||||
user_id = get_uuid()
|
||||
@ -347,7 +355,7 @@ def github_callback():
|
||||
|
||||
|
||||
@manager.route("/feishu_callback", methods=["GET"]) # noqa: F821
|
||||
def feishu_callback():
|
||||
async def feishu_callback():
|
||||
"""
|
||||
Feishu OAuth callback endpoint.
|
||||
---
|
||||
@ -365,9 +373,8 @@ def feishu_callback():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
import requests
|
||||
|
||||
app_access_token_res = requests.post(
|
||||
app_access_token_res = await async_request(
|
||||
"POST",
|
||||
settings.FEISHU_OAUTH.get("app_access_token_url"),
|
||||
data=json.dumps(
|
||||
{
|
||||
@ -381,7 +388,8 @@ def feishu_callback():
|
||||
if app_access_token_res["code"] != 0:
|
||||
return redirect("/?error=%s" % app_access_token_res)
|
||||
|
||||
res = requests.post(
|
||||
res = await async_request(
|
||||
"POST",
|
||||
settings.FEISHU_OAUTH.get("user_access_token_url"),
|
||||
data=json.dumps(
|
||||
{
|
||||
@ -402,7 +410,7 @@ def feishu_callback():
|
||||
return redirect("/?error=contact:user.email:readonly not in scope")
|
||||
session["access_token"] = res["data"]["access_token"]
|
||||
session["access_token_from"] = "feishu"
|
||||
user_info = user_info_from_feishu(session["access_token"])
|
||||
user_info = await user_info_from_feishu(session["access_token"])
|
||||
email_address = user_info["email"]
|
||||
users = UserService.query(email=email_address)
|
||||
user_id = get_uuid()
|
||||
@ -450,36 +458,34 @@ def feishu_callback():
|
||||
return redirect("/?auth=%s" % user.get_id())
|
||||
|
||||
|
||||
def user_info_from_feishu(access_token):
|
||||
import requests
|
||||
|
||||
async def user_info_from_feishu(access_token):
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
|
||||
res = await async_request("GET", "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
|
||||
user_info = res.json()["data"]
|
||||
user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
|
||||
return user_info
|
||||
|
||||
|
||||
def user_info_from_github(access_token):
|
||||
import requests
|
||||
|
||||
async def user_info_from_github(access_token):
|
||||
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
|
||||
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||
res = await async_request("GET", f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||
user_info = res.json()
|
||||
email_info = requests.get(
|
||||
email_info_response = await async_request(
|
||||
"GET",
|
||||
f"https://api.github.com/user/emails?access_token={access_token}",
|
||||
headers=headers,
|
||||
).json()
|
||||
)
|
||||
email_info = email_info_response.json()
|
||||
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||
return user_info
|
||||
|
||||
|
||||
@manager.route("/logout", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def log_out():
|
||||
async def log_out():
|
||||
"""
|
||||
User logout endpoint.
|
||||
---
|
||||
@ -501,7 +507,7 @@ def log_out():
|
||||
|
||||
@manager.route("/setting", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def setting_user():
|
||||
async def setting_user():
|
||||
"""
|
||||
Update user settings.
|
||||
---
|
||||
@ -530,7 +536,7 @@ def setting_user():
|
||||
type: object
|
||||
"""
|
||||
update_dict = {}
|
||||
request_data = request.json
|
||||
request_data = await get_request_json()
|
||||
if request_data.get("password"):
|
||||
new_password = request_data.get("new_password")
|
||||
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||||
@ -569,7 +575,7 @@ def setting_user():
|
||||
|
||||
@manager.route("/info", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def user_profile():
|
||||
async def user_profile():
|
||||
"""
|
||||
Get user profile information.
|
||||
---
|
||||
@ -660,7 +666,7 @@ def user_register(user_id, user):
|
||||
|
||||
@manager.route("/register", methods=["POST"]) # noqa: F821
|
||||
@validate_request("nickname", "email", "password")
|
||||
def user_add():
|
||||
async def user_add():
|
||||
"""
|
||||
Register a new user.
|
||||
---
|
||||
@ -697,7 +703,7 @@ def user_add():
|
||||
code=RetCode.OPERATING_ERROR,
|
||||
)
|
||||
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
email_address = req["email"]
|
||||
|
||||
# Validate the email address
|
||||
@ -737,7 +743,7 @@ def user_add():
|
||||
raise Exception(f"Same email: {email_address} exists!")
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return construct_response(
|
||||
return await construct_response(
|
||||
data=user.to_json(),
|
||||
auth=user.get_id(),
|
||||
message=f"{nickname}, welcome aboard!",
|
||||
@ -754,7 +760,7 @@ def user_add():
|
||||
|
||||
@manager.route("/tenant_info", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def tenant_info():
|
||||
async def tenant_info():
|
||||
"""
|
||||
Get tenant information.
|
||||
---
|
||||
@ -793,7 +799,7 @@ def tenant_info():
|
||||
@manager.route("/set_tenant_info", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
|
||||
def set_tenant_info():
|
||||
async def set_tenant_info():
|
||||
"""
|
||||
Update tenant information.
|
||||
---
|
||||
@ -830,17 +836,17 @@ def set_tenant_info():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req = request.json
|
||||
req = await get_request_json()
|
||||
try:
|
||||
tid = req.pop("tenant_id")
|
||||
TenantService.update_by_id(tid, req)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
|
||||
@manager.route("/forget/captcha", methods=["GET"]) # noqa: F821
|
||||
def forget_get_captcha():
|
||||
async def forget_get_captcha():
|
||||
"""
|
||||
GET /forget/captcha?email=<email>
|
||||
- Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = OTP_TTL_SECONDS.
|
||||
@ -862,19 +868,19 @@ def forget_get_captcha():
|
||||
from captcha.image import ImageCaptcha
|
||||
image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70])
|
||||
img_bytes = image.generate(captcha_text).read()
|
||||
response = make_response(img_bytes)
|
||||
response = await make_response(img_bytes)
|
||||
response.headers.set("Content-Type", "image/JPEG")
|
||||
return response
|
||||
|
||||
|
||||
@manager.route("/forget/otp", methods=["POST"]) # noqa: F821
|
||||
def forget_send_otp():
|
||||
async def forget_send_otp():
|
||||
"""
|
||||
POST /forget/otp
|
||||
- Verify the image captcha stored at captcha:{email} (case-insensitive).
|
||||
- On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
|
||||
"""
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
email = req.get("email") or ""
|
||||
captcha = (req.get("captcha") or "").strip()
|
||||
|
||||
@ -930,17 +936,17 @@ def forget_send_otp():
|
||||
)
|
||||
except Exception:
|
||||
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to send email")
|
||||
|
||||
|
||||
return get_json_result(data=True, code=RetCode.SUCCESS, message="verification passed, email sent")
|
||||
|
||||
|
||||
@manager.route("/forget", methods=["POST"]) # noqa: F821
|
||||
def forget():
|
||||
async def forget():
|
||||
"""
|
||||
POST: Verify email + OTP and reset password, then log the user in.
|
||||
Request JSON: { email, otp, new_password, confirm_new_password }
|
||||
"""
|
||||
req = request.get_json()
|
||||
req = await get_request_json()
|
||||
email = req.get("email") or ""
|
||||
otp = (req.get("otp") or "").strip()
|
||||
new_pwd = req.get("new_password")
|
||||
@ -1001,8 +1007,8 @@ def forget():
|
||||
# Auto login (reuse login flow)
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.update_time = (current_timestamp(),)
|
||||
user.update_date = (datetime_format(datetime.now()),)
|
||||
user.update_time = current_timestamp()
|
||||
user.update_date = datetime_format(datetime.now())
|
||||
user.save()
|
||||
msg = "Password reset successful. Logged in."
|
||||
return construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
|
||||
return await construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
|
||||
|
||||
@ -25,7 +25,7 @@ from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import UserMixin
|
||||
from quart_auth import AuthUser
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||
@ -305,6 +305,7 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
|
||||
@ -594,7 +595,7 @@ def fill_db_model_object(model_object, human_model_dict):
|
||||
return model_object
|
||||
|
||||
|
||||
class User(DataBaseModel, UserMixin):
|
||||
class User(DataBaseModel, AuthUser):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
access_token = CharField(max_length=255, null=True, index=True)
|
||||
nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
|
||||
@ -748,7 +749,7 @@ class Knowledgebase(DataBaseModel):
|
||||
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0})
|
||||
pagerank = IntegerField(default=0, index=False)
|
||||
|
||||
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
|
||||
@ -772,8 +773,8 @@ class Document(DataBaseModel):
|
||||
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
||||
kb_id = CharField(max_length=256, null=False, index=True)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="pipleline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="pipeline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0})
|
||||
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
|
||||
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
|
||||
created_by = CharField(max_length=32, null=False, help_text="who created it", index=True)
|
||||
@ -876,7 +877,7 @@ class Dialog(DataBaseModel):
|
||||
class Conversation(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||
name = CharField(max_length=255, null=True, help_text="converastion name", index=True)
|
||||
name = CharField(max_length=255, null=True, help_text="conversation name", index=True)
|
||||
message = JSONField(null=True)
|
||||
reference = JSONField(null=True, default=[])
|
||||
user_id = CharField(max_length=255, null=True, help_text="user_id", index=True)
|
||||
@ -1112,6 +1113,70 @@ class SyncLogs(DataBaseModel):
|
||||
db_table = "sync_logs"
|
||||
|
||||
|
||||
class EvaluationDataset(DataBaseModel):
|
||||
"""Ground truth dataset for RAG evaluation"""
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
tenant_id = CharField(max_length=32, null=False, index=True, help_text="tenant ID")
|
||||
name = CharField(max_length=255, null=False, index=True, help_text="dataset name")
|
||||
description = TextField(null=True, help_text="dataset description")
|
||||
kb_ids = JSONField(null=False, help_text="knowledge base IDs to evaluate against")
|
||||
created_by = CharField(max_length=32, null=False, index=True, help_text="creator user ID")
|
||||
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
|
||||
update_time = BigIntegerField(null=False, help_text="last update timestamp")
|
||||
status = IntegerField(null=False, default=1, help_text="1=valid, 0=invalid")
|
||||
|
||||
class Meta:
|
||||
db_table = "evaluation_datasets"
|
||||
|
||||
|
||||
class EvaluationCase(DataBaseModel):
|
||||
"""Individual test case in an evaluation dataset"""
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
|
||||
question = TextField(null=False, help_text="test question")
|
||||
reference_answer = TextField(null=True, help_text="optional ground truth answer")
|
||||
relevant_doc_ids = JSONField(null=True, help_text="expected relevant document IDs")
|
||||
relevant_chunk_ids = JSONField(null=True, help_text="expected relevant chunk IDs")
|
||||
metadata = JSONField(null=True, help_text="additional context/tags")
|
||||
create_time = BigIntegerField(null=False, help_text="creation timestamp")
|
||||
|
||||
class Meta:
|
||||
db_table = "evaluation_cases"
|
||||
|
||||
|
||||
class EvaluationRun(DataBaseModel):
|
||||
"""A single evaluation run"""
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
|
||||
dialog_id = CharField(max_length=32, null=False, index=True, help_text="dialog configuration being evaluated")
|
||||
name = CharField(max_length=255, null=False, help_text="run name")
|
||||
config_snapshot = JSONField(null=False, help_text="dialog config at time of evaluation")
|
||||
metrics_summary = JSONField(null=True, help_text="aggregated metrics")
|
||||
status = CharField(max_length=32, null=False, default="PENDING", help_text="PENDING/RUNNING/COMPLETED/FAILED")
|
||||
created_by = CharField(max_length=32, null=False, index=True, help_text="user who started the run")
|
||||
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
|
||||
complete_time = BigIntegerField(null=True, help_text="completion timestamp")
|
||||
|
||||
class Meta:
|
||||
db_table = "evaluation_runs"
|
||||
|
||||
|
||||
class EvaluationResult(DataBaseModel):
|
||||
"""Result for a single test case in an evaluation run"""
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
run_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_runs")
|
||||
case_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_cases")
|
||||
generated_answer = TextField(null=False, help_text="generated answer")
|
||||
retrieved_chunks = JSONField(null=False, help_text="chunks that were retrieved")
|
||||
metrics = JSONField(null=False, help_text="all computed metrics")
|
||||
execution_time = FloatField(null=False, help_text="response time in seconds")
|
||||
token_usage = JSONField(null=True, help_text="prompt/completion tokens")
|
||||
create_time = BigIntegerField(null=False, help_text="creation timestamp")
|
||||
|
||||
class Meta:
|
||||
db_table = "evaluation_results"
|
||||
|
||||
|
||||
def migrate_db():
|
||||
logging.disable(logging.ERROR)
|
||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||
@ -1292,4 +1357,43 @@ def migrate_db():
|
||||
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# RAG Evaluation tables
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "tenant_id", CharField(max_length=32, null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "name", CharField(max_length=255, null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "description", TextField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "kb_ids", JSONField(null=False)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "created_by", CharField(max_length=32, null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "create_time", BigIntegerField(null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "update_time", BigIntegerField(null=False)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1)))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
@ -34,14 +34,17 @@ from common.file_utils import get_project_base_directory
|
||||
from common import settings
|
||||
from api.common.base64 import encode_to_base64
|
||||
|
||||
DEFAULT_SUPERUSER_NICKNAME = os.getenv("DEFAULT_SUPERUSER_NICKNAME", "admin")
|
||||
DEFAULT_SUPERUSER_EMAIL = os.getenv("DEFAULT_SUPERUSER_EMAIL", "admin@ragflow.io")
|
||||
DEFAULT_SUPERUSER_PASSWORD = os.getenv("DEFAULT_SUPERUSER_PASSWORD", "admin")
|
||||
|
||||
def init_superuser():
|
||||
def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_EMAIL, password=DEFAULT_SUPERUSER_PASSWORD, role=UserTenantRole.OWNER):
|
||||
user_info = {
|
||||
"id": uuid.uuid1().hex,
|
||||
"password": encode_to_base64("admin"),
|
||||
"nickname": "admin",
|
||||
"password": encode_to_base64(password),
|
||||
"nickname": nickname,
|
||||
"is_superuser": True,
|
||||
"email": "admin@ragflow.io",
|
||||
"email": email,
|
||||
"creator": "system",
|
||||
"status": "1",
|
||||
}
|
||||
@ -58,7 +61,7 @@ def init_superuser():
|
||||
"tenant_id": user_info["id"],
|
||||
"user_id": user_info["id"],
|
||||
"invited_by": user_info["id"],
|
||||
"role": UserTenantRole.OWNER
|
||||
"role": role
|
||||
}
|
||||
|
||||
tenant_llm = get_init_tenant_llm(user_info["id"])
|
||||
@ -70,7 +73,7 @@ def init_superuser():
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
logging.info(
|
||||
"Super user initialized. email: admin@ragflow.io, password: admin. Changing the password after login is strongly recommended.")
|
||||
f"Super user initialized. email: {email}, password: {password}. Changing the password after login is strongly recommended.")
|
||||
|
||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||
msg = chat_mdl.chat(system="", history=[
|
||||
|
||||
@ -177,7 +177,7 @@ class UserCanvasService(CommonService):
|
||||
return True
|
||||
|
||||
|
||||
def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
async def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
query = kwargs.get("query", "") or kwargs.get("question", "")
|
||||
files = kwargs.get("files", [])
|
||||
inputs = kwargs.get("inputs", {})
|
||||
@ -219,10 +219,14 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
"id": message_id
|
||||
})
|
||||
txt = ""
|
||||
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
ans["session_id"] = session_id
|
||||
if ans["event"] == "message":
|
||||
txt += ans["data"]["content"]
|
||||
if ans["data"].get("start_to_think", False):
|
||||
txt += "<think>"
|
||||
elif ans["data"].get("end_to_think", False):
|
||||
txt += "</think>"
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
conv.message.append({"role": "assistant", "content": txt, "created_at": time.time(), "id": message_id})
|
||||
@ -233,7 +237,7 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
API4ConversationService.append_message(conv["id"], conv)
|
||||
|
||||
|
||||
def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
|
||||
async def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
|
||||
tiktoken_encoder = tiktoken.get_encoding("cl100k_base")
|
||||
prompt_tokens = len(tiktoken_encoder.encode(str(question)))
|
||||
user_id = kwargs.get("user_id", "")
|
||||
@ -241,7 +245,7 @@ def completion_openai(tenant_id, agent_id, question, session_id=None, stream=Tru
|
||||
if stream:
|
||||
completion_tokens = 0
|
||||
try:
|
||||
for ans in completion(
|
||||
async for ans in completion(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
@ -300,7 +304,7 @@ def completion_openai(tenant_id, agent_id, question, session_id=None, stream=Tru
|
||||
try:
|
||||
all_content = ""
|
||||
reference = {}
|
||||
for ans in completion(
|
||||
async for ans in completion(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import os
|
||||
from typing import Tuple, List
|
||||
|
||||
from anthropic import BaseModel
|
||||
@ -24,7 +25,6 @@ from api.db import InputType
|
||||
from api.db.db_models import Connector, SyncLogs, Connector2Kb, Knowledgebase
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import TaskStatus
|
||||
from common.time_utils import current_timestamp, timestamp_to_date
|
||||
@ -68,9 +68,10 @@ class ConnectorService(CommonService):
|
||||
|
||||
@classmethod
|
||||
def rebuild(cls, kb_id:str, connector_id: str, tenant_id:str):
|
||||
from api.db.services.file_service import FileService
|
||||
e, conn = cls.get_by_id(connector_id)
|
||||
if not e:
|
||||
return
|
||||
return None
|
||||
SyncLogsService.filter_delete([SyncLogs.connector_id==connector_id, SyncLogs.kb_id==kb_id])
|
||||
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}", kb_id=kb_id)
|
||||
err = FileService.delete_docs([d.id for d in docs], tenant_id)
|
||||
@ -103,7 +104,8 @@ class SyncLogsService(CommonService):
|
||||
Knowledgebase.avatar.alias("kb_avatar"),
|
||||
Connector2Kb.auto_parse,
|
||||
cls.model.from_beginning.alias("reindex"),
|
||||
cls.model.status
|
||||
cls.model.status,
|
||||
cls.model.update_time
|
||||
]
|
||||
if not connector_id:
|
||||
fields.append(Connector.config)
|
||||
@ -116,7 +118,11 @@ class SyncLogsService(CommonService):
|
||||
if connector_id:
|
||||
query = query.where(cls.model.connector_id == connector_id)
|
||||
else:
|
||||
interval_expr = SQL("INTERVAL `t2`.`refresh_freq` MINUTE")
|
||||
database_type = os.getenv("DB_TYPE", "mysql")
|
||||
if "postgres" in database_type.lower():
|
||||
interval_expr = SQL("make_interval(mins => t2.refresh_freq)")
|
||||
else:
|
||||
interval_expr = SQL("INTERVAL `t2`.`refresh_freq` MINUTE")
|
||||
query = query.where(
|
||||
Connector.input_type == InputType.POLL,
|
||||
Connector.status == TaskStatus.SCHEDULE,
|
||||
@ -125,11 +131,11 @@ class SyncLogsService(CommonService):
|
||||
)
|
||||
|
||||
query = query.distinct().order_by(cls.model.update_time.desc())
|
||||
totbal = query.count()
|
||||
total = query.count()
|
||||
if page_number:
|
||||
query = query.paginate(page_number, items_per_page)
|
||||
|
||||
return list(query.dicts()), totbal
|
||||
return list(query.dicts()), total
|
||||
|
||||
@classmethod
|
||||
def start(cls, id, connector_id):
|
||||
@ -191,6 +197,7 @@ class SyncLogsService(CommonService):
|
||||
|
||||
@classmethod
|
||||
def duplicate_and_parse(cls, kb, docs, tenant_id, src, auto_parse=True):
|
||||
from api.db.services.file_service import FileService
|
||||
if not docs:
|
||||
return None
|
||||
|
||||
@ -207,9 +214,21 @@ class SyncLogsService(CommonService):
|
||||
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
|
||||
errs.extend(err)
|
||||
|
||||
# Create a mapping from filename to metadata for later use
|
||||
metadata_map = {}
|
||||
for d in docs:
|
||||
if d.get("metadata"):
|
||||
filename = d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else "")
|
||||
metadata_map[filename] = d["metadata"]
|
||||
|
||||
kb_table_num_map = {}
|
||||
for doc, _ in doc_blob_pairs:
|
||||
doc_ids.append(doc["id"])
|
||||
|
||||
# Set metadata if available for this document
|
||||
if doc["name"] in metadata_map:
|
||||
DocumentService.update_by_id(doc["id"], {"meta_fields": metadata_map[doc["name"]]})
|
||||
|
||||
if not auto_parse or auto_parse == "0":
|
||||
continue
|
||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||
@ -242,7 +261,7 @@ class Connector2KbService(CommonService):
|
||||
"id": get_uuid(),
|
||||
"connector_id": conn_id,
|
||||
"kb_id": kb_id,
|
||||
"auto_parse": conn.get("auto_parse", "1")
|
||||
"auto_parse": conn.get("auto_parse", "1")
|
||||
})
|
||||
SyncLogsService.schedule(conn_id, kb_id, reindex=True)
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ import trio
|
||||
from langfuse import Langfuse
|
||||
from peewee import fn
|
||||
from agentic_reasoning import DeepResearcher
|
||||
from api.db.services.file_service import FileService
|
||||
from common.constants import LLMType, ParserType, StatusEnum
|
||||
from api.db.db_models import DB, Dialog
|
||||
from api.db.services.common_service import CommonService
|
||||
@ -178,6 +179,9 @@ class DialogService(CommonService):
|
||||
return res
|
||||
|
||||
def chat_solo(dialog, messages, stream=True):
|
||||
attachments = ""
|
||||
if "files" in messages[-1]:
|
||||
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
@ -188,6 +192,8 @@ def chat_solo(dialog, messages, stream=True):
|
||||
if prompt_config.get("tts"):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
|
||||
if attachments and msg:
|
||||
msg[-1]["content"] += attachments
|
||||
if stream:
|
||||
last_ans = ""
|
||||
delta_ans = ""
|
||||
@ -287,7 +293,7 @@ def convert_conditions(metadata_condition):
|
||||
]
|
||||
|
||||
|
||||
def meta_filter(metas: dict, filters: list[dict]):
|
||||
def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
|
||||
doc_ids = set([])
|
||||
|
||||
def filter_out(v2docs, operator, value):
|
||||
@ -304,6 +310,8 @@ def meta_filter(metas: dict, filters: list[dict]):
|
||||
for conds in [
|
||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||
(operator == "in", str(input).lower() in str(value).lower()),
|
||||
(operator == "not in", str(input).lower() not in str(value).lower()),
|
||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||
(operator == "empty", not input),
|
||||
@ -331,7 +339,10 @@ def meta_filter(metas: dict, filters: list[dict]):
|
||||
if not doc_ids:
|
||||
doc_ids = set(ids)
|
||||
else:
|
||||
doc_ids = doc_ids & set(ids)
|
||||
if logic == "and":
|
||||
doc_ids = doc_ids & set(ids)
|
||||
else:
|
||||
doc_ids = doc_ids | set(ids)
|
||||
if not doc_ids:
|
||||
return []
|
||||
return list(doc_ids)
|
||||
@ -342,7 +353,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||
for ans in chat_solo(dialog, messages, stream):
|
||||
yield ans
|
||||
return
|
||||
return None
|
||||
|
||||
chat_start_ts = timer()
|
||||
|
||||
@ -375,8 +386,11 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
retriever = settings.retriever
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
||||
attachments_= ""
|
||||
if "doc_ids" in messages[-1]:
|
||||
attachments = messages[-1]["doc_ids"]
|
||||
if "files" in messages[-1]:
|
||||
attachments_ = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
@ -386,7 +400,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||
if ans:
|
||||
yield ans
|
||||
return
|
||||
return None
|
||||
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["key"] == "knowledge":
|
||||
@ -407,14 +421,15 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if dialog.meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
|
||||
if dialog.meta_data_filter.get("method") == "auto":
|
||||
filters = gen_meta_filter(chat_mdl, metas, questions[-1])
|
||||
attachments.extend(meta_filter(metas, filters))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, questions[-1])
|
||||
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not attachments:
|
||||
attachments = None
|
||||
elif dialog.meta_data_filter.get("method") == "manual":
|
||||
attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"]))
|
||||
if not attachments:
|
||||
attachments = None
|
||||
conds = dialog.meta_data_filter["manual"]
|
||||
attachments.extend(meta_filter(metas, conds, dialog.meta_data_filter.get("logic", "and")))
|
||||
if conds and not attachments:
|
||||
attachments = ["-999"]
|
||||
|
||||
if prompt_config.get("keyword", False):
|
||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||||
@ -445,7 +460,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
),
|
||||
)
|
||||
|
||||
for think in reasoner.thinking(kbinfos, " ".join(questions)):
|
||||
for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
|
||||
if isinstance(think, str):
|
||||
thought = think
|
||||
knowledges = [t for t in think.split("\n") if t]
|
||||
@ -472,6 +487,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids)
|
||||
if prompt_config.get("tavily_api_key"):
|
||||
tav = Tavily(prompt_config["tavily_api_key"])
|
||||
tav_res = tav.retrieve_chunks(" ".join(questions))
|
||||
@ -497,7 +513,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||
gen_conf = dialog.llm_setting
|
||||
|
||||
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
||||
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}]
|
||||
prompt4citation = ""
|
||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||
prompt4citation = citation_prompt()
|
||||
@ -617,6 +633,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
res["audio_binary"] = tts(tts_mdl, answer)
|
||||
yield res
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||
sys_prompt = """
|
||||
@ -664,7 +682,11 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
if kb_ids:
|
||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||
if "where" not in sql.lower():
|
||||
sql += f" WHERE {kb_filter}"
|
||||
o = sql.lower().split("order by")
|
||||
if len(o) > 1:
|
||||
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
|
||||
else:
|
||||
sql += f" WHERE {kb_filter}"
|
||||
else:
|
||||
sql += f" AND {kb_filter}"
|
||||
|
||||
@ -672,10 +694,9 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
tried_times += 1
|
||||
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
||||
|
||||
tbl, sql = get_table()
|
||||
if tbl is None:
|
||||
return None
|
||||
if tbl.get("error") and tried_times <= 2:
|
||||
try:
|
||||
tbl, sql = get_table()
|
||||
except Exception as e:
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
Table of database fields are as follows:
|
||||
@ -689,16 +710,14 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
The SQL error you provided last time is as follows:
|
||||
{}
|
||||
|
||||
Error issued by database as follows:
|
||||
{}
|
||||
|
||||
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
|
||||
tbl, sql = get_table()
|
||||
logging.debug("TRY it again: {}".format(sql))
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
|
||||
try:
|
||||
tbl, sql = get_table()
|
||||
except Exception:
|
||||
return
|
||||
|
||||
logging.debug("GET table: {}".format(tbl))
|
||||
if tbl.get("error") or len(tbl["rows"]) == 0:
|
||||
if len(tbl["rows"]) == 0:
|
||||
return None
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||
@ -742,13 +761,48 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
"prompt": sys_prompt,
|
||||
}
|
||||
|
||||
def clean_tts_text(text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||||
|
||||
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||
|
||||
emoji_pattern = re.compile(
|
||||
"[\U0001F600-\U0001F64F"
|
||||
"\U0001F300-\U0001F5FF"
|
||||
"\U0001F680-\U0001F6FF"
|
||||
"\U0001F1E0-\U0001F1FF"
|
||||
"\U00002700-\U000027BF"
|
||||
"\U0001F900-\U0001F9FF"
|
||||
"\U0001FA70-\U0001FAFF"
|
||||
"\U0001FAD0-\U0001FAFF]+",
|
||||
flags=re.UNICODE
|
||||
)
|
||||
text = emoji_pattern.sub("", text)
|
||||
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
MAX_LEN = 500
|
||||
if len(text) > MAX_LEN:
|
||||
text = text[:MAX_LEN]
|
||||
|
||||
return text
|
||||
|
||||
def tts(tts_mdl, text):
|
||||
if not tts_mdl or not text:
|
||||
return
|
||||
return None
|
||||
text = clean_tts_text(text)
|
||||
if not text:
|
||||
return None
|
||||
bin = b""
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
try:
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
except Exception as e:
|
||||
logging.error(f"TTS failed: {e}, text={text!r}")
|
||||
return None
|
||||
return binascii.hexlify(bin).decode("utf-8")
|
||||
|
||||
|
||||
@ -776,14 +830,14 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
|
||||
kbinfos = retriever.retrieval(
|
||||
question=question,
|
||||
@ -851,14 +905,14 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
|
||||
ranks = settings.retriever.retrieval(
|
||||
question=question,
|
||||
|
||||
@ -41,6 +41,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common import settings
|
||||
|
||||
|
||||
class DocumentService(CommonService):
|
||||
model = Document
|
||||
|
||||
@ -113,7 +114,7 @@ class DocumentService(CommonService):
|
||||
def check_doc_health(cls, tenant_id: str, filename):
|
||||
import os
|
||||
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
|
||||
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(tenant_id) >= MAX_FILE_NUM_PER_USER:
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(tenant_id):
|
||||
raise RuntimeError("Exceed the maximum file number of a free user!")
|
||||
if len(filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
raise RuntimeError("Exceed the maximum length of file name!")
|
||||
@ -309,7 +310,7 @@ class DocumentService(CommonService):
|
||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
page * page_size, page_size, search.index_name(tenant_id),
|
||||
[doc.kb_id])
|
||||
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
|
||||
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
|
||||
if not chunk_ids:
|
||||
break
|
||||
all_chunk_ids.extend(chunk_ids)
|
||||
@ -322,7 +323,7 @@ class DocumentService(CommonService):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
graph_source = settings.docStoreConn.getFields(
|
||||
graph_source = settings.docStoreConn.get_fields(
|
||||
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||
)
|
||||
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
|
||||
@ -464,7 +465,7 @@ class DocumentService(CommonService):
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["tenant_id"]
|
||||
|
||||
@classmethod
|
||||
@ -473,7 +474,7 @@ class DocumentService(CommonService):
|
||||
docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["kb_id"]
|
||||
|
||||
@classmethod
|
||||
@ -486,7 +487,7 @@ class DocumentService(CommonService):
|
||||
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["tenant_id"]
|
||||
|
||||
@classmethod
|
||||
@ -533,7 +534,7 @@ class DocumentService(CommonService):
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["embd_id"]
|
||||
|
||||
@classmethod
|
||||
@ -569,7 +570,7 @@ class DocumentService(CommonService):
|
||||
.where(cls.model.name == doc_name)
|
||||
doc_id = doc_id.dicts()
|
||||
if not doc_id:
|
||||
return
|
||||
return None
|
||||
return doc_id[0]["id"]
|
||||
|
||||
@classmethod
|
||||
@ -715,13 +716,17 @@ class DocumentService(CommonService):
|
||||
prg = 1
|
||||
status = TaskStatus.DONE.value
|
||||
|
||||
# only for special task and parsed docs and unfinised
|
||||
# only for special task and parsed docs and unfinished
|
||||
freeze_progress = special_task_running and doc_progress >= 1 and not finished
|
||||
msg = "\n".join(sorted(msg))
|
||||
begin_at = d.get("process_begin_at")
|
||||
if not begin_at:
|
||||
begin_at = datetime.now()
|
||||
# fallback
|
||||
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
|
||||
|
||||
info = {
|
||||
"process_duration": datetime.timestamp(
|
||||
datetime.now()) -
|
||||
d["process_begin_at"].timestamp(),
|
||||
"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0),
|
||||
"run": status}
|
||||
if prg != 0 and not freeze_progress:
|
||||
info["progress"] = prg
|
||||
@ -922,7 +927,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
ParserType.AUDIO.value: audio,
|
||||
ParserType.EMAIL.value: email
|
||||
}
|
||||
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
|
||||
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0}
|
||||
exe = ThreadPoolExecutor(max_workers=12)
|
||||
threads = []
|
||||
doc_nm = {}
|
||||
@ -974,13 +979,13 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
|
||||
def embedding(doc_id, cnts, batch_size=16):
|
||||
nonlocal embd_mdl, chunk_counts, token_counts
|
||||
vects = []
|
||||
vectors = []
|
||||
for i in range(0, len(cnts), batch_size):
|
||||
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
|
||||
vects.extend(vts.tolist())
|
||||
vectors.extend(vts.tolist())
|
||||
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
|
||||
token_counts[doc_id] += c
|
||||
return vects
|
||||
return vectors
|
||||
|
||||
idxnm = search.index_name(kb.tenant_id)
|
||||
try_create_idx = True
|
||||
@ -1011,15 +1016,15 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
except Exception:
|
||||
logging.exception("Mind map generation error")
|
||||
|
||||
vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
|
||||
assert len(cks) == len(vects)
|
||||
vectors = embedding(doc_id, [c["content_with_weight"] for c in cks])
|
||||
assert len(cks) == len(vectors)
|
||||
for i, d in enumerate(cks):
|
||||
v = vects[i]
|
||||
v = vectors[i]
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0]))
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
|
||||
598
api/db/services/evaluation_service.py
Normal file
598
api/db/services/evaluation_service.py
Normal file
@ -0,0 +1,598 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
RAG Evaluation Service
|
||||
|
||||
Provides functionality for evaluating RAG system performance including:
|
||||
- Dataset management
|
||||
- Test case management
|
||||
- Evaluation execution
|
||||
- Metrics computation
|
||||
- Configuration recommendations
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.dialog_service import DialogService, chat
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp
|
||||
from common.constants import StatusEnum
|
||||
|
||||
|
||||
class EvaluationService(CommonService):
|
||||
"""Service for managing RAG evaluations"""
|
||||
|
||||
model = EvaluationDataset
|
||||
|
||||
# ==================== Dataset Management ====================
|
||||
|
||||
@classmethod
|
||||
def create_dataset(cls, name: str, description: str, kb_ids: List[str],
|
||||
tenant_id: str, user_id: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Create a new evaluation dataset.
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
description: Dataset description
|
||||
kb_ids: List of knowledge base IDs to evaluate against
|
||||
tenant_id: Tenant ID
|
||||
user_id: User ID who creates the dataset
|
||||
|
||||
Returns:
|
||||
(success, dataset_id or error_message)
|
||||
"""
|
||||
try:
|
||||
dataset_id = get_uuid()
|
||||
dataset = {
|
||||
"id": dataset_id,
|
||||
"tenant_id": tenant_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"kb_ids": kb_ids,
|
||||
"created_by": user_id,
|
||||
"create_time": current_timestamp(),
|
||||
"update_time": current_timestamp(),
|
||||
"status": StatusEnum.VALID.value
|
||||
}
|
||||
|
||||
if not EvaluationDataset.create(**dataset):
|
||||
return False, "Failed to create dataset"
|
||||
|
||||
return True, dataset_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error creating evaluation dataset: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@classmethod
|
||||
def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get dataset by ID"""
|
||||
try:
|
||||
dataset = EvaluationDataset.get_by_id(dataset_id)
|
||||
if dataset:
|
||||
return dataset.to_dict()
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting dataset {dataset_id}: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def list_datasets(cls, tenant_id: str, user_id: str,
|
||||
page: int = 1, page_size: int = 20) -> Dict[str, Any]:
|
||||
"""List datasets for a tenant"""
|
||||
try:
|
||||
query = EvaluationDataset.select().where(
|
||||
(EvaluationDataset.tenant_id == tenant_id) &
|
||||
(EvaluationDataset.status == StatusEnum.VALID.value)
|
||||
).order_by(EvaluationDataset.create_time.desc())
|
||||
|
||||
total = query.count()
|
||||
datasets = query.paginate(page, page_size)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"datasets": [d.to_dict() for d in datasets]
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error listing datasets: {e}")
|
||||
return {"total": 0, "datasets": []}
|
||||
|
||||
@classmethod
|
||||
def update_dataset(cls, dataset_id: str, **kwargs) -> bool:
|
||||
"""Update dataset"""
|
||||
try:
|
||||
kwargs["update_time"] = current_timestamp()
|
||||
return EvaluationDataset.update(**kwargs).where(
|
||||
EvaluationDataset.id == dataset_id
|
||||
).execute() > 0
|
||||
except Exception as e:
|
||||
logging.error(f"Error updating dataset {dataset_id}: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def delete_dataset(cls, dataset_id: str) -> bool:
|
||||
"""Soft delete dataset"""
|
||||
try:
|
||||
return EvaluationDataset.update(
|
||||
status=StatusEnum.INVALID.value,
|
||||
update_time=current_timestamp()
|
||||
).where(EvaluationDataset.id == dataset_id).execute() > 0
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting dataset {dataset_id}: {e}")
|
||||
return False
|
||||
|
||||
# ==================== Test Case Management ====================
|
||||
|
||||
@classmethod
|
||||
def add_test_case(cls, dataset_id: str, question: str,
|
||||
reference_answer: Optional[str] = None,
|
||||
relevant_doc_ids: Optional[List[str]] = None,
|
||||
relevant_chunk_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
Add a test case to a dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
question: Test question
|
||||
reference_answer: Optional ground truth answer
|
||||
relevant_doc_ids: Optional list of relevant document IDs
|
||||
relevant_chunk_ids: Optional list of relevant chunk IDs
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
(success, case_id or error_message)
|
||||
"""
|
||||
try:
|
||||
case_id = get_uuid()
|
||||
case = {
|
||||
"id": case_id,
|
||||
"dataset_id": dataset_id,
|
||||
"question": question,
|
||||
"reference_answer": reference_answer,
|
||||
"relevant_doc_ids": relevant_doc_ids,
|
||||
"relevant_chunk_ids": relevant_chunk_ids,
|
||||
"metadata": metadata,
|
||||
"create_time": current_timestamp()
|
||||
}
|
||||
|
||||
if not EvaluationCase.create(**case):
|
||||
return False, "Failed to create test case"
|
||||
|
||||
return True, case_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding test case: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@classmethod
|
||||
def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all test cases for a dataset"""
|
||||
try:
|
||||
cases = EvaluationCase.select().where(
|
||||
EvaluationCase.dataset_id == dataset_id
|
||||
).order_by(EvaluationCase.create_time)
|
||||
|
||||
return [c.to_dict() for c in cases]
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting test cases for dataset {dataset_id}: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def delete_test_case(cls, case_id: str) -> bool:
|
||||
"""Delete a test case"""
|
||||
try:
|
||||
return EvaluationCase.delete().where(
|
||||
EvaluationCase.id == case_id
|
||||
).execute() > 0
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting test case {case_id}: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]:
|
||||
"""
|
||||
Bulk import test cases from a list.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
cases: List of test case dictionaries
|
||||
|
||||
Returns:
|
||||
(success_count, failure_count)
|
||||
"""
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
|
||||
for case_data in cases:
|
||||
success, _ = cls.add_test_case(
|
||||
dataset_id=dataset_id,
|
||||
question=case_data.get("question", ""),
|
||||
reference_answer=case_data.get("reference_answer"),
|
||||
relevant_doc_ids=case_data.get("relevant_doc_ids"),
|
||||
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
|
||||
metadata=case_data.get("metadata")
|
||||
)
|
||||
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
failure_count += 1
|
||||
|
||||
return success_count, failure_count
|
||||
|
||||
# ==================== Evaluation Execution ====================
|
||||
|
||||
@classmethod
|
||||
def start_evaluation(cls, dataset_id: str, dialog_id: str,
|
||||
user_id: str, name: Optional[str] = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
Start an evaluation run.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
dialog_id: Dialog configuration to evaluate
|
||||
user_id: User ID who starts the run
|
||||
name: Optional run name
|
||||
|
||||
Returns:
|
||||
(success, run_id or error_message)
|
||||
"""
|
||||
try:
|
||||
# Get dialog configuration
|
||||
success, dialog = DialogService.get_by_id(dialog_id)
|
||||
if not success:
|
||||
return False, "Dialog not found"
|
||||
|
||||
# Create evaluation run
|
||||
run_id = get_uuid()
|
||||
if not name:
|
||||
name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
run = {
|
||||
"id": run_id,
|
||||
"dataset_id": dataset_id,
|
||||
"dialog_id": dialog_id,
|
||||
"name": name,
|
||||
"config_snapshot": dialog.to_dict(),
|
||||
"metrics_summary": None,
|
||||
"status": "RUNNING",
|
||||
"created_by": user_id,
|
||||
"create_time": current_timestamp(),
|
||||
"complete_time": None
|
||||
}
|
||||
|
||||
if not EvaluationRun.create(**run):
|
||||
return False, "Failed to create evaluation run"
|
||||
|
||||
# Execute evaluation asynchronously (in production, use task queue)
|
||||
# For now, we'll execute synchronously
|
||||
cls._execute_evaluation(run_id, dataset_id, dialog)
|
||||
|
||||
return True, run_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error starting evaluation: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@classmethod
|
||||
def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any):
|
||||
"""
|
||||
Execute evaluation for all test cases.
|
||||
|
||||
This method runs the RAG pipeline for each test case and computes metrics.
|
||||
"""
|
||||
try:
|
||||
# Get all test cases
|
||||
test_cases = cls.get_test_cases(dataset_id)
|
||||
|
||||
if not test_cases:
|
||||
EvaluationRun.update(
|
||||
status="FAILED",
|
||||
complete_time=current_timestamp()
|
||||
).where(EvaluationRun.id == run_id).execute()
|
||||
return
|
||||
|
||||
# Execute each test case
|
||||
results = []
|
||||
for case in test_cases:
|
||||
result = cls._evaluate_single_case(run_id, case, dialog)
|
||||
if result:
|
||||
results.append(result)
|
||||
|
||||
# Compute summary metrics
|
||||
metrics_summary = cls._compute_summary_metrics(results)
|
||||
|
||||
# Update run status
|
||||
EvaluationRun.update(
|
||||
status="COMPLETED",
|
||||
metrics_summary=metrics_summary,
|
||||
complete_time=current_timestamp()
|
||||
).where(EvaluationRun.id == run_id).execute()
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing evaluation {run_id}: {e}")
|
||||
EvaluationRun.update(
|
||||
status="FAILED",
|
||||
complete_time=current_timestamp()
|
||||
).where(EvaluationRun.id == run_id).execute()
|
||||
|
||||
@classmethod
|
||||
def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
|
||||
dialog: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Evaluate a single test case.
|
||||
|
||||
Args:
|
||||
run_id: Evaluation run ID
|
||||
case: Test case dictionary
|
||||
dialog: Dialog configuration
|
||||
|
||||
Returns:
|
||||
Result dictionary or None if failed
|
||||
"""
|
||||
try:
|
||||
# Prepare messages
|
||||
messages = [{"role": "user", "content": case["question"]}]
|
||||
|
||||
# Execute RAG pipeline
|
||||
start_time = timer()
|
||||
answer = ""
|
||||
retrieved_chunks = []
|
||||
|
||||
for ans in chat(dialog, messages, stream=False):
|
||||
if isinstance(ans, dict):
|
||||
answer = ans.get("answer", "")
|
||||
retrieved_chunks = ans.get("reference", {}).get("chunks", [])
|
||||
break
|
||||
|
||||
execution_time = timer() - start_time
|
||||
|
||||
# Compute metrics
|
||||
metrics = cls._compute_metrics(
|
||||
question=case["question"],
|
||||
generated_answer=answer,
|
||||
reference_answer=case.get("reference_answer"),
|
||||
retrieved_chunks=retrieved_chunks,
|
||||
relevant_chunk_ids=case.get("relevant_chunk_ids"),
|
||||
dialog=dialog
|
||||
)
|
||||
|
||||
# Save result
|
||||
result_id = get_uuid()
|
||||
result = {
|
||||
"id": result_id,
|
||||
"run_id": run_id,
|
||||
"case_id": case["id"],
|
||||
"generated_answer": answer,
|
||||
"retrieved_chunks": retrieved_chunks,
|
||||
"metrics": metrics,
|
||||
"execution_time": execution_time,
|
||||
"token_usage": None, # TODO: Track token usage
|
||||
"create_time": current_timestamp()
|
||||
}
|
||||
|
||||
EvaluationResult.create(**result)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error evaluating case {case.get('id')}: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _compute_metrics(cls, question: str, generated_answer: str,
|
||||
reference_answer: Optional[str],
|
||||
retrieved_chunks: List[Dict[str, Any]],
|
||||
relevant_chunk_ids: Optional[List[str]],
|
||||
dialog: Any) -> Dict[str, float]:
|
||||
"""
|
||||
Compute evaluation metrics for a single test case.
|
||||
|
||||
Returns:
|
||||
Dictionary of metric names to values
|
||||
"""
|
||||
metrics = {}
|
||||
|
||||
# Retrieval metrics (if ground truth chunks provided)
|
||||
if relevant_chunk_ids:
|
||||
retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks]
|
||||
metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids))
|
||||
|
||||
# Generation metrics
|
||||
if generated_answer:
|
||||
# Basic metrics
|
||||
metrics["answer_length"] = len(generated_answer)
|
||||
metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0
|
||||
|
||||
# TODO: Implement advanced metrics using LLM-as-judge
|
||||
# - Faithfulness (hallucination detection)
|
||||
# - Answer relevance
|
||||
# - Context relevance
|
||||
# - Semantic similarity (if reference answer provided)
|
||||
|
||||
return metrics
|
||||
|
||||
@classmethod
|
||||
def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
|
||||
relevant_ids: List[str]) -> Dict[str, float]:
|
||||
"""
|
||||
Compute retrieval metrics.
|
||||
|
||||
Args:
|
||||
retrieved_ids: List of retrieved chunk IDs
|
||||
relevant_ids: List of relevant chunk IDs (ground truth)
|
||||
|
||||
Returns:
|
||||
Dictionary of retrieval metrics
|
||||
"""
|
||||
if not relevant_ids:
|
||||
return {}
|
||||
|
||||
retrieved_set = set(retrieved_ids)
|
||||
relevant_set = set(relevant_ids)
|
||||
|
||||
# Precision: proportion of retrieved that are relevant
|
||||
precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0
|
||||
|
||||
# Recall: proportion of relevant that were retrieved
|
||||
recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0
|
||||
|
||||
# F1 score
|
||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||
|
||||
# Hit rate: whether any relevant chunk was retrieved
|
||||
hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0
|
||||
|
||||
# MRR (Mean Reciprocal Rank): position of first relevant chunk
|
||||
mrr = 0.0
|
||||
for i, chunk_id in enumerate(retrieved_ids, 1):
|
||||
if chunk_id in relevant_set:
|
||||
mrr = 1.0 / i
|
||||
break
|
||||
|
||||
return {
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1_score": f1,
|
||||
"hit_rate": hit_rate,
|
||||
"mrr": mrr
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Compute summary metrics across all test cases.
|
||||
|
||||
Args:
|
||||
results: List of result dictionaries
|
||||
|
||||
Returns:
|
||||
Summary metrics dictionary
|
||||
"""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
# Aggregate metrics
|
||||
metric_sums = {}
|
||||
metric_counts = {}
|
||||
|
||||
for result in results:
|
||||
metrics = result.get("metrics", {})
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
metric_sums[key] = metric_sums.get(key, 0) + value
|
||||
metric_counts[key] = metric_counts.get(key, 0) + 1
|
||||
|
||||
# Compute averages
|
||||
summary = {
|
||||
"total_cases": len(results),
|
||||
"avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results)
|
||||
}
|
||||
|
||||
for key in metric_sums:
|
||||
summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key]
|
||||
|
||||
return summary
|
||||
|
||||
# ==================== Results & Analysis ====================
|
||||
|
||||
@classmethod
|
||||
def get_run_results(cls, run_id: str) -> Dict[str, Any]:
|
||||
"""Get results for an evaluation run"""
|
||||
try:
|
||||
run = EvaluationRun.get_by_id(run_id)
|
||||
if not run:
|
||||
return {}
|
||||
|
||||
results = EvaluationResult.select().where(
|
||||
EvaluationResult.run_id == run_id
|
||||
).order_by(EvaluationResult.create_time)
|
||||
|
||||
return {
|
||||
"run": run.to_dict(),
|
||||
"results": [r.to_dict() for r in results]
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting run results {run_id}: {e}")
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Analyze evaluation results and provide configuration recommendations.
|
||||
|
||||
Args:
|
||||
run_id: Evaluation run ID
|
||||
|
||||
Returns:
|
||||
List of recommendation dictionaries
|
||||
"""
|
||||
try:
|
||||
run = EvaluationRun.get_by_id(run_id)
|
||||
if not run or not run.metrics_summary:
|
||||
return []
|
||||
|
||||
metrics = run.metrics_summary
|
||||
recommendations = []
|
||||
|
||||
# Low precision: retrieving irrelevant chunks
|
||||
if metrics.get("avg_precision", 1.0) < 0.7:
|
||||
recommendations.append({
|
||||
"issue": "Low Precision",
|
||||
"severity": "high",
|
||||
"description": "System is retrieving many irrelevant chunks",
|
||||
"suggestions": [
|
||||
"Increase similarity_threshold to filter out less relevant chunks",
|
||||
"Enable reranking to improve chunk ordering",
|
||||
"Reduce top_k to return fewer chunks"
|
||||
]
|
||||
})
|
||||
|
||||
# Low recall: missing relevant chunks
|
||||
if metrics.get("avg_recall", 1.0) < 0.7:
|
||||
recommendations.append({
|
||||
"issue": "Low Recall",
|
||||
"severity": "high",
|
||||
"description": "System is missing relevant chunks",
|
||||
"suggestions": [
|
||||
"Increase top_k to retrieve more chunks",
|
||||
"Lower similarity_threshold to be more inclusive",
|
||||
"Enable hybrid search (keyword + semantic)",
|
||||
"Check chunk size - may be too large or too small"
|
||||
]
|
||||
})
|
||||
|
||||
# Slow response time
|
||||
if metrics.get("avg_execution_time", 0) > 5.0:
|
||||
recommendations.append({
|
||||
"issue": "Slow Response Time",
|
||||
"severity": "medium",
|
||||
"description": f"Average response time is {metrics['avg_execution_time']:.2f}s",
|
||||
"suggestions": [
|
||||
"Reduce top_k to retrieve fewer chunks",
|
||||
"Optimize embedding model selection",
|
||||
"Consider caching frequently asked questions"
|
||||
]
|
||||
})
|
||||
|
||||
return recommendations
|
||||
except Exception as e:
|
||||
logging.error(f"Error generating recommendations for run {run_id}: {e}")
|
||||
return []
|
||||
@ -13,12 +13,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from flask_login import current_user
|
||||
from peewee import fn
|
||||
|
||||
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileType
|
||||
@ -31,7 +35,7 @@ from common.misc_utils import get_uuid
|
||||
from common.constants import TaskStatus, FileSource, ParserType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img, sanitize_path
|
||||
from rag.llm.cv_model import GptV4
|
||||
from common import settings
|
||||
|
||||
@ -184,6 +188,7 @@ class FileService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create_folder(cls, file, parent_id, name, count):
|
||||
from api.apps import current_user
|
||||
# Recursively create folder structure
|
||||
# Args:
|
||||
# file: Current file object
|
||||
@ -329,7 +334,7 @@ class FileService(CommonService):
|
||||
current_id = start_id
|
||||
while current_id:
|
||||
e, file = cls.get_by_id(current_id)
|
||||
if file.parent_id != file.id and e:
|
||||
if e and file.parent_id != file.id:
|
||||
parent_folders.append(file)
|
||||
current_id = file.parent_id
|
||||
else:
|
||||
@ -423,13 +428,15 @@ class FileService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def upload_document(self, kb, file_objs, user_id, src="local"):
|
||||
def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str | None = None):
|
||||
root_folder = self.get_root_folder(user_id)
|
||||
pf_id = root_folder["id"]
|
||||
self.init_knowledgebase_docs(pf_id, user_id)
|
||||
kb_root_folder = self.get_kb_folder(user_id)
|
||||
kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
|
||||
|
||||
safe_parent_path = sanitize_path(parent_path)
|
||||
|
||||
err, files = [], []
|
||||
for file in file_objs:
|
||||
try:
|
||||
@ -439,7 +446,7 @@ class FileService(CommonService):
|
||||
if filetype == FileType.OTHER.value:
|
||||
raise RuntimeError("This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
location = filename if not safe_parent_path else f"{safe_parent_path}/{filename}"
|
||||
while settings.STORAGE_IMPL.obj_exist(kb.id, location):
|
||||
location += "_"
|
||||
|
||||
@ -506,6 +513,7 @@ class FileService(CommonService):
|
||||
@staticmethod
|
||||
def parse(filename, blob, img_base64=True, tenant_id=None):
|
||||
from rag.app import audio, email, naive, picture, presentation
|
||||
from api.apps import current_user
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
@ -517,7 +525,7 @@ class FileService(CommonService):
|
||||
if img_base64 and file_type == FileType.VISUAL.value:
|
||||
return GptV4.image2base64(blob)
|
||||
cks = FACTORY.get(FileService.get_parser(filename_type(filename), filename, ""), naive).chunk(filename, blob, **kwargs)
|
||||
return "\n".join([ck["content_with_weight"] for ck in cks])
|
||||
return f"\n -----------------\nFile: {filename}\nContent as following: \n" + "\n".join([ck["content_with_weight"] for ck in cks])
|
||||
|
||||
@staticmethod
|
||||
def get_parser(doc_type, filename, default):
|
||||
@ -585,3 +593,80 @@ class FileService(CommonService):
|
||||
errors += str(e)
|
||||
|
||||
return errors
|
||||
|
||||
@staticmethod
|
||||
def upload_info(user_id, file, url: str|None=None):
|
||||
def structured(filename, filetype, blob, content_type):
|
||||
nonlocal user_id
|
||||
if filetype == FileType.PDF.value:
|
||||
blob = read_potential_broken_pdf(blob)
|
||||
|
||||
location = get_uuid()
|
||||
FileService.put_blob(user_id, location, blob)
|
||||
|
||||
return {
|
||||
"id": location,
|
||||
"name": filename,
|
||||
"size": sys.getsizeof(blob),
|
||||
"extension": filename.split(".")[-1].lower(),
|
||||
"mime_type": content_type,
|
||||
"created_by": user_id,
|
||||
"created_at": time.time(),
|
||||
"preview_url": None
|
||||
}
|
||||
|
||||
if url:
|
||||
from crawl4ai import (
|
||||
AsyncWebCrawler,
|
||||
BrowserConfig,
|
||||
CrawlerRunConfig,
|
||||
DefaultMarkdownGenerator,
|
||||
PruningContentFilter,
|
||||
CrawlResult
|
||||
)
|
||||
filename = re.sub(r"\?.*", "", url.split("/")[-1])
|
||||
async def adownload():
|
||||
browser_config = BrowserConfig(
|
||||
headless=True,
|
||||
verbose=False,
|
||||
)
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
crawler_config = CrawlerRunConfig(
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=PruningContentFilter()
|
||||
),
|
||||
pdf=True,
|
||||
screenshot=False
|
||||
)
|
||||
result: CrawlResult = await crawler.arun(
|
||||
url=url,
|
||||
config=crawler_config
|
||||
)
|
||||
return result
|
||||
page = asyncio.run(adownload())
|
||||
if page.pdf:
|
||||
if filename.split(".")[-1].lower() != "pdf":
|
||||
filename += ".pdf"
|
||||
return structured(filename, "pdf", page.pdf, page.response_headers["content-type"])
|
||||
|
||||
return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id)
|
||||
|
||||
DocumentService.check_doc_health(user_id, file.filename)
|
||||
return structured(file.filename, filename_type(file.filename), file.read(), file.content_type)
|
||||
|
||||
@staticmethod
|
||||
def get_files(files: Union[None, list[dict]]) -> list[str]:
|
||||
if not files:
|
||||
return []
|
||||
def image_to_base64(file):
|
||||
return "data:{};base64,{}".format(file["mime_type"],
|
||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||
exe = ThreadPoolExecutor(max_workers=5)
|
||||
threads = []
|
||||
for file in files:
|
||||
if file["mime_type"].find("image") >=0:
|
||||
threads.append(exe.submit(image_to_base64, file))
|
||||
continue
|
||||
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||
return [th.result() for th in threads]
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ from common.constants import StatusEnum
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from api.utils.api_utils import get_parser_config, get_data_error_result
|
||||
|
||||
|
||||
class KnowledgebaseService(CommonService):
|
||||
"""Service class for managing knowledge base operations.
|
||||
|
||||
@ -391,12 +392,12 @@ class KnowledgebaseService(CommonService):
|
||||
"""
|
||||
# Validate name
|
||||
if not isinstance(name, str):
|
||||
return get_data_error_result(message="Dataset name must be string.")
|
||||
return False, get_data_error_result(message="Dataset name must be string.")
|
||||
dataset_name = name.strip()
|
||||
if dataset_name == "":
|
||||
return get_data_error_result(message="Dataset name can't be empty.")
|
||||
return False, get_data_error_result(message="Dataset name can't be empty.")
|
||||
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||
return get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
|
||||
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
|
||||
|
||||
# Deduplicate name within tenant
|
||||
dataset_name = duplicate_name(
|
||||
@ -409,7 +410,7 @@ class KnowledgebaseService(CommonService):
|
||||
# Verify tenant exists
|
||||
ok, _t = TenantService.get_by_id(tenant_id)
|
||||
if not ok:
|
||||
return False, "Tenant not found."
|
||||
return False, get_data_error_result(message="Tenant not found.")
|
||||
|
||||
# Build payload
|
||||
kb_id = get_uuid()
|
||||
@ -419,12 +420,13 @@ class KnowledgebaseService(CommonService):
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": tenant_id,
|
||||
"parser_id": (parser_id or "naive"),
|
||||
**kwargs
|
||||
**kwargs # Includes optional fields such as description, language, permission, avatar, parser_config, etc.
|
||||
}
|
||||
|
||||
# Default parser_config (align with kb_app.create) — do not accept external overrides
|
||||
# Update parser_config (always override with validated default/merged config)
|
||||
payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config"))
|
||||
return payload
|
||||
|
||||
return True, payload
|
||||
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -13,12 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
from common.constants import LLMType
|
||||
from api.db.db_models import LLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||
@ -32,6 +35,14 @@ def get_init_tenant_llm(user_id):
|
||||
from common import settings
|
||||
tenant_llm = []
|
||||
|
||||
model_configs = {
|
||||
LLMType.CHAT: settings.CHAT_CFG,
|
||||
LLMType.EMBEDDING: settings.EMBEDDING_CFG,
|
||||
LLMType.SPEECH2TEXT: settings.ASR_CFG,
|
||||
LLMType.IMAGE2TEXT: settings.IMAGE2TEXT_CFG,
|
||||
LLMType.RERANK: settings.RERANK_CFG,
|
||||
}
|
||||
|
||||
seen = set()
|
||||
factory_configs = []
|
||||
for factory_config in [
|
||||
@ -54,8 +65,8 @@ def get_init_tenant_llm(user_id):
|
||||
"llm_factory": factory_config["factory"],
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": factory_config["api_key"],
|
||||
"api_base": factory_config["base_url"],
|
||||
"api_key": model_configs.get(llm.model_type, {}).get("api_key", factory_config["api_key"]),
|
||||
"api_base": model_configs.get(llm.model_type, {}).get("base_url", factory_config["base_url"]),
|
||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||
}
|
||||
)
|
||||
@ -80,8 +91,8 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def encode(self, texts: list):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
|
||||
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
|
||||
|
||||
safe_texts = []
|
||||
for text in texts:
|
||||
token_size = num_tokens_from_string(text)
|
||||
@ -90,7 +101,7 @@ class LLMBundle(LLM4Tenant):
|
||||
safe_texts.append(text[:target_len])
|
||||
else:
|
||||
safe_texts.append(text)
|
||||
|
||||
|
||||
embeddings, used_tokens = self.mdl.encode(safe_texts)
|
||||
|
||||
llm_name = getattr(self, "llm_name", None)
|
||||
@ -174,6 +185,66 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
return txt
|
||||
|
||||
def stream_transcription(self, audio):
|
||||
mdl = self.mdl
|
||||
supports_stream = hasattr(mdl, "stream_transcription") and callable(getattr(mdl, "stream_transcription"))
|
||||
if supports_stream:
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(
|
||||
trace_context=self.trace_context,
|
||||
name="stream_transcription",
|
||||
metadata={"model": self.llm_name}
|
||||
)
|
||||
final_text = ""
|
||||
used_tokens = 0
|
||||
|
||||
try:
|
||||
for evt in mdl.stream_transcription(audio):
|
||||
if evt.get("event") == "final":
|
||||
final_text = evt.get("text", "")
|
||||
|
||||
yield evt
|
||||
|
||||
except Exception as e:
|
||||
err = {"event": "error", "text": str(e)}
|
||||
yield err
|
||||
final_text = final_text or ""
|
||||
finally:
|
||||
if final_text:
|
||||
used_tokens = num_tokens_from_string(final_text)
|
||||
TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens)
|
||||
|
||||
if self.langfuse:
|
||||
generation.update(
|
||||
output={"output": final_text},
|
||||
usage_details={"total_tokens": used_tokens}
|
||||
)
|
||||
generation.end()
|
||||
|
||||
return
|
||||
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name})
|
||||
full_text, used_tokens = mdl.transcription(audio)
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens
|
||||
):
|
||||
logging.error(
|
||||
f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}"
|
||||
)
|
||||
if self.langfuse:
|
||||
generation.update(
|
||||
output={"output": full_text},
|
||||
usage_details={"total_tokens": used_tokens}
|
||||
)
|
||||
generation.end()
|
||||
|
||||
yield {
|
||||
"event": "final",
|
||||
"text": full_text,
|
||||
"streaming": False
|
||||
}
|
||||
|
||||
def tts(self, text: str) -> Generator[bytes, None, None]:
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text})
|
||||
@ -233,7 +304,7 @@ class LLMBundle(LLM4Tenant):
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
@ -270,5 +341,89 @@ class LLMBundle(LLM4Tenant):
|
||||
yield ans
|
||||
|
||||
if total_tokens > 0:
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
|
||||
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
|
||||
def _bridge_sync_stream(self, gen):
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def worker():
|
||||
try:
|
||||
for item in gen:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||
except Exception as e: # pragma: no cover
|
||||
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||
finally:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||
|
||||
threading.Thread(target=worker, daemon=True).start()
|
||||
return queue
|
||||
|
||||
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
|
||||
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
|
||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
|
||||
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
|
||||
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
|
||||
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
|
||||
elif hasattr(self.mdl, "async_chat"):
|
||||
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
|
||||
else:
|
||||
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
|
||||
|
||||
txt = self._remove_reasoning_content(txt)
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||
|
||||
return txt
|
||||
|
||||
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||
total_tokens = 0
|
||||
ans = ""
|
||||
if self.is_tools and self.mdl.is_tools:
|
||||
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
||||
else:
|
||||
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
||||
|
||||
if stream_fn:
|
||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
async for txt in chat_partial(**use_kwargs):
|
||||
if isinstance(txt, int):
|
||||
total_tokens = txt
|
||||
break
|
||||
|
||||
if txt.endswith("</think>"):
|
||||
ans = ans[: -len("</think>")]
|
||||
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
ans += txt
|
||||
yield ans
|
||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
return
|
||||
|
||||
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is StopAsyncIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
if isinstance(item, int):
|
||||
total_tokens = item
|
||||
break
|
||||
yield item
|
||||
|
||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
|
||||
@ -20,28 +20,26 @@
|
||||
|
||||
from common.log_utils import init_root_logger
|
||||
from plugin import GlobalPluginManager
|
||||
init_root_logger("ragflow_server")
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import threading
|
||||
import uuid
|
||||
import faulthandler
|
||||
|
||||
from werkzeug.serving import run_simple
|
||||
from api.apps import app, smtp_mail_server
|
||||
from api.db.runtime_config import RuntimeConfig
|
||||
from api.db.services.document_service import DocumentService
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common import settings
|
||||
from api.db.db_models import init_database_tables as init_web_db
|
||||
from api.db.init_data import init_web_data
|
||||
from api.db.init_data import init_web_data, init_superuser
|
||||
from common.versions import get_ragflow_version
|
||||
from common.config_utils import show_configs
|
||||
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||
from common.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
|
||||
stop_event = threading.Event()
|
||||
@ -70,10 +68,12 @@ def signal_handler(sig, frame):
|
||||
logging.info("Received interrupt signal, shutting down...")
|
||||
shutdown_all_mcp_sessions()
|
||||
stop_event.set()
|
||||
time.sleep(1)
|
||||
stop_event.wait(1)
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
faulthandler.enable()
|
||||
init_root_logger("ragflow_server")
|
||||
logging.info(r"""
|
||||
____ ___ ______ ______ __
|
||||
/ __ \ / | / ____// ____// /____ _ __
|
||||
@ -110,11 +110,16 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
"--debug", default=False, help="debug mode", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init-superuser", default=False, help="init superuser", action="store_true"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.version:
|
||||
print(get_ragflow_version())
|
||||
sys.exit(0)
|
||||
|
||||
if args.init_superuser:
|
||||
init_superuser()
|
||||
RuntimeConfig.DEBUG = args.debug
|
||||
if RuntimeConfig.DEBUG:
|
||||
logging.info("run on debug mode")
|
||||
@ -153,16 +158,9 @@ if __name__ == '__main__':
|
||||
# start http server
|
||||
try:
|
||||
logging.info("RAGFlow HTTP server start...")
|
||||
run_simple(
|
||||
hostname=settings.HOST_IP,
|
||||
port=settings.HOST_PORT,
|
||||
application=app,
|
||||
threaded=True,
|
||||
use_reloader=RuntimeConfig.DEBUG,
|
||||
use_debugger=RuntimeConfig.DEBUG,
|
||||
)
|
||||
app.run(host=settings.HOST_IP, port=settings.HOST_PORT)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
stop_event.set()
|
||||
time.sleep(1)
|
||||
stop_event.wait(1)
|
||||
os.kill(os.getpid(), signal.SIGKILL)
|
||||
|
||||
@ -15,29 +15,29 @@
|
||||
#
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import trio
|
||||
from flask import (
|
||||
from quart import (
|
||||
Response,
|
||||
jsonify,
|
||||
request
|
||||
)
|
||||
from flask_login import current_user
|
||||
from flask import (
|
||||
request as flask_request,
|
||||
)
|
||||
|
||||
from peewee import OperationalError
|
||||
|
||||
from common.constants import ActiveEnum
|
||||
from api.db.db_models import APIToken
|
||||
from api.utils.json_encode import CustomJSONEncoder
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService
|
||||
from common.connection_utils import timeout
|
||||
from common.constants import RetCode
|
||||
@ -46,6 +46,41 @@ from common import settings
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
|
||||
|
||||
async def _coerce_request_data() -> dict:
|
||||
"""Fetch JSON body with sane defaults; fallback to form data."""
|
||||
payload: Any = None
|
||||
last_error: Exception | None = None
|
||||
|
||||
try:
|
||||
payload = await request.get_json(force=True, silent=True)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
payload = None
|
||||
|
||||
if payload is None:
|
||||
try:
|
||||
form = await request.form
|
||||
payload = form.to_dict()
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
payload = None
|
||||
|
||||
if payload is None:
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise ValueError("No JSON body or form data found in request.")
|
||||
|
||||
if isinstance(payload, dict):
|
||||
return payload or {}
|
||||
|
||||
if isinstance(payload, str):
|
||||
raise AttributeError("'str' object has no attribute 'get'")
|
||||
|
||||
raise TypeError(f"Unsupported request payload type: {type(payload)!r}")
|
||||
|
||||
async def get_request_json():
|
||||
return await _coerce_request_data()
|
||||
|
||||
def serialize_for_json(obj):
|
||||
"""
|
||||
Recursively serialize objects to make them JSON serializable.
|
||||
@ -84,7 +119,8 @@ def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!
|
||||
|
||||
|
||||
def server_error_response(e):
|
||||
logging.exception(e)
|
||||
# Quart invokes this handler outside the original except block, so we must pass exc_info manually.
|
||||
logging.error("Unhandled exception during request", exc_info=(type(e), e, e.__traceback__))
|
||||
try:
|
||||
msg = repr(e).lower()
|
||||
if getattr(e, "code", None) == 401 or ("unauthorized" in msg) or ("401" in msg):
|
||||
@ -105,31 +141,37 @@ def server_error_response(e):
|
||||
|
||||
|
||||
def validate_request(*args, **kwargs):
|
||||
def process_args(input_arguments):
|
||||
no_arguments = []
|
||||
error_arguments = []
|
||||
for arg in args:
|
||||
if arg not in input_arguments:
|
||||
no_arguments.append(arg)
|
||||
for k, v in kwargs.items():
|
||||
config_value = input_arguments.get(k, None)
|
||||
if config_value is None:
|
||||
no_arguments.append(k)
|
||||
elif isinstance(v, (tuple, list)):
|
||||
if config_value not in v:
|
||||
error_arguments.append((k, set(v)))
|
||||
elif config_value != v:
|
||||
error_arguments.append((k, v))
|
||||
if no_arguments or error_arguments:
|
||||
error_string = ""
|
||||
if no_arguments:
|
||||
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
|
||||
if error_arguments:
|
||||
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||
return error_string
|
||||
|
||||
def wrapper(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*_args, **_kwargs):
|
||||
input_arguments = flask_request.json or flask_request.form.to_dict()
|
||||
no_arguments = []
|
||||
error_arguments = []
|
||||
for arg in args:
|
||||
if arg not in input_arguments:
|
||||
no_arguments.append(arg)
|
||||
for k, v in kwargs.items():
|
||||
config_value = input_arguments.get(k, None)
|
||||
if config_value is None:
|
||||
no_arguments.append(k)
|
||||
elif isinstance(v, (tuple, list)):
|
||||
if config_value not in v:
|
||||
error_arguments.append((k, set(v)))
|
||||
elif config_value != v:
|
||||
error_arguments.append((k, v))
|
||||
if no_arguments or error_arguments:
|
||||
error_string = ""
|
||||
if no_arguments:
|
||||
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
|
||||
if error_arguments:
|
||||
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=error_string)
|
||||
async def decorated_function(*_args, **_kwargs):
|
||||
errs = process_args(await _coerce_request_data())
|
||||
if errs:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*_args, **_kwargs)
|
||||
return func(*_args, **_kwargs)
|
||||
|
||||
return decorated_function
|
||||
@ -138,30 +180,34 @@ def validate_request(*args, **kwargs):
|
||||
|
||||
|
||||
def not_allowed_parameters(*params):
|
||||
def decorator(f):
|
||||
def wrapper(*args, **kwargs):
|
||||
input_arguments = flask_request.json or flask_request.form.to_dict()
|
||||
def decorator(func):
|
||||
async def wrapper(*args, **kwargs):
|
||||
input_arguments = await _coerce_request_data()
|
||||
for param in params:
|
||||
if param in input_arguments:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
|
||||
return f(*args, **kwargs)
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def active_required(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
def active_required(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
from api.db.services import UserService
|
||||
from api.apps import current_user
|
||||
|
||||
user_id = current_user.id
|
||||
usr = UserService.filter_by_id(user_id)
|
||||
# check is_active
|
||||
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
|
||||
return get_json_result(code=RetCode.FORBIDDEN, message="User isn't active, please activate first.")
|
||||
return f(*args, **kwargs)
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -173,12 +219,15 @@ def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=Non
|
||||
|
||||
def apikey_required(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
token = flask_request.headers.get("Authorization").split()[1]
|
||||
async def decorated_function(*args, **kwargs):
|
||||
token = request.headers.get("Authorization").split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return build_error_result(message="API-KEY is invalid!", code=RetCode.FORBIDDEN)
|
||||
kwargs["tenant_id"] = objs[0].tenant_id
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
@ -199,23 +248,38 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da
|
||||
|
||||
|
||||
def token_required(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
def get_tenant_id(**kwargs):
|
||||
if os.environ.get("DISABLE_SDK"):
|
||||
return get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
authorization_str = flask_request.headers.get("Authorization")
|
||||
return False, get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
authorization_str = request.headers.get("Authorization")
|
||||
if not authorization_str:
|
||||
return get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
return False, get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
authorization_list = authorization_str.split()
|
||||
if len(authorization_list) < 2:
|
||||
return get_json_result(data=False, message="Please check your authorization format.")
|
||||
return False, get_json_result(data=False, message="Please check your authorization format.")
|
||||
token = authorization_list[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR)
|
||||
return False, get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR)
|
||||
kwargs["tenant_id"] = objs[0].tenant_id
|
||||
return True, kwargs
|
||||
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
e, kwargs = get_tenant_id(**kwargs)
|
||||
if not e:
|
||||
return kwargs
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
async def adecorated_function(*args, **kwargs):
|
||||
e, kwargs = get_tenant_id(**kwargs)
|
||||
if not e:
|
||||
return kwargs
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return adecorated_function
|
||||
return decorated_function
|
||||
|
||||
|
||||
@ -279,6 +343,10 @@ def get_parser_config(chunk_method, parser_config):
|
||||
chunk_method = "naive"
|
||||
|
||||
# Define default configurations for each chunking method
|
||||
base_defaults = {
|
||||
"table_context_size": 0,
|
||||
"image_context_size": 0,
|
||||
}
|
||||
key_mapping = {
|
||||
"naive": {
|
||||
"layout_recognize": "DeepDOC",
|
||||
@ -331,16 +399,19 @@ def get_parser_config(chunk_method, parser_config):
|
||||
|
||||
default_config = key_mapping[chunk_method]
|
||||
|
||||
# If no parser_config provided, return default
|
||||
# If no parser_config provided, return default merged with base defaults
|
||||
if not parser_config:
|
||||
return default_config
|
||||
if default_config is None:
|
||||
return deep_merge(base_defaults, {})
|
||||
return deep_merge(base_defaults, default_config)
|
||||
|
||||
# If parser_config is provided, merge with defaults to ensure required fields exist
|
||||
if default_config is None:
|
||||
return parser_config
|
||||
return deep_merge(base_defaults, parser_config)
|
||||
|
||||
# Ensure raptor and graphrag fields have default values if not provided
|
||||
merged_config = deep_merge(default_config, parser_config)
|
||||
merged_config = deep_merge(base_defaults, default_config)
|
||||
merged_config = deep_merge(merged_config, parser_config)
|
||||
|
||||
return merged_config
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ import base64
|
||||
import click
|
||||
import re
|
||||
|
||||
from flask import Flask
|
||||
from quart import Quart
|
||||
from werkzeug.security import generate_password_hash
|
||||
|
||||
from api.db.services import UserService
|
||||
@ -73,6 +73,7 @@ def reset_email(email, new_email, email_confirm):
|
||||
UserService.update_user(user[0].id,user_dict)
|
||||
click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
|
||||
|
||||
def register_commands(app: Flask):
|
||||
|
||||
def register_commands(app: Quart):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Reusable HTML email templates and registry.
|
||||
"""
|
||||
|
||||
@ -164,3 +164,23 @@ def read_potential_broken_pdf(blob):
|
||||
return repaired
|
||||
|
||||
return blob
|
||||
|
||||
|
||||
def sanitize_path(raw_path: str | None) -> str:
|
||||
"""Normalize and sanitize a user-provided path segment.
|
||||
|
||||
- Converts backslashes to forward slashes
|
||||
- Strips leading/trailing slashes
|
||||
- Removes '.' and '..' segments
|
||||
- Restricts characters to A-Za-z0-9, underscore, dash, and '/'
|
||||
"""
|
||||
if not raw_path:
|
||||
return ""
|
||||
backslash_re = re.compile(r"[\\]+")
|
||||
unsafe_re = re.compile(r"[^A-Za-z0-9_\-/]")
|
||||
normalized = backslash_re.sub("/", raw_path)
|
||||
normalized = normalized.strip("/")
|
||||
parts = [seg for seg in normalized.split("/") if seg and seg not in (".", "..")]
|
||||
sanitized = "/".join(parts)
|
||||
sanitized = unsafe_re.sub("", sanitized)
|
||||
return sanitized
|
||||
|
||||
@ -173,7 +173,8 @@ def check_task_executor_alive():
|
||||
heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats]
|
||||
task_executor_heartbeats[task_executor_id] = heartbeats
|
||||
if task_executor_heartbeats:
|
||||
return {"status": "alive", "message": task_executor_heartbeats}
|
||||
status = "alive" if any(task_executor_heartbeats.values()) else "timeout"
|
||||
return {"status": status, "message": task_executor_heartbeats}
|
||||
else:
|
||||
return {"status": "timeout", "message": "Not found any task executor."}
|
||||
except Exception as e:
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
@ -14,10 +14,11 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from collections import Counter
|
||||
import string
|
||||
from typing import Annotated, Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Request
|
||||
from quart import Request
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@ -25,6 +26,7 @@ from pydantic import (
|
||||
StringConstraints,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
|
||||
@ -32,7 +34,7 @@ from werkzeug.exceptions import BadRequest, UnsupportedMediaType
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
|
||||
|
||||
def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
|
||||
async def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""
|
||||
Validates and parses JSON requests through a multi-stage validation pipeline.
|
||||
|
||||
@ -81,7 +83,7 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
|
||||
from the final output after validation
|
||||
"""
|
||||
try:
|
||||
payload = request.get_json() or {}
|
||||
payload = await request.get_json() or {}
|
||||
except UnsupportedMediaType:
|
||||
return None, f"Unsupported content type: Expected application/json, got {request.content_type}"
|
||||
except BadRequest:
|
||||
@ -329,6 +331,7 @@ class RaptorConfig(Base):
|
||||
threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)]
|
||||
max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)]
|
||||
random_seed: Annotated[int, Field(default=0, ge=0)]
|
||||
auto_disable_for_structured_data: Annotated[bool, Field(default=True)]
|
||||
|
||||
|
||||
class GraphragConfig(Base):
|
||||
@ -361,10 +364,9 @@ class CreateDatasetReq(Base):
|
||||
description: Annotated[str | None, Field(default=None, max_length=65535)]
|
||||
embedding_model: Annotated[str | None, Field(default=None, max_length=255, serialization_alias="embd_id")]
|
||||
permission: Annotated[Literal["me", "team"], Field(default="me", min_length=1, max_length=16)]
|
||||
chunk_method: Annotated[
|
||||
Literal["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"],
|
||||
Field(default="naive", min_length=1, max_length=32, serialization_alias="parser_id"),
|
||||
]
|
||||
chunk_method: Annotated[str | None, Field(default=None, serialization_alias="parser_id")]
|
||||
parse_type: Annotated[int | None, Field(default=None, ge=0, le=64)]
|
||||
pipeline_id: Annotated[str | None, Field(default=None, min_length=32, max_length=32, serialization_alias="pipeline_id")]
|
||||
parser_config: Annotated[ParserConfig | None, Field(default=None)]
|
||||
|
||||
@field_validator("avatar", mode="after")
|
||||
@ -525,6 +527,93 @@ class CreateDatasetReq(Base):
|
||||
raise PydanticCustomError("string_too_long", "Parser config exceeds size limit (max 65,535 characters). Current size: {actual}", {"actual": len(json_str)})
|
||||
return v
|
||||
|
||||
@field_validator("pipeline_id", mode="after")
|
||||
@classmethod
|
||||
def validate_pipeline_id(cls, v: str | None) -> str | None:
|
||||
"""Validate pipeline_id as 32-char lowercase hex string if provided.
|
||||
|
||||
Rules:
|
||||
- None or empty string: treat as None (not set)
|
||||
- Must be exactly length 32
|
||||
- Must contain only hex digits (0-9a-fA-F); normalized to lowercase
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if v == "":
|
||||
return None
|
||||
if len(v) != 32:
|
||||
raise PydanticCustomError("format_invalid", "pipeline_id must be 32 hex characters")
|
||||
if any(ch not in string.hexdigits for ch in v):
|
||||
raise PydanticCustomError("format_invalid", "pipeline_id must be hexadecimal")
|
||||
return v.lower()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_parser_dependency(self) -> "CreateDatasetReq":
|
||||
"""
|
||||
Mixed conditional validation:
|
||||
- If parser_id is omitted (field not set):
|
||||
* If both parse_type and pipeline_id are omitted → default chunk_method = "naive"
|
||||
* If both parse_type and pipeline_id are provided → allow ingestion pipeline mode
|
||||
- If parser_id is provided (valid enum) → parse_type and pipeline_id must be None (disallow mixed usage)
|
||||
|
||||
Raises:
|
||||
PydanticCustomError with code 'dependency_error' on violation.
|
||||
"""
|
||||
# Omitted chunk_method (not in fields) logic
|
||||
if self.chunk_method is None and "chunk_method" not in self.model_fields_set:
|
||||
# All three absent → default naive
|
||||
if self.parse_type is None and self.pipeline_id is None:
|
||||
object.__setattr__(self, "chunk_method", "naive")
|
||||
return self
|
||||
# parser_id omitted: require BOTH parse_type & pipeline_id present (no partial allowed)
|
||||
if self.parse_type is None or self.pipeline_id is None:
|
||||
missing = []
|
||||
if self.parse_type is None:
|
||||
missing.append("parse_type")
|
||||
if self.pipeline_id is None:
|
||||
missing.append("pipeline_id")
|
||||
raise PydanticCustomError(
|
||||
"dependency_error",
|
||||
"parser_id omitted → required fields missing: {fields}",
|
||||
{"fields": ", ".join(missing)},
|
||||
)
|
||||
# Both provided → allow pipeline mode
|
||||
return self
|
||||
|
||||
# parser_id provided (valid): MUST NOT have parse_type or pipeline_id
|
||||
if isinstance(self.chunk_method, str):
|
||||
if self.parse_type is not None or self.pipeline_id is not None:
|
||||
invalid = []
|
||||
if self.parse_type is not None:
|
||||
invalid.append("parse_type")
|
||||
if self.pipeline_id is not None:
|
||||
invalid.append("pipeline_id")
|
||||
raise PydanticCustomError(
|
||||
"dependency_error",
|
||||
"parser_id provided → disallowed fields present: {fields}",
|
||||
{"fields": ", ".join(invalid)},
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("chunk_method", mode="wrap")
|
||||
@classmethod
|
||||
def validate_chunk_method(cls, v: Any, handler) -> Any:
|
||||
"""Wrap validation to unify error messages, including type errors (e.g. list)."""
|
||||
allowed = {"naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"}
|
||||
error_msg = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'"
|
||||
# Omitted field: handler won't be invoked (wrap still gets value); None treated as explicit invalid
|
||||
if v is None:
|
||||
raise PydanticCustomError("literal_error", error_msg)
|
||||
try:
|
||||
# Run inner validation (type checking)
|
||||
result = handler(v)
|
||||
except Exception:
|
||||
raise PydanticCustomError("literal_error", error_msg)
|
||||
# After handler, enforce enumeration
|
||||
if not isinstance(result, str) or result == "" or result not in allowed:
|
||||
raise PydanticCustomError("literal_error", error_msg)
|
||||
return result
|
||||
|
||||
|
||||
class UpdateDatasetReq(CreateDatasetReq):
|
||||
dataset_id: Annotated[str, Field(...)]
|
||||
|
||||
@ -23,7 +23,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from api.apps import smtp_mail_server
|
||||
from flask_mail import Message
|
||||
from flask import render_template_string
|
||||
from quart import render_template_string
|
||||
from api.utils.email_templates import EMAIL_TEMPLATES
|
||||
from selenium import webdriver
|
||||
from selenium.common.exceptions import TimeoutException
|
||||
|
||||
48
check_comment_ascii.py
Normal file
48
check_comment_ascii.py
Normal file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Check whether given python files contain non-ASCII comments.
|
||||
|
||||
How to check the whole git repo:
|
||||
|
||||
```
|
||||
$ git ls-files -z -- '*.py' | xargs -0 python3 check_comment_ascii.py
|
||||
```
|
||||
"""
|
||||
|
||||
import sys
|
||||
import tokenize
|
||||
import ast
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
ASCII = re.compile(r"^[\n -~]*\Z") # Printable ASCII + newline
|
||||
|
||||
|
||||
def check(src: str, name: str) -> int:
|
||||
"""
|
||||
docstring line 1
|
||||
docstring line 2
|
||||
"""
|
||||
ok = 1
|
||||
# A common comment begins with `#`
|
||||
with tokenize.open(src) as fp:
|
||||
for tk in tokenize.generate_tokens(fp.readline):
|
||||
if tk.type == tokenize.COMMENT and not ASCII.fullmatch(tk.string):
|
||||
print(f"{name}:{tk.start[0]}: non-ASCII comment: {tk.string}")
|
||||
ok = 0
|
||||
# A docstring begins and ends with `'''`
|
||||
for node in ast.walk(ast.parse(pathlib.Path(src).read_text(), filename=name)):
|
||||
if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
|
||||
if (doc := ast.get_docstring(node)) and not ASCII.fullmatch(doc):
|
||||
print(f"{name}:{node.lineno}: non-ASCII docstring: {doc}")
|
||||
ok = 0
|
||||
return ok
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
status = 0
|
||||
for file in sys.argv[1:]:
|
||||
if not check(file, file):
|
||||
status = 1
|
||||
sys.exit(status)
|
||||
@ -21,7 +21,7 @@ from typing import Any, Callable, Coroutine, Optional, Type, Union
|
||||
import asyncio
|
||||
import trio
|
||||
from functools import wraps
|
||||
from flask import make_response, jsonify
|
||||
from quart import make_response, jsonify
|
||||
from common.constants import RetCode
|
||||
|
||||
TimeoutException = Union[Type[BaseException], BaseException]
|
||||
@ -103,7 +103,7 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception:
|
||||
return decorator
|
||||
|
||||
|
||||
def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
|
||||
async def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
|
||||
result_dict = {"code": code, "message": message, "data": data}
|
||||
response_dict = {}
|
||||
for key, value in result_dict.items():
|
||||
@ -111,7 +111,27 @@ def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=
|
||||
continue
|
||||
else:
|
||||
response_dict[key] = value
|
||||
response = make_response(jsonify(response_dict))
|
||||
response = await make_response(jsonify(response_dict))
|
||||
if auth:
|
||||
response.headers["Authorization"] = auth
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Method"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Expose-Headers"] = "Authorization"
|
||||
return response
|
||||
|
||||
|
||||
def sync_construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
|
||||
import flask
|
||||
result_dict = {"code": code, "message": message, "data": data}
|
||||
response_dict = {}
|
||||
for key, value in result_dict.items():
|
||||
if value is None and key != "code":
|
||||
continue
|
||||
else:
|
||||
response_dict[key] = value
|
||||
response = flask.make_response(flask.jsonify(response_dict))
|
||||
if auth:
|
||||
response.headers["Authorization"] = auth
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
|
||||
@ -49,6 +49,7 @@ class RetCode(IntEnum, CustomEnum):
|
||||
RUNNING = 106
|
||||
PERMISSION_ERROR = 108
|
||||
AUTHENTICATION_ERROR = 109
|
||||
BAD_REQUEST = 400
|
||||
UNAUTHORIZED = 401
|
||||
SERVER_ERROR = 500
|
||||
FORBIDDEN = 403
|
||||
@ -118,6 +119,9 @@ class FileSource(StrEnum):
|
||||
SHAREPOINT = "sharepoint"
|
||||
SLACK = "slack"
|
||||
TEAMS = "teams"
|
||||
WEBDAV = "webdav"
|
||||
MOODLE = "moodle"
|
||||
DROPBOX = "dropbox"
|
||||
|
||||
|
||||
class PipelineTaskType(StrEnum):
|
||||
@ -144,6 +148,7 @@ class Storage(Enum):
|
||||
AWS_S3 = 4
|
||||
OSS = 5
|
||||
OPENDAL = 6
|
||||
GCS = 7
|
||||
|
||||
# environment
|
||||
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"
|
||||
|
||||
@ -11,9 +11,11 @@ from .confluence_connector import ConfluenceConnector
|
||||
from .discord_connector import DiscordConnector
|
||||
from .dropbox_connector import DropboxConnector
|
||||
from .google_drive.connector import GoogleDriveConnector
|
||||
from .jira_connector import JiraConnector
|
||||
from .jira.connector import JiraConnector
|
||||
from .sharepoint_connector import SharePointConnector
|
||||
from .teams_connector import TeamsConnector
|
||||
from .webdav_connector import WebDAVConnector
|
||||
from .moodle_connector import MoodleConnector
|
||||
from .config import BlobType, DocumentSource
|
||||
from .models import Document, TextSection, ImageSection, BasicExpertInfo
|
||||
from .exceptions import (
|
||||
@ -36,6 +38,8 @@ __all__ = [
|
||||
"JiraConnector",
|
||||
"SharePointConnector",
|
||||
"TeamsConnector",
|
||||
"WebDAVConnector",
|
||||
"MoodleConnector",
|
||||
"BlobType",
|
||||
"DocumentSource",
|
||||
"Document",
|
||||
|
||||
@ -87,6 +87,13 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
|
||||
|
||||
elif self.bucket_type == BlobType.S3_COMPATIBLE:
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["endpoint_url", "aws_access_key_id", "aws_secret_access_key", "addressing_style"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("S3 Compatible Storage")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ def get_current_tz_offset() -> int:
|
||||
return round(time_diff.total_seconds() / 3600)
|
||||
|
||||
|
||||
ONE_MINUTE = 60
|
||||
ONE_HOUR = 3600
|
||||
ONE_DAY = ONE_HOUR * 24
|
||||
|
||||
@ -31,6 +32,7 @@ class BlobType(str, Enum):
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
S3_COMPATIBLE = "s3_compatible"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
@ -42,9 +44,14 @@ class DocumentSource(str, Enum):
|
||||
OCI_STORAGE = "oci_storage"
|
||||
SLACK = "slack"
|
||||
CONFLUENCE = "confluence"
|
||||
JIRA = "jira"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
GMAIL = "gmail"
|
||||
DISCORD = "discord"
|
||||
WEBDAV = "webdav"
|
||||
MOODLE = "moodle"
|
||||
S3_COMPATIBLE = "s3_compatible"
|
||||
DROPBOX = "dropbox"
|
||||
|
||||
|
||||
class FileOrigin(str, Enum):
|
||||
@ -178,6 +185,21 @@ GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
if ignored_tag
|
||||
]
|
||||
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
|
||||
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
|
||||
)
|
||||
JIRA_SYNC_TIME_BUFFER_SECONDS = int(
|
||||
os.environ.get("JIRA_SYNC_TIME_BUFFER_SECONDS", ONE_MINUTE)
|
||||
)
|
||||
JIRA_TIMEZONE_OFFSET = float(
|
||||
os.environ.get("JIRA_TIMEZONE_OFFSET", get_current_tz_offset())
|
||||
)
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||
@ -195,6 +217,7 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
)
|
||||
GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = os.environ.get("GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/google-drive/oauth/web/callback")
|
||||
GMAIL_WEB_OAUTH_REDIRECT_URI = os.environ.get("GMAIL_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/gmail/oauth/web/callback")
|
||||
|
||||
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
@ -1562,6 +1562,7 @@ class ConfluenceConnector(
|
||||
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
|
||||
doc_updated_at=datetime_from_string(page["version"]["when"]),
|
||||
primary_owners=primary_owners if primary_owners else None,
|
||||
metadata=metadata if metadata else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error converting page {page.get('id', 'unknown')}: {e}")
|
||||
@ -1788,6 +1789,7 @@ class ConfluenceConnector(
|
||||
cql_url = self.confluence_client.build_cql_url(
|
||||
page_query, expand=",".join(_PAGE_EXPANSION_FIELDS)
|
||||
)
|
||||
logging.info(f"[Confluence Connector] Building CQL URL {cql_url}")
|
||||
return update_param_in_path(cql_url, "limit", str(limit))
|
||||
|
||||
@override
|
||||
|
||||
@ -65,6 +65,7 @@ def _convert_message_to_document(
|
||||
blob=message.content.encode("utf-8"),
|
||||
extension=".txt",
|
||||
size_bytes=len(message.content.encode("utf-8")),
|
||||
metadata=metadata if metadata else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,13 +1,24 @@
|
||||
"""Dropbox connector"""
|
||||
|
||||
import logging
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from dropbox import Dropbox
|
||||
from dropbox.exceptions import ApiError, AuthError
|
||||
from dropbox.files import FileMetadata, FolderMetadata
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import ConnectorValidationError, InsufficientPermissionsError, ConnectorMissingCredentialError
|
||||
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
ConnectorValidationError,
|
||||
InsufficientPermissionsError,
|
||||
)
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
||||
from common.data_source.models import Document, GenerateDocumentsOutput
|
||||
from common.data_source.utils import get_file_ext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DropboxConnector(LoadConnector, PollConnector):
|
||||
@ -19,29 +30,29 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load Dropbox credentials"""
|
||||
try:
|
||||
access_token = credentials.get("dropbox_access_token")
|
||||
if not access_token:
|
||||
raise ConnectorMissingCredentialError("Dropbox access token is required")
|
||||
|
||||
self.dropbox_client = Dropbox(access_token)
|
||||
return None
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"Dropbox: {e}")
|
||||
access_token = credentials.get("dropbox_access_token")
|
||||
if not access_token:
|
||||
raise ConnectorMissingCredentialError("Dropbox access token is required")
|
||||
|
||||
self.dropbox_client = Dropbox(access_token)
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Dropbox connector settings"""
|
||||
if not self.dropbox_client:
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
|
||||
try:
|
||||
# Test connection by getting current account info
|
||||
self.dropbox_client.users_get_current_account()
|
||||
except (AuthError, ApiError) as e:
|
||||
if "invalid_access_token" in str(e).lower():
|
||||
raise InsufficientPermissionsError("Invalid Dropbox access token")
|
||||
else:
|
||||
raise ConnectorValidationError(f"Dropbox validation error: {e}")
|
||||
self.dropbox_client.files_list_folder(path="", limit=1)
|
||||
except AuthError as e:
|
||||
logger.exception("[Dropbox]: Failed to validate Dropbox credentials")
|
||||
raise ConnectorValidationError(f"Dropbox credential is invalid: {e}")
|
||||
except ApiError as e:
|
||||
if e.error is not None and "insufficient_permissions" in str(e.error).lower():
|
||||
raise InsufficientPermissionsError("Your Dropbox token does not have sufficient permissions.")
|
||||
raise ConnectorValidationError(f"Unexpected Dropbox error during validation: {e.user_message_text or e}")
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(f"Unexpected error during Dropbox settings validation: {e}")
|
||||
|
||||
def _download_file(self, path: str) -> bytes:
|
||||
"""Download a single file from Dropbox."""
|
||||
@ -54,26 +65,105 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
"""Create a shared link for a file in Dropbox."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
|
||||
try:
|
||||
# Try to get existing shared links first
|
||||
shared_links = self.dropbox_client.sharing_list_shared_links(path=path)
|
||||
if shared_links.links:
|
||||
return shared_links.links[0].url
|
||||
|
||||
# Create a new shared link
|
||||
link_settings = self.dropbox_client.sharing_create_shared_link_with_settings(path)
|
||||
return link_settings.url
|
||||
except Exception:
|
||||
# Fallback to basic link format
|
||||
return f"https://www.dropbox.com/home{path}"
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
link_metadata = self.dropbox_client.sharing_create_shared_link_with_settings(path)
|
||||
return link_metadata.url
|
||||
except ApiError as err:
|
||||
logger.exception(f"[Dropbox]: Failed to create a shared link for {path}: {err}")
|
||||
return ""
|
||||
|
||||
def _yield_files_recursive(
|
||||
self,
|
||||
path: str,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Yield files in batches from a specified Dropbox folder, including subfolders."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
result = self.dropbox_client.files_list_folder(
|
||||
path,
|
||||
limit=self.batch_size,
|
||||
recursive=False,
|
||||
include_non_downloadable_files=False,
|
||||
)
|
||||
|
||||
while True:
|
||||
batch: list[Document] = []
|
||||
for entry in result.entries:
|
||||
if isinstance(entry, FileMetadata):
|
||||
modified_time = entry.client_modified
|
||||
if modified_time.tzinfo is None:
|
||||
modified_time = modified_time.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
modified_time = modified_time.astimezone(timezone.utc)
|
||||
|
||||
time_as_seconds = modified_time.timestamp()
|
||||
if start is not None and time_as_seconds <= start:
|
||||
continue
|
||||
if end is not None and time_as_seconds > end:
|
||||
continue
|
||||
|
||||
try:
|
||||
downloaded_file = self._download_file(entry.path_display)
|
||||
except Exception:
|
||||
logger.exception(f"[Dropbox]: Error downloading file {entry.path_display}")
|
||||
continue
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"dropbox:{entry.id}",
|
||||
blob=downloaded_file,
|
||||
source=DocumentSource.DROPBOX,
|
||||
semantic_identifier=entry.name,
|
||||
extension=get_file_ext(entry.name),
|
||||
doc_updated_at=modified_time,
|
||||
size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file),
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(entry, FolderMetadata):
|
||||
yield from self._yield_files_recursive(entry.path_lower, start, end)
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
if not result.has_more:
|
||||
break
|
||||
|
||||
result = self.dropbox_client.files_list_folder_continue(result.cursor)
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput:
|
||||
"""Poll Dropbox for recent file changes"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
def load_from_state(self) -> Any:
|
||||
for batch in self._yield_files_recursive("", start, end):
|
||||
yield batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""Load files from Dropbox state"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
return self._yield_files_recursive("", None, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
connector = DropboxConnector()
|
||||
connector.load_credentials({"dropbox_access_token": os.environ.get("DROPBOX_ACCESS_TOKEN")})
|
||||
connector.validate_connector_settings()
|
||||
document_batches = connector.load_from_state()
|
||||
try:
|
||||
first_batch = next(document_batches)
|
||||
print(f"Loaded {len(first_batch)} documents in first batch.")
|
||||
for doc in first_batch:
|
||||
print(f"- {doc.semantic_identifier} ({doc.size_bytes} bytes)")
|
||||
except StopIteration:
|
||||
print("No documents available in Dropbox.")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.errors import HttpError
|
||||
@ -9,10 +9,10 @@ from common.data_source.config import INDEX_BATCH_SIZE, SLIM_BATCH_SIZE, Documen
|
||||
from common.data_source.google_util.auth import get_google_creds
|
||||
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS
|
||||
from common.data_source.google_util.resource import get_admin_service, get_gmail_service
|
||||
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval
|
||||
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval, sanitize_filename, clean_string
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
|
||||
from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection
|
||||
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, time_str_to_utc
|
||||
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, gmail_time_str_to_utc
|
||||
|
||||
# Constants for Gmail API fields
|
||||
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
|
||||
@ -67,7 +67,6 @@ def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str,
|
||||
message_data += f"{name}: {value}\n"
|
||||
|
||||
message_body_text: str = get_message_body(payload)
|
||||
|
||||
return TextSection(link=link, text=message_body_text + message_data), metadata
|
||||
|
||||
|
||||
@ -97,13 +96,15 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread:
|
||||
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = message_metadata.get("subject", "")
|
||||
semantic_identifier = clean_string(semantic_identifier)
|
||||
semantic_identifier = sanitize_filename(semantic_identifier)
|
||||
|
||||
if message_metadata.get("updated_at"):
|
||||
updated_at = message_metadata.get("updated_at")
|
||||
|
||||
|
||||
updated_at_datetime = None
|
||||
if updated_at:
|
||||
updated_at_datetime = time_str_to_utc(updated_at)
|
||||
updated_at_datetime = gmail_time_str_to_utc(updated_at)
|
||||
|
||||
thread_id = full_thread.get("id")
|
||||
if not thread_id:
|
||||
@ -115,15 +116,24 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread:
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = "(no subject)"
|
||||
|
||||
combined_sections = "\n\n".join(
|
||||
sec.text for sec in sections if hasattr(sec, "text")
|
||||
)
|
||||
blob = combined_sections
|
||||
size_bytes = len(blob)
|
||||
extension = '.txt'
|
||||
|
||||
return Document(
|
||||
id=thread_id,
|
||||
semantic_identifier=semantic_identifier,
|
||||
sections=sections,
|
||||
blob=blob,
|
||||
size_bytes=size_bytes,
|
||||
extension=extension,
|
||||
source=DocumentSource.GMAIL,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
doc_updated_at=updated_at_datetime,
|
||||
metadata={},
|
||||
metadata=message_metadata,
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails={email_used_to_fetch_thread},
|
||||
external_user_group_ids=set(),
|
||||
@ -214,15 +224,13 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
full_threads = _execute_single_retrieval(
|
||||
full_thread = _execute_single_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().get,
|
||||
list_key=None,
|
||||
userId=user_email,
|
||||
fields=THREAD_FIELDS,
|
||||
id=thread["id"],
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
full_thread = list(full_threads)[0]
|
||||
doc = thread_to_document(full_thread, user_email)
|
||||
if doc is None:
|
||||
continue
|
||||
@ -310,4 +318,30 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
import time
|
||||
import os
|
||||
from common.data_source.google_util.util import get_credentials_from_env
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
try:
|
||||
email = os.environ.get("GMAIL_TEST_EMAIL", "newyorkupperbay@gmail.com")
|
||||
creds = get_credentials_from_env(email, oauth=True, source="gmail")
|
||||
print("Credentials loaded successfully")
|
||||
print(f"{creds=}")
|
||||
|
||||
connector = GmailConnector(batch_size=2)
|
||||
print("GmailConnector initialized")
|
||||
connector.load_credentials(creds)
|
||||
print("Credentials loaded into connector")
|
||||
|
||||
print("Gmail is ready to use")
|
||||
|
||||
for file in connector._fetch_threads(
|
||||
int(time.time()) - 1 * 24 * 60 * 60,
|
||||
int(time.time()),
|
||||
):
|
||||
print("new batch","-"*80)
|
||||
for f in file:
|
||||
print(f)
|
||||
print("\n\n")
|
||||
except Exception as e:
|
||||
logging.exception(f"Error loading credentials: {e}")
|
||||
@ -1,7 +1,6 @@
|
||||
"""Google Drive connector"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@ -32,7 +31,6 @@ from common.data_source.google_drive.file_retrieval import (
|
||||
from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveCheckpoint, GoogleDriveFileType, RetrievedDriveFile, StageCompletion
|
||||
from common.data_source.google_util.auth import get_google_creds
|
||||
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, USER_FIELDS
|
||||
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
|
||||
from common.data_source.google_util.resource import GoogleDriveService, get_admin_service, get_drive_service
|
||||
from common.data_source.google_util.util import GoogleFields, execute_paginated_retrieval, get_file_owners
|
||||
from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict
|
||||
@ -1138,39 +1136,6 @@ class GoogleDriveConnector(SlimConnectorWithPermSync, CheckpointedConnectorWithP
|
||||
return GoogleDriveCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
|
||||
def get_credentials_from_env(email: str, oauth: bool = False) -> dict:
|
||||
try:
|
||||
if oauth:
|
||||
raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"]
|
||||
else:
|
||||
raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"]
|
||||
except KeyError:
|
||||
raise ValueError("Missing Google Drive credentials in environment variables")
|
||||
|
||||
try:
|
||||
credential_dict = json.loads(raw_credential_string)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid JSON in Google Drive credentials")
|
||||
|
||||
if oauth:
|
||||
credential_dict = ensure_oauth_token_dict(credential_dict, DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
refried_credential_string = json.dumps(credential_dict)
|
||||
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
|
||||
|
||||
cred_key = DB_CREDENTIALS_DICT_TOKEN_KEY if oauth else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
|
||||
return {
|
||||
cred_key: refried_credential_string,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
|
||||
}
|
||||
|
||||
|
||||
class CheckpointOutputWrapper:
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format.
|
||||
@ -1236,7 +1201,7 @@ def yield_all_docs_from_checkpoint_connector(
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
from common.data_source.google_util.util import get_credentials_from_env
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
try:
|
||||
@ -1245,7 +1210,7 @@ if __name__ == "__main__":
|
||||
creds = get_credentials_from_env(email, oauth=True)
|
||||
print("Credentials loaded successfully")
|
||||
print(f"{creds=}")
|
||||
|
||||
sys.exit(0)
|
||||
connector = GoogleDriveConnector(
|
||||
include_shared_drives=False,
|
||||
shared_drive_urls=None,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user