mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Compare commits
141 Commits
v0.22.0
...
856201c0f2
| Author | SHA1 | Date | |
|---|---|---|---|
| 856201c0f2 | |||
| 9d8b96c1d0 | |||
| 7c3c185038 | |||
| a9259917c6 | |||
| 8c28587821 | |||
| 12979a3f21 | |||
| 376eb15c63 | |||
| 89ba7abe30 | |||
| 2fd5ac1031 | |||
| 40e84ca41a | |||
| a28c672695 | |||
| 74e0b58d89 | |||
| 7c20c964b4 | |||
| 5d0981d046 | |||
| a793dd2ea8 | |||
| 915e385244 | |||
| 7a344a32f9 | |||
| 8c1ee3845a | |||
| 8c751d5afc | |||
| f5faf0c94f | |||
| af72e8dc33 | |||
| bcd70affb5 | |||
| 6987e9f23b | |||
| 41665b0865 | |||
| d1744aaaf3 | |||
| d5f8548200 | |||
| 4d8698624c | |||
| 1009819801 | |||
| 8fe782f4ea | |||
| 7140950e93 | |||
| 0181747881 | |||
| 3c41159d26 | |||
| e0e1d04da5 | |||
| f0a14f5fce | |||
| 174a2578e8 | |||
| a0959b9d38 | |||
| 13299197b8 | |||
| 249296e417 | |||
| db0f6840d9 | |||
| 1033a3ae26 | |||
| 1845daf41f | |||
| 4c8f9f0d77 | |||
| cc00c3ec93 | |||
| 653b785958 | |||
| 971c1bcba7 | |||
| 065917bf1c | |||
| 820934fc77 | |||
| d3d2ccc76c | |||
| c8ab9079b3 | |||
| 0d5589bfda | |||
| b846a0f547 | |||
| 69578ebfce | |||
| 06cef71ba6 | |||
| d2b1da0e26 | |||
| 7c6d30f4c8 | |||
| ea0352ee4a | |||
| fa5cf10f56 | |||
| 3fe71ab7dd | |||
| 9f715d6bc2 | |||
| 48de3b26ba | |||
| 273c4bc4d3 | |||
| 420c97199a | |||
| ecf0322165 | |||
| 38234aca53 | |||
| 1c06ec39ca | |||
| cfdccebb17 | |||
| 980a883033 | |||
| 02d429f0ca | |||
| 9c24d5d44a | |||
| 0cc5d7a8a6 | |||
| c43bf1dcf5 | |||
| f76b8279dd | |||
| db5ec89dc5 | |||
| 1c201c4d54 | |||
| ba78d0f0c2 | |||
| add8c63458 | |||
| 83661efdaf | |||
| 971197d595 | |||
| 0884e9a4d9 | |||
| 2de42f00b8 | |||
| e8fe580d7a | |||
| 62505164d5 | |||
| d1dcf3b43c | |||
| f84662d2ee | |||
| 1cb6b7f5dd | |||
| 023f509501 | |||
| 50bc53a1f5 | |||
| 8cd4882596 | |||
| 35e5fade93 | |||
| 4942a23290 | |||
| d1716d865a | |||
| c2b7c305fa | |||
| 341e5904c8 | |||
| ded9bf80c5 | |||
| fea157ba08 | |||
| 0db00f70b2 | |||
| 701761d119 | |||
| 2993fc666b | |||
| 8a6d205df0 | |||
| 912b6b023e | |||
| 89e8818dda | |||
| 1dba6b5bf9 | |||
| 3fcf2ee54c | |||
| d8f413a885 | |||
| 7264fb6978 | |||
| bd4bc57009 | |||
| 0569b50fed | |||
| 6b64641042 | |||
| 9cef3a2625 | |||
| e7e89d3ecb | |||
| 13e212c856 | |||
| 61cf430dbb | |||
| e841b09d63 | |||
| b1a1eedf53 | |||
| 68e3b33ae4 | |||
| cd55f6c1b8 | |||
| 996b5fe14e | |||
| db4fd19c82 | |||
| 12db62b9c7 | |||
| b5f2cf16bc | |||
| e27ff8d3d4 | |||
| 5f59418aba | |||
| 87e69868c0 | |||
| 72c20022f6 | |||
| 3f2472f1b9 | |||
| 1d4d67daf8 | |||
| 7538e218a5 | |||
| 6b52f7df5a | |||
| 63131ec9b2 | |||
| e8f1a245a6 | |||
| 908450509f | |||
| 70a0f081f6 | |||
| 93422fa8cc | |||
| bfc84ba95b | |||
| 871055b0fc | |||
| ba71160b14 | |||
| bd5dda6b10 | |||
| 774563970b | |||
| 83d84e90ed | |||
| 8ef2f79d0a | |||
| 296476ab89 |
36
.github/workflows/tests.yml
vendored
36
.github/workflows/tests.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
name: ragflow_tests
|
||||
# https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution
|
||||
# https://github.com/orgs/community/discussions/26261
|
||||
if: ${{ github.event_name != 'pull_request_target' || contains(github.event.pull_request.labels.*.name, 'ci') }}
|
||||
if: ${{ github.event_name != 'pull_request_target' || (contains(github.event.pull_request.labels.*.name, 'ci') && github.event.pull_request.mergeable == true) }}
|
||||
runs-on: [ "self-hosted", "ragflow-test" ]
|
||||
steps:
|
||||
# https://github.com/hmarr/debug-action
|
||||
@ -95,6 +95,38 @@ jobs:
|
||||
version: ">=0.11.x"
|
||||
args: "check"
|
||||
|
||||
- name: Check comments of changed Python files
|
||||
if: ${{ false }}
|
||||
run: |
|
||||
if [[ ${{ github.event_name }} == 'pull_request_target' ]]; then
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \
|
||||
| grep -E '\.(py)$' || true)
|
||||
|
||||
if [ -n "$CHANGED_FILES" ]; then
|
||||
echo "Check comments of changed Python files with check_comment_ascii.py"
|
||||
|
||||
readarray -t files <<< "$CHANGED_FILES"
|
||||
HAS_ERROR=0
|
||||
|
||||
for file in "${files[@]}"; do
|
||||
if [ -f "$file" ]; then
|
||||
if python3 check_comment_ascii.py "$file"; then
|
||||
echo "✅ $file"
|
||||
else
|
||||
echo "❌ $file"
|
||||
HAS_ERROR=1
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $HAS_ERROR -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "No Python files changed"
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Build ragflow:nightly
|
||||
run: |
|
||||
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}
|
||||
@ -161,7 +193,7 @@ jobs:
|
||||
echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> ${GITHUB_ENV}
|
||||
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
|
||||
uv sync --python 3.10 --only-group test --no-default-groups --frozen && uv pip install sdk/python
|
||||
uv sync --python 3.10 --only-group test --no-default-groups --frozen && uv pip install sdk/python --group test
|
||||
|
||||
- name: Run sdk tests against Elasticsearch
|
||||
run: |
|
||||
|
||||
@ -51,7 +51,9 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||
apt install -y libjemalloc-dev && \
|
||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
||||
apt install -y ghostscript
|
||||
apt install -y ghostscript && \
|
||||
apt install -y pandoc && \
|
||||
apt install -y texlive
|
||||
|
||||
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||
|
||||
15
README.md
15
README.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -85,7 +85,8 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
## 🔥 Latest Updates
|
||||
|
||||
- 2025-11-12 Supports data synchronization from Confluence, AWS S3, Discord, Google Drive.
|
||||
- 2025-11-19 Supports Gemini 3 Pro.
|
||||
- 2025-11-12 Supports data synchronization from Confluence, S3, Notion, Discord, Google Drive.
|
||||
- 2025-10-23 Supports MinerU & Docling as document parsing methods.
|
||||
- 2025-10-15 Supports orchestrable ingestion pipeline.
|
||||
- 2025-08-08 Supports OpenAI's latest GPT-5 series models.
|
||||
@ -93,8 +94,6 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
- 2025-05-23 Adds a Python/JavaScript code executor component to Agent.
|
||||
- 2025-05-05 Supports cross-language query.
|
||||
- 2025-03-19 Supports using a multi-modal model to make sense of images within PDF or DOCX files.
|
||||
- 2024-12-18 Upgrades Document Layout Analysis model in DeepDoc.
|
||||
- 2024-08-22 Support text to SQL statements through RAG.
|
||||
|
||||
## 🎉 Stay Tuned
|
||||
|
||||
@ -188,13 +187,15 @@ releases! 🌟
|
||||
> All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64.
|
||||
> If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system.
|
||||
|
||||
> The command below downloads the `v0.22.0` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.22.0`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
|
||||
> The command below downloads the `v0.22.1` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.22.1`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# git checkout v0.22.1
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
|
||||
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
|
||||
11
README_id.md
11
README_id.md
@ -22,7 +22,7 @@
|
||||
<img alt="Lencana Daring" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Rilis%20Terbaru" alt="Rilis Terbaru">
|
||||
@ -85,7 +85,8 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
## 🔥 Pembaruan Terbaru
|
||||
|
||||
- 2025-11-12 Mendukung sinkronisasi data dari Confluence, AWS S3, Discord, Google Drive.
|
||||
- 2025-11-19 Mendukung Gemini 3 Pro.
|
||||
- 2025-11-12 Mendukung sinkronisasi data dari Confluence, S3, Notion, Discord, Google Drive.
|
||||
- 2025-10-23 Mendukung MinerU & Docling sebagai metode penguraian dokumen.
|
||||
- 2025-10-15 Dukungan untuk jalur data yang terorkestrasi.
|
||||
- 2025-08-08 Mendukung model seri GPT-5 terbaru dari OpenAI.
|
||||
@ -186,12 +187,14 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
> Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64.
|
||||
> Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image).
|
||||
|
||||
> Perintah di bawah ini mengunduh edisi v0.22.0 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.22.0, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
|
||||
> Perintah di bawah ini mengunduh edisi v0.22.1 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.22.1, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases), contoh: git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases)
|
||||
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
11
README_ja.md
11
README_ja.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -66,7 +66,8 @@
|
||||
|
||||
## 🔥 最新情報
|
||||
|
||||
- 2025-11-12 Confluence、AWS S3、Discord、Google Drive からのデータ同期をサポートします。
|
||||
- 2025-11-19 Gemini 3 Proをサポートしています
|
||||
- 2025-11-12 Confluence、S3、Notion、Discord、Google Drive からのデータ同期をサポートします。
|
||||
- 2025-10-23 ドキュメント解析方法として MinerU と Docling をサポートします。
|
||||
- 2025-10-15 オーケストレーションされたデータパイプラインのサポート。
|
||||
- 2025-08-08 OpenAI の最新 GPT-5 シリーズモデルをサポートします。
|
||||
@ -166,12 +167,14 @@
|
||||
> 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。
|
||||
> ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。
|
||||
|
||||
> 以下のコマンドは、RAGFlow Docker イメージの v0.22.0 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.22.0 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
|
||||
> 以下のコマンドは、RAGFlow Docker イメージの v0.22.1 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.22.1 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) 例: git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases)
|
||||
# この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
11
README_ko.md
11
README_ko.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -67,7 +67,8 @@
|
||||
|
||||
## 🔥 업데이트
|
||||
|
||||
- 2025-11-12 Confluence, AWS S3, Discord, Google Drive에서 데이터 동기화를 지원합니다.
|
||||
- 2025-11-19 Gemini 3 Pro를 지원합니다.
|
||||
- 2025-11-12 Confluence, S3, Notion, Discord, Google Drive에서 데이터 동기화를 지원합니다.
|
||||
- 2025-10-23 문서 파싱 방법으로 MinerU 및 Docling을 지원합니다.
|
||||
- 2025-10-15 조정된 데이터 파이프라인 지원.
|
||||
- 2025-08-08 OpenAI의 최신 GPT-5 시리즈 모델을 지원합니다.
|
||||
@ -168,12 +169,14 @@
|
||||
> 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다.
|
||||
> ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image).
|
||||
|
||||
> 아래 명령어는 RAGFlow Docker 이미지의 v0.22.0 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.22.0과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
|
||||
> 아래 명령어는 RAGFlow Docker 이미지의 v0.22.1 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.22.1과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
|
||||
# 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
<img alt="Badge Estático" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Última%20Relese" alt="Última Versão">
|
||||
@ -86,7 +86,8 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
## 🔥 Últimas Atualizações
|
||||
|
||||
- 12-11-2025 Suporta a sincronização de dados do Confluence, AWS S3, Discord e Google Drive.
|
||||
- 19-11-2025 Suporta Gemini 3 Pro.
|
||||
- 12-11-2025 Suporta a sincronização de dados do Confluence, S3, Notion, Discord e Google Drive.
|
||||
- 23-10-2025 Suporta MinerU e Docling como métodos de análise de documentos.
|
||||
- 15-10-2025 Suporte para pipelines de dados orquestrados.
|
||||
- 08-08-2025 Suporta a mais recente série GPT-5 da OpenAI.
|
||||
@ -186,12 +187,14 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
> Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64.
|
||||
> Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema.
|
||||
|
||||
> O comando abaixo baixa a edição`v0.22.0` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.22.0`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
|
||||
> O comando abaixo baixa a edição`v0.22.1` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.22.1`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases), ex.: git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases)
|
||||
# Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -85,7 +85,8 @@
|
||||
|
||||
## 🔥 近期更新
|
||||
|
||||
- 2025-11-12 支援從 Confluence、AWS S3、Discord、Google Drive 進行資料同步。
|
||||
- 2025-11-19 支援 Gemini 3 Pro.
|
||||
- 2025-11-12 支援從 Confluence、S3、Notion、Discord、Google Drive 進行資料同步。
|
||||
- 2025-10-23 支援 MinerU 和 Docling 作為文件解析方法。
|
||||
- 2025-10-15 支援可編排的資料管道。
|
||||
- 2025-08-08 支援 OpenAI 最新的 GPT-5 系列模型。
|
||||
@ -185,12 +186,14 @@
|
||||
> 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。
|
||||
> 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。
|
||||
|
||||
> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.22.0`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.22.0` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
|
||||
> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.22.1`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.22.1` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases),例:git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases)
|
||||
# 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
11
README_zh.md
11
README_zh.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.1">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -85,7 +85,8 @@
|
||||
|
||||
## 🔥 近期更新
|
||||
|
||||
- 2025-11-12 支持从 Confluence、AWS S3、Discord、Google Drive 进行数据同步。
|
||||
- 2025-11-19 支持 Gemini 3 Pro.
|
||||
- 2025-11-12 支持从 Confluence、S3、Notion、Discord、Google Drive 进行数据同步。
|
||||
- 2025-10-23 支持 MinerU 和 Docling 作为文档解析方法。
|
||||
- 2025-10-15 支持可编排的数据管道。
|
||||
- 2025-08-08 支持 OpenAI 最新的 GPT-5 系列模型。
|
||||
@ -186,12 +187,14 @@
|
||||
> 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。
|
||||
> 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。
|
||||
|
||||
> 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.22.0`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.22.0` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
|
||||
> 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.22.1`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.22.1` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
# 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases),例如:git checkout v0.22.0
|
||||
# git checkout v0.22.1
|
||||
# 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases)
|
||||
# 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
Admin Service is a dedicated management component designed to monitor, maintain, and administrate the RAGFlow system. It provides comprehensive tools for ensuring system stability, performing operational tasks, and managing users and permissions efficiently.
|
||||
|
||||
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
|
||||
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Infinity, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
|
||||
|
||||
For user and system management, it supports listing, creating, modifying, and deleting users and their associated resources like knowledge bases and Agents.
|
||||
|
||||
@ -48,7 +48,7 @@ It consists of a server-side Service and a command-line client (CLI), both imple
|
||||
1. Ensure the Admin Service is running.
|
||||
2. Install ragflow-cli.
|
||||
```bash
|
||||
pip install ragflow-cli==0.22.0
|
||||
pip install ragflow-cli==0.22.1
|
||||
```
|
||||
3. Launch the CLI client:
|
||||
```bash
|
||||
|
||||
@ -378,7 +378,7 @@ class AdminCLI(Cmd):
|
||||
self.session.headers.update({
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': response.headers['Authorization'],
|
||||
'User-Agent': 'RAGFlow-CLI/0.22.0'
|
||||
'User-Agent': 'RAGFlow-CLI/0.22.1'
|
||||
})
|
||||
print("Authentication successful.")
|
||||
return True
|
||||
@ -393,7 +393,9 @@ class AdminCLI(Cmd):
|
||||
print(f"Can't access {self.host}, port: {self.port}")
|
||||
|
||||
def _format_service_detail_table(self, data):
|
||||
if not any([isinstance(v, list) for v in data.values()]):
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
if not all([isinstance(v, list) for v in data.values()]):
|
||||
# normal table
|
||||
return data
|
||||
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
|
||||
@ -404,7 +406,7 @@ class AdminCLI(Cmd):
|
||||
task_executor_list.append({
|
||||
"task_executor_name": k,
|
||||
**heartbeats[0],
|
||||
})
|
||||
} if heartbeats else {"task_executor_name": k})
|
||||
return task_executor_list
|
||||
|
||||
def _print_table_simple(self, data):
|
||||
@ -415,7 +417,8 @@ class AdminCLI(Cmd):
|
||||
# handle single row data
|
||||
data = [data]
|
||||
|
||||
columns = list(data[0].keys())
|
||||
columns = list(set().union(*(d.keys() for d in data)))
|
||||
columns.sort()
|
||||
col_widths = {}
|
||||
|
||||
def get_string_width(text):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ragflow-cli"
|
||||
version = "0.22.0"
|
||||
version = "0.22.1"
|
||||
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
|
||||
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
|
||||
license = { text = "Apache License, Version 2.0" }
|
||||
@ -8,7 +8,7 @@ readme = "README.md"
|
||||
requires-python = ">=3.10,<3.13"
|
||||
dependencies = [
|
||||
"requests>=2.30.0,<3.0.0",
|
||||
"beartype>=0.18.5,<0.19.0",
|
||||
"beartype>=0.20.0,<1.0.0",
|
||||
"pycryptodomex>=3.10.0",
|
||||
"lark>=1.1.0",
|
||||
]
|
||||
|
||||
298
admin/client/uv.lock
generated
Normal file
298
admin/client/uv.lock
generated
Normal file
@ -0,0 +1,298 @@
|
||||
version = 1
|
||||
revision = 3
|
||||
requires-python = ">=3.10, <3.13"
|
||||
|
||||
[[package]]
|
||||
name = "beartype"
|
||||
version = "0.22.6"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/88/e2/105ceb1704cb80fe4ab3872529ab7b6f365cf7c74f725e6132d0efcf1560/beartype-0.22.6.tar.gz", hash = "sha256:97fbda69c20b48c5780ac2ca60ce3c1bb9af29b3a1a0216898ffabdd523e48f4", size = 1588975, upload-time = "2025-11-20T04:47:14.736Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/98/c9/ceecc71fe2c9495a1d8e08d44f5f31f5bca1350d5b2e27a4b6265424f59e/beartype-0.22.6-py3-none-any.whl", hash = "sha256:0584bc46a2ea2a871509679278cda992eadde676c01356ab0ac77421f3c9a093", size = 1324807, upload-time = "2025-11-20T04:47:11.837Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2025.11.12"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a2/8c/58f469717fa48465e4a50c014a0400602d3c437d7c0c468e17ada824da3a/certifi-2025.11.12.tar.gz", hash = "sha256:d8ab5478f2ecd78af242878415affce761ca6bc54a22a27e026d7c25357c3316", size = 160538, upload-time = "2025-11-12T02:54:51.517Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/70/7d/9bc192684cea499815ff478dfcdc13835ddf401365057044fb721ec6bddb/certifi-2025.11.12-py3-none-any.whl", hash = "sha256:97de8790030bbd5c2d96b7ec782fc2f7820ef8dba6db909ccf95449f2d062d4b", size = 159438, upload-time = "2025-11-12T02:54:49.735Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "charset-normalizer"
|
||||
version = "3.4.4"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418, upload-time = "2025-10-14T04:42:32.879Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/b8/6d51fc1d52cbd52cd4ccedd5b5b2f0f6a11bbf6765c782298b0f3e808541/charset_normalizer-3.4.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e824f1492727fa856dd6eda4f7cee25f8518a12f3c4a56a74e8095695089cf6d", size = 209709, upload-time = "2025-10-14T04:40:11.385Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/5c/af/1f9d7f7faafe2ddfb6f72a2e07a548a629c61ad510fe60f9630309908fef/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4bd5d4137d500351a30687c2d3971758aac9a19208fc110ccb9d7188fbe709e8", size = 148814, upload-time = "2025-10-14T04:40:13.135Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/79/3d/f2e3ac2bbc056ca0c204298ea4e3d9db9b4afe437812638759db2c976b5f/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:027f6de494925c0ab2a55eab46ae5129951638a49a34d87f4c3eda90f696b4ad", size = 144467, upload-time = "2025-10-14T04:40:14.728Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ec/85/1bf997003815e60d57de7bd972c57dc6950446a3e4ccac43bc3070721856/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f820802628d2694cb7e56db99213f930856014862f3fd943d290ea8438d07ca8", size = 162280, upload-time = "2025-10-14T04:40:16.14Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/3e/8e/6aa1952f56b192f54921c436b87f2aaf7c7a7c3d0d1a765547d64fd83c13/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:798d75d81754988d2565bff1b97ba5a44411867c0cf32b77a7e8f8d84796b10d", size = 159454, upload-time = "2025-10-14T04:40:17.567Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/36/3b/60cbd1f8e93aa25d1c669c649b7a655b0b5fb4c571858910ea9332678558/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d1bb833febdff5c8927f922386db610b49db6e0d4f4ee29601d71e7c2694313", size = 153609, upload-time = "2025-10-14T04:40:19.08Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/64/91/6a13396948b8fd3c4b4fd5bc74d045f5637d78c9675585e8e9fbe5636554/charset_normalizer-3.4.4-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:9cd98cdc06614a2f768d2b7286d66805f94c48cde050acdbbb7db2600ab3197e", size = 151849, upload-time = "2025-10-14T04:40:20.607Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b7/7a/59482e28b9981d105691e968c544cc0df3b7d6133152fb3dcdc8f135da7a/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:077fbb858e903c73f6c9db43374fd213b0b6a778106bc7032446a8e8b5b38b93", size = 151586, upload-time = "2025-10-14T04:40:21.719Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/92/59/f64ef6a1c4bdd2baf892b04cd78792ed8684fbc48d4c2afe467d96b4df57/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:244bfb999c71b35de57821b8ea746b24e863398194a4014e4c76adc2bbdfeff0", size = 145290, upload-time = "2025-10-14T04:40:23.069Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6b/63/3bf9f279ddfa641ffa1962b0db6a57a9c294361cc2f5fcac997049a00e9c/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:64b55f9dce520635f018f907ff1b0df1fdc31f2795a922fb49dd14fbcdf48c84", size = 163663, upload-time = "2025-10-14T04:40:24.17Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ed/09/c9e38fc8fa9e0849b172b581fd9803bdf6e694041127933934184e19f8c3/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:faa3a41b2b66b6e50f84ae4a68c64fcd0c44355741c6374813a800cd6695db9e", size = 151964, upload-time = "2025-10-14T04:40:25.368Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d2/d1/d28b747e512d0da79d8b6a1ac18b7ab2ecfd81b2944c4c710e166d8dd09c/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:6515f3182dbe4ea06ced2d9e8666d97b46ef4c75e326b79bb624110f122551db", size = 161064, upload-time = "2025-10-14T04:40:26.806Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/bb/9a/31d62b611d901c3b9e5500c36aab0ff5eb442043fb3a1c254200d3d397d9/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cc00f04ed596e9dc0da42ed17ac5e596c6ccba999ba6bd92b0e0aef2f170f2d6", size = 155015, upload-time = "2025-10-14T04:40:28.284Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/f3/107e008fa2bff0c8b9319584174418e5e5285fef32f79d8ee6a430d0039c/charset_normalizer-3.4.4-cp310-cp310-win32.whl", hash = "sha256:f34be2938726fc13801220747472850852fe6b1ea75869a048d6f896838c896f", size = 99792, upload-time = "2025-10-14T04:40:29.613Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/eb/66/e396e8a408843337d7315bab30dbf106c38966f1819f123257f5520f8a96/charset_normalizer-3.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:a61900df84c667873b292c3de315a786dd8dac506704dea57bc957bd31e22c7d", size = 107198, upload-time = "2025-10-14T04:40:30.644Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b5/58/01b4f815bf0312704c267f2ccb6e5d42bcc7752340cd487bc9f8c3710597/charset_normalizer-3.4.4-cp310-cp310-win_arm64.whl", hash = "sha256:cead0978fc57397645f12578bfd2d5ea9138ea0fac82b2f63f7f7c6877986a69", size = 100262, upload-time = "2025-10-14T04:40:32.108Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ed/27/c6491ff4954e58a10f69ad90aca8a1b6fe9c5d3c6f380907af3c37435b59/charset_normalizer-3.4.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6e1fcf0720908f200cd21aa4e6750a48ff6ce4afe7ff5a79a90d5ed8a08296f8", size = 206988, upload-time = "2025-10-14T04:40:33.79Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/94/59/2e87300fe67ab820b5428580a53cad894272dbb97f38a7a814a2a1ac1011/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5f819d5fe9234f9f82d75bdfa9aef3a3d72c4d24a6e57aeaebba32a704553aa0", size = 147324, upload-time = "2025-10-14T04:40:34.961Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/07/fb/0cf61dc84b2b088391830f6274cb57c82e4da8bbc2efeac8c025edb88772/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a59cb51917aa591b1c4e6a43c132f0cdc3c76dbad6155df4e28ee626cc77a0a3", size = 142742, upload-time = "2025-10-14T04:40:36.105Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/62/8b/171935adf2312cd745d290ed93cf16cf0dfe320863ab7cbeeae1dcd6535f/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8ef3c867360f88ac904fd3f5e1f902f13307af9052646963ee08ff4f131adafc", size = 160863, upload-time = "2025-10-14T04:40:37.188Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/09/73/ad875b192bda14f2173bfc1bc9a55e009808484a4b256748d931b6948442/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d9e45d7faa48ee908174d8fe84854479ef838fc6a705c9315372eacbc2f02897", size = 157837, upload-time = "2025-10-14T04:40:38.435Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6d/fc/de9cce525b2c5b94b47c70a4b4fb19f871b24995c728e957ee68ab1671ea/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:840c25fb618a231545cbab0564a799f101b63b9901f2569faecd6b222ac72381", size = 151550, upload-time = "2025-10-14T04:40:40.053Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/55/c2/43edd615fdfba8c6f2dfbd459b25a6b3b551f24ea21981e23fb768503ce1/charset_normalizer-3.4.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca5862d5b3928c4940729dacc329aa9102900382fea192fc5e52eb69d6093815", size = 149162, upload-time = "2025-10-14T04:40:41.163Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/03/86/bde4ad8b4d0e9429a4e82c1e8f5c659993a9a863ad62c7df05cf7b678d75/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d9c7f57c3d666a53421049053eaacdd14bbd0a528e2186fcb2e672effd053bb0", size = 150019, upload-time = "2025-10-14T04:40:42.276Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/86/a151eb2af293a7e7bac3a739b81072585ce36ccfb4493039f49f1d3cae8c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:277e970e750505ed74c832b4bf75dac7476262ee2a013f5574dd49075879e161", size = 143310, upload-time = "2025-10-14T04:40:43.439Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b5/fe/43dae6144a7e07b87478fdfc4dbe9efd5defb0e7ec29f5f58a55aeef7bf7/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:31fd66405eaf47bb62e8cd575dc621c56c668f27d46a61d975a249930dd5e2a4", size = 162022, upload-time = "2025-10-14T04:40:44.547Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/80/e6/7aab83774f5d2bca81f42ac58d04caf44f0cc2b65fc6db2b3b2e8a05f3b3/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:0d3d8f15c07f86e9ff82319b3d9ef6f4bf907608f53fe9d92b28ea9ae3d1fd89", size = 149383, upload-time = "2025-10-14T04:40:46.018Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/4f/e8/b289173b4edae05c0dde07f69f8db476a0b511eac556dfe0d6bda3c43384/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:9f7fcd74d410a36883701fafa2482a6af2ff5ba96b9a620e9e0721e28ead5569", size = 159098, upload-time = "2025-10-14T04:40:47.081Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d8/df/fe699727754cae3f8478493c7f45f777b17c3ef0600e28abfec8619eb49c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ebf3e58c7ec8a8bed6d66a75d7fb37b55e5015b03ceae72a8e7c74495551e224", size = 152991, upload-time = "2025-10-14T04:40:48.246Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1a/86/584869fe4ddb6ffa3bd9f491b87a01568797fb9bd8933f557dba9771beaf/charset_normalizer-3.4.4-cp311-cp311-win32.whl", hash = "sha256:eecbc200c7fd5ddb9a7f16c7decb07b566c29fa2161a16cf67b8d068bd21690a", size = 99456, upload-time = "2025-10-14T04:40:49.376Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/65/f6/62fdd5feb60530f50f7e38b4f6a1d5203f4d16ff4f9f0952962c044e919a/charset_normalizer-3.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:5ae497466c7901d54b639cf42d5b8c1b6a4fead55215500d2f486d34db48d016", size = 106978, upload-time = "2025-10-14T04:40:50.844Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/7a/9d/0710916e6c82948b3be62d9d398cb4fcf4e97b56d6a6aeccd66c4b2f2bd5/charset_normalizer-3.4.4-cp311-cp311-win_arm64.whl", hash = "sha256:65e2befcd84bc6f37095f5961e68a6f077bf44946771354a28ad434c2cce0ae1", size = 99969, upload-time = "2025-10-14T04:40:52.272Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f3/85/1637cd4af66fa687396e757dec650f28025f2a2f5a5531a3208dc0ec43f2/charset_normalizer-3.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394", size = 208425, upload-time = "2025-10-14T04:40:53.353Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/9d/6a/04130023fef2a0d9c62d0bae2649b69f7b7d8d24ea5536feef50551029df/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25", size = 148162, upload-time = "2025-10-14T04:40:54.558Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/78/29/62328d79aa60da22c9e0b9a66539feae06ca0f5a4171ac4f7dc285b83688/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef", size = 144558, upload-time = "2025-10-14T04:40:55.677Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/86/bb/b32194a4bf15b88403537c2e120b817c61cd4ecffa9b6876e941c3ee38fe/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d", size = 161497, upload-time = "2025-10-14T04:40:57.217Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/19/89/a54c82b253d5b9b111dc74aca196ba5ccfcca8242d0fb64146d4d3183ff1/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8", size = 159240, upload-time = "2025-10-14T04:40:58.358Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86", size = 153471, upload-time = "2025-10-14T04:40:59.468Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/61/fa/fbf177b55bdd727010f9c0a3c49eefa1d10f960e5f09d1d887bf93c2e698/charset_normalizer-3.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a", size = 150864, upload-time = "2025-10-14T04:41:00.623Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/05/12/9fbc6a4d39c0198adeebbde20b619790e9236557ca59fc40e0e3cebe6f40/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f", size = 150647, upload-time = "2025-10-14T04:41:01.754Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ad/1f/6a9a593d52e3e8c5d2b167daf8c6b968808efb57ef4c210acb907c365bc4/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc", size = 145110, upload-time = "2025-10-14T04:41:03.231Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/30/42/9a52c609e72471b0fc54386dc63c3781a387bb4fe61c20231a4ebcd58bdd/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf", size = 162839, upload-time = "2025-10-14T04:41:04.715Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c4/5b/c0682bbf9f11597073052628ddd38344a3d673fda35a36773f7d19344b23/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15", size = 150667, upload-time = "2025-10-14T04:41:05.827Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/e4/24/a41afeab6f990cf2daf6cb8c67419b63b48cf518e4f56022230840c9bfb2/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9", size = 160535, upload-time = "2025-10-14T04:41:06.938Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/2a/e5/6a4ce77ed243c4a50a1fecca6aaaab419628c818a49434be428fe24c9957/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0", size = 154816, upload-time = "2025-10-14T04:41:08.101Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a8/ef/89297262b8092b312d29cdb2517cb1237e51db8ecef2e9af5edbe7b683b1/charset_normalizer-3.4.4-cp312-cp312-win32.whl", hash = "sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26", size = 99694, upload-time = "2025-10-14T04:41:09.23Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/3d/2d/1e5ed9dd3b3803994c155cd9aacb60c82c331bad84daf75bcb9c91b3295e/charset_normalizer-3.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525", size = 107131, upload-time = "2025-10-14T04:41:10.467Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d0/d9/0ed4c7098a861482a7b6a95603edce4c0d9db2311af23da1fb2b75ec26fc/charset_normalizer-3.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3", size = 100390, upload-time = "2025-10-14T04:41:11.915Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.3.1"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.11"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582, upload-time = "2025-10-12T14:55:20.501Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lark"
|
||||
version = "1.3.1"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/da/34/28fff3ab31ccff1fd4f6c7c7b0ceb2b6968d8ea4950663eadcb5720591a0/lark-1.3.1.tar.gz", hash = "sha256:b426a7a6d6d53189d318f2b6236ab5d6429eaf09259f1ca33eb716eed10d2905", size = 382732, upload-time = "2025-10-27T18:25:56.653Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/82/3d/14ce75ef66813643812f3093ab17e46d3a206942ce7376d31ec2d36229e7/lark-1.3.1-py3-none-any.whl", hash = "sha256:c629b661023a014c37da873b4ff58a817398d12635d3bbb2c5a03be7fe5d1e12", size = 113151, upload-time = "2025-10-27T18:25:54.882Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "25.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.6.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pycryptodomex"
|
||||
version = "3.23.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c9/85/e24bf90972a30b0fcd16c73009add1d7d7cd9140c2498a68252028899e41/pycryptodomex-3.23.0.tar.gz", hash = "sha256:71909758f010c82bc99b0abf4ea12012c98962fbf0583c2164f8b84533c2e4da", size = 4922157, upload-time = "2025-05-17T17:23:41.434Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/dd/9c/1a8f35daa39784ed8adf93a694e7e5dc15c23c741bbda06e1d45f8979e9e/pycryptodomex-3.23.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:06698f957fe1ab229a99ba2defeeae1c09af185baa909a31a5d1f9d42b1aaed6", size = 2499240, upload-time = "2025-05-17T17:22:46.953Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/7a/62/f5221a191a97157d240cf6643747558759126c76ee92f29a3f4aee3197a5/pycryptodomex-3.23.0-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:b2c2537863eccef2d41061e82a881dcabb04944c5c06c5aa7110b577cc487545", size = 1644042, upload-time = "2025-05-17T17:22:49.098Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/8c/fd/5a054543c8988d4ed7b612721d7e78a4b9bf36bc3c5ad45ef45c22d0060e/pycryptodomex-3.23.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:43c446e2ba8df8889e0e16f02211c25b4934898384c1ec1ec04d7889c0333587", size = 2186227, upload-time = "2025-05-17T17:22:51.139Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c8/a9/8862616a85cf450d2822dbd4fff1fcaba90877907a6ff5bc2672cafe42f8/pycryptodomex-3.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f489c4765093fb60e2edafdf223397bc716491b2b69fe74367b70d6999257a5c", size = 2272578, upload-time = "2025-05-17T17:22:53.676Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/46/9f/bda9c49a7c1842820de674ab36c79f4fbeeee03f8ff0e4f3546c3889076b/pycryptodomex-3.23.0-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bdc69d0d3d989a1029df0eed67cc5e8e5d968f3724f4519bd03e0ec68df7543c", size = 2312166, upload-time = "2025-05-17T17:22:56.585Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/03/cc/870b9bf8ca92866ca0186534801cf8d20554ad2a76ca959538041b7a7cf4/pycryptodomex-3.23.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:6bbcb1dd0f646484939e142462d9e532482bc74475cecf9c4903d4e1cd21f003", size = 2185467, upload-time = "2025-05-17T17:22:59.237Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/96/e3/ce9348236d8e669fea5dd82a90e86be48b9c341210f44e25443162aba187/pycryptodomex-3.23.0-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:8a4fcd42ccb04c31268d1efeecfccfd1249612b4de6374205376b8f280321744", size = 2346104, upload-time = "2025-05-17T17:23:02.112Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a5/e9/e869bcee87beb89040263c416a8a50204f7f7a83ac11897646c9e71e0daf/pycryptodomex-3.23.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:55ccbe27f049743a4caf4f4221b166560d3438d0b1e5ab929e07ae1702a4d6fd", size = 2271038, upload-time = "2025-05-17T17:23:04.872Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/8d/67/09ee8500dd22614af5fbaa51a4aee6e342b5fa8aecf0a6cb9cbf52fa6d45/pycryptodomex-3.23.0-cp37-abi3-win32.whl", hash = "sha256:189afbc87f0b9f158386bf051f720e20fa6145975f1e76369303d0f31d1a8d7c", size = 1771969, upload-time = "2025-05-17T17:23:07.115Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/69/96/11f36f71a865dd6df03716d33bd07a67e9d20f6b8d39820470b766af323c/pycryptodomex-3.23.0-cp37-abi3-win_amd64.whl", hash = "sha256:52e5ca58c3a0b0bd5e100a9fbc8015059b05cffc6c66ce9d98b4b45e023443b9", size = 1803124, upload-time = "2025-05-17T17:23:09.267Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/93/45c1cdcbeb182ccd2e144c693eaa097763b08b38cded279f0053ed53c553/pycryptodomex-3.23.0-cp37-abi3-win_arm64.whl", hash = "sha256:02d87b80778c171445d67e23d1caef279bf4b25c3597050ccd2e13970b57fd51", size = 1707161, upload-time = "2025-05-17T17:23:11.414Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f3/b8/3e76d948c3c4ac71335bbe75dac53e154b40b0f8f1f022dfa295257a0c96/pycryptodomex-3.23.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ebfff755c360d674306e5891c564a274a47953562b42fb74a5c25b8fc1fb1cb5", size = 1627695, upload-time = "2025-05-17T17:23:17.38Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6a/cf/80f4297a4820dfdfd1c88cf6c4666a200f204b3488103d027b5edd9176ec/pycryptodomex-3.23.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eca54f4bb349d45afc17e3011ed4264ef1cc9e266699874cdd1349c504e64798", size = 1675772, upload-time = "2025-05-17T17:23:19.202Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/d1/42/1e969ee0ad19fe3134b0e1b856c39bd0b70d47a4d0e81c2a8b05727394c9/pycryptodomex-3.23.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f2596e643d4365e14d0879dc5aafe6355616c61c2176009270f3048f6d9a61f", size = 1668083, upload-time = "2025-05-17T17:23:21.867Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/6e/c3/1de4f7631fea8a992a44ba632aa40e0008764c0fb9bf2854b0acf78c2cf2/pycryptodomex-3.23.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fdfac7cda115bca3a5abb2f9e43bc2fb66c2b65ab074913643803ca7083a79ea", size = 1706056, upload-time = "2025-05-17T17:23:24.031Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f2/5f/af7da8e6f1e42b52f44a24d08b8e4c726207434e2593732d39e7af5e7256/pycryptodomex-3.23.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:14c37aaece158d0ace436f76a7bb19093db3b4deade9797abfc39ec6cd6cc2fe", size = 1806478, upload-time = "2025-05-17T17:23:26.066Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.19.2"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "9.0.1"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||
{ name = "iniconfig" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pluggy" },
|
||||
{ name = "pygments" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/07/56/f013048ac4bc4c1d9be45afd4ab209ea62822fb1598f40687e6bf45dcea4/pytest-9.0.1.tar.gz", hash = "sha256:3e9c069ea73583e255c3b21cf46b8d3c56f6e3a1a8f6da94ccb0fcf57b9d73c8", size = 1564125, upload-time = "2025-11-12T13:05:09.333Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0b/8b/6300fb80f858cda1c51ffa17075df5d846757081d11ab4aa35cef9e6258b/pytest-9.0.1-py3-none-any.whl", hash = "sha256:67be0030d194df2dfa7b556f2e56fb3c3315bd5c8822c6951162b92b32ce7dad", size = 373668, upload-time = "2025-11-12T13:05:07.379Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ragflow-cli"
|
||||
version = "0.22.1"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "beartype" },
|
||||
{ name = "lark" },
|
||||
{ name = "pycryptodomex" },
|
||||
{ name = "requests" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
test = [
|
||||
{ name = "pytest" },
|
||||
{ name = "requests" },
|
||||
{ name = "requests-toolbelt" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "beartype", specifier = ">=0.20.0,<1.0.0" },
|
||||
{ name = "lark", specifier = ">=1.1.0" },
|
||||
{ name = "pycryptodomex", specifier = ">=3.10.0" },
|
||||
{ name = "requests", specifier = ">=2.30.0,<3.0.0" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
test = [
|
||||
{ name = "pytest", specifier = ">=8.3.5" },
|
||||
{ name = "requests", specifier = ">=2.32.3" },
|
||||
{ name = "requests-toolbelt", specifier = ">=1.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.5"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
{ name = "charset-normalizer" },
|
||||
{ name = "idna" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "requests-toolbelt"
|
||||
version = "1.0.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
dependencies = [
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f3/61/d7545dafb7ac2230c70d38d31cbfe4cc64f7144dc41f6e4e4b78ecd9f5bb/requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6", size = 206888, upload-time = "2023-05-01T04:11:33.229Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481, upload-time = "2023-05-01T04:11:28.427Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/b6/eb/a7679c8ac85208706d27436e8d421dfa39d4c914dcf5fa8083a9305f58d9/tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456", size = 96445, upload-time = "2025-10-08T22:01:06.896Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a0/b7/a7a7042715d55c9ba6e8b196d65d2cb662578b4d8cd17d882d45322b0d78/tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876", size = 97124, upload-time = "2025-10-08T22:01:15.629Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.15.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.5.0"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" },
|
||||
]
|
||||
@ -20,8 +20,11 @@ import logging
|
||||
import time
|
||||
import threading
|
||||
import traceback
|
||||
from werkzeug.serving import run_simple
|
||||
import faulthandler
|
||||
|
||||
from flask import Flask
|
||||
from flask_login import LoginManager
|
||||
from werkzeug.serving import run_simple
|
||||
from routes import admin_bp
|
||||
from common.log_utils import init_root_logger
|
||||
from common.constants import SERVICE_CONF
|
||||
@ -30,12 +33,12 @@ from common import settings
|
||||
from config import load_configurations, SERVICE_CONFIGS
|
||||
from auth import init_default_admin, setup_auth
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from common.versions import get_ragflow_version
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
if __name__ == '__main__':
|
||||
faulthandler.enable()
|
||||
init_root_logger("admin_service")
|
||||
logging.info(r"""
|
||||
____ ___ ______________ ___ __ _
|
||||
|
||||
@ -19,7 +19,8 @@ import logging
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from datetime import datetime
|
||||
from flask import request, jsonify
|
||||
|
||||
from flask import jsonify, request
|
||||
from flask_login import current_user, login_user
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
@ -30,7 +31,7 @@ from common.constants import ActiveEnum, StatusEnum
|
||||
from api.utils.crypt import decrypt
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp, datetime_format, get_format_time
|
||||
from common.connection_utils import construct_response
|
||||
from common.connection_utils import sync_construct_response
|
||||
from common import settings
|
||||
|
||||
|
||||
@ -129,7 +130,7 @@ def login_admin(email: str, password: str):
|
||||
user.last_login_time = get_format_time()
|
||||
user.save()
|
||||
msg = "Welcome back!"
|
||||
return construct_response(data=resp, auth=user.get_id(), message=msg)
|
||||
return sync_construct_response(data=resp, auth=user.get_id(), message=msg)
|
||||
|
||||
|
||||
def check_admin(username: str, password: str):
|
||||
@ -169,7 +170,7 @@ def login_verify(f):
|
||||
username = auth.parameters['username']
|
||||
password = auth.parameters['password']
|
||||
try:
|
||||
if check_admin(username, password) is False:
|
||||
if not check_admin(username, password):
|
||||
return jsonify({
|
||||
"code": 500,
|
||||
"message": "Access denied",
|
||||
|
||||
@ -25,8 +25,21 @@ from common.config_utils import read_config
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
host: str
|
||||
port: int
|
||||
service_type: str
|
||||
detail_func_name: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
|
||||
'service_type': self.service_type}
|
||||
|
||||
|
||||
class ServiceConfigs:
|
||||
configs = dict
|
||||
configs = list[BaseConfig]
|
||||
|
||||
def __init__(self):
|
||||
self.configs = []
|
||||
@ -45,19 +58,6 @@ class ServiceType(Enum):
|
||||
FILE_STORE = "file_store"
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
host: str
|
||||
port: int
|
||||
service_type: str
|
||||
detail_func_name: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
|
||||
'service_type': self.service_type}
|
||||
|
||||
|
||||
class MetaConfig(BaseConfig):
|
||||
meta_type: str
|
||||
|
||||
@ -227,7 +227,7 @@ def load_configurations(config_path: str) -> list[BaseConfig]:
|
||||
ragflow_count = 0
|
||||
id_count = 0
|
||||
for k, v in raw_configs.items():
|
||||
match (k):
|
||||
match k:
|
||||
case "ragflow":
|
||||
name: str = f'ragflow_{ragflow_count}'
|
||||
host: str = v['host']
|
||||
|
||||
@ -13,8 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
from flask import jsonify
|
||||
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
import secrets
|
||||
|
||||
from flask import Blueprint, request
|
||||
from flask_login import current_user, logout_user, login_required
|
||||
from flask_login import current_user, login_required, logout_user
|
||||
|
||||
from auth import login_verify, login_admin, check_admin_auth
|
||||
from responses import success_response, error_response
|
||||
|
||||
@ -13,8 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import logging
|
||||
import re
|
||||
from werkzeug.security import check_password_hash
|
||||
from common.constants import ActiveEnum
|
||||
@ -190,7 +189,8 @@ class ServiceMgr:
|
||||
config_dict['status'] = service_detail['status']
|
||||
else:
|
||||
config_dict['status'] = 'timeout'
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.warning(f"Can't get service details, error: {e}")
|
||||
config_dict['status'] = 'timeout'
|
||||
if not config_dict['host']:
|
||||
config_dict['host'] = '-'
|
||||
@ -205,17 +205,13 @@ class ServiceMgr:
|
||||
|
||||
@staticmethod
|
||||
def get_service_details(service_id: int):
|
||||
service_id = int(service_id)
|
||||
service_idx = int(service_id)
|
||||
configs = SERVICE_CONFIGS.configs
|
||||
service_config_mapping = {
|
||||
c.id: {
|
||||
'name': c.name,
|
||||
'detail_func_name': c.detail_func_name
|
||||
} for c in configs
|
||||
}
|
||||
service_info = service_config_mapping.get(service_id, {})
|
||||
if not service_info:
|
||||
raise AdminException(f"invalid service_id: {service_id}")
|
||||
if service_idx < 0 or service_idx >= len(configs):
|
||||
raise AdminException(f"invalid service_index: {service_idx}")
|
||||
|
||||
service_config = configs[service_idx]
|
||||
service_info = {'name': service_config.name, 'detail_func_name': service_config.detail_func_name}
|
||||
|
||||
detail_func = getattr(health_utils, service_info.get('detail_func_name'))
|
||||
res = detail_func()
|
||||
|
||||
@ -25,7 +25,6 @@ from typing import Any, Union, Tuple
|
||||
|
||||
from agent.component import component_class
|
||||
from agent.component.base import ComponentBase
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.misc_utils import get_uuid, hash_str2int
|
||||
from common.exceptions import TaskCanceledException
|
||||
@ -207,17 +206,60 @@ class Graph:
|
||||
for key in path.split('.'):
|
||||
if cur is None:
|
||||
return None
|
||||
|
||||
if isinstance(cur, str):
|
||||
try:
|
||||
cur = json.loads(cur)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if isinstance(cur, dict):
|
||||
cur = cur.get(key)
|
||||
else:
|
||||
cur = getattr(cur, key, None)
|
||||
continue
|
||||
|
||||
if isinstance(cur, (list, tuple)):
|
||||
try:
|
||||
idx = int(key)
|
||||
cur = cur[idx]
|
||||
except Exception:
|
||||
return None
|
||||
continue
|
||||
|
||||
cur = getattr(cur, key, None)
|
||||
return cur
|
||||
|
||||
def set_variable_value(self, exp: str,value):
|
||||
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
|
||||
if exp.find("@") < 0:
|
||||
self.globals[exp] = value
|
||||
return
|
||||
cpn_id, var_nm = exp.split("@")
|
||||
cpn = self.get_component(cpn_id)
|
||||
if not cpn:
|
||||
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
||||
parts = var_nm.split(".", 1)
|
||||
root_key = parts[0]
|
||||
rest = parts[1] if len(parts) > 1 else ""
|
||||
if not rest:
|
||||
cpn["obj"].set_output(root_key, value)
|
||||
return
|
||||
root_val = cpn["obj"].output(root_key)
|
||||
if not root_val:
|
||||
root_val = {}
|
||||
cpn["obj"].set_output(root_key, self.set_variable_param_value(root_val,rest,value))
|
||||
|
||||
def set_variable_param_value(self, obj: Any, path: str, value) -> Any:
|
||||
cur = obj
|
||||
keys = path.split('.')
|
||||
if not path:
|
||||
return value
|
||||
for key in keys:
|
||||
if key not in cur or not isinstance(cur[key], dict):
|
||||
cur[key] = {}
|
||||
cur = cur[key]
|
||||
cur[keys[-1]] = value
|
||||
return obj
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
return has_canceled(self.task_id)
|
||||
|
||||
@ -270,7 +312,7 @@ class Canvas(Graph):
|
||||
self.retrieval = []
|
||||
self.memory = []
|
||||
for k in self.globals.keys():
|
||||
if k.startswith("sys."):
|
||||
if k.startswith("sys.") or k.startswith("env."):
|
||||
if isinstance(self.globals[k], str):
|
||||
self.globals[k] = ""
|
||||
elif isinstance(self.globals[k], int):
|
||||
@ -284,7 +326,7 @@ class Canvas(Graph):
|
||||
else:
|
||||
self.globals[k] = None
|
||||
|
||||
def run(self, **kwargs):
|
||||
async def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
self.message_id = get_uuid()
|
||||
created_at = int(time.time())
|
||||
@ -298,8 +340,6 @@ class Canvas(Graph):
|
||||
for kk, vv in kwargs["webhook_payload"].items():
|
||||
self.components[k]["obj"].set_output(kk, vv)
|
||||
|
||||
self.components[k]["obj"].reset(True)
|
||||
|
||||
for k in kwargs.keys():
|
||||
if k in ["query", "user_id", "files"] and kwargs[k]:
|
||||
if k == "files":
|
||||
@ -408,6 +448,10 @@ class Canvas(Graph):
|
||||
else:
|
||||
yield decorate("message", {"content": cpn_obj.output("content")})
|
||||
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
||||
|
||||
if isinstance(cpn_obj.output("attachment"), tuple):
|
||||
yield decorate("message", {"attachment": cpn_obj.output("attachment")})
|
||||
|
||||
yield decorate("message_end", {"reference": self.get_reference() if cite else None})
|
||||
|
||||
while partials:
|
||||
@ -547,6 +591,7 @@ class Canvas(Graph):
|
||||
return self.components[cpnnm]["obj"].get_input_elements()
|
||||
|
||||
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
||||
from api.db.services.file_service import FileService
|
||||
if not files:
|
||||
return []
|
||||
def image_to_base64(file):
|
||||
@ -613,4 +658,3 @@ class Canvas(Graph):
|
||||
|
||||
def get_component_thoughts(self, cpn_id) -> str:
|
||||
return self.components.get(cpn_id)["obj"].thoughts()
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ from api.db.services.mcp_server_service import MCPServerService
|
||||
from common.connection_utils import timeout
|
||||
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
|
||||
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
|
||||
|
||||
@ -163,12 +163,7 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
output_structure=None
|
||||
try:
|
||||
output_structure=self._param.outputs['structured']
|
||||
except Exception:
|
||||
pass
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]):
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
|
||||
return
|
||||
|
||||
@ -368,11 +363,19 @@ Respond immediately with your final comprehensive answer.
|
||||
|
||||
return "Error occurred."
|
||||
|
||||
def reset(self, temp=False):
|
||||
def reset(self, only_output=False):
|
||||
"""
|
||||
Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession.
|
||||
"""
|
||||
for k in self._param.outputs.keys():
|
||||
self._param.outputs[k]["value"] = None
|
||||
|
||||
for k, cpn in self.tools.items():
|
||||
if hasattr(cpn, "reset") and callable(cpn.reset):
|
||||
cpn.reset()
|
||||
if only_output:
|
||||
return
|
||||
for k in self._param.inputs.keys():
|
||||
self._param.inputs[k]["value"] = None
|
||||
self._param.debug_inputs = {}
|
||||
|
||||
|
||||
@ -463,12 +463,15 @@ class ComponentBase(ABC):
|
||||
return self._param.outputs.get("_ERROR", {}).get("value")
|
||||
|
||||
def reset(self, only_output=False):
|
||||
for k in self._param.outputs.keys():
|
||||
self._param.outputs[k]["value"] = None
|
||||
outputs: dict = self._param.outputs # for better performance
|
||||
for k in outputs.keys():
|
||||
outputs[k]["value"] = None
|
||||
if only_output:
|
||||
return
|
||||
for k in self._param.inputs.keys():
|
||||
self._param.inputs[k]["value"] = None
|
||||
|
||||
inputs: dict = self._param.inputs # for better performance
|
||||
for k in inputs.keys():
|
||||
inputs[k]["value"] = None
|
||||
self._param.debug_inputs = {}
|
||||
|
||||
def get_input(self, key: str=None) -> Union[Any, dict[str, Any]]:
|
||||
|
||||
@ -1,3 +1,18 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import ast
|
||||
import os
|
||||
|
||||
@ -32,6 +32,7 @@ class IterationParam(ComponentParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.items_ref = ""
|
||||
self.variable={}
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
|
||||
168
agent/component/list_operations.py
Normal file
168
agent/component/list_operations.py
Normal file
@ -0,0 +1,168 @@
|
||||
from abc import ABC
|
||||
import os
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
class ListOperationsParam(ComponentParamBase):
|
||||
"""
|
||||
Define the List Operations component parameters.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.query = ""
|
||||
self.operations = "topN"
|
||||
self.n=0
|
||||
self.sort_method = "asc"
|
||||
self.filter = {
|
||||
"operator": "=",
|
||||
"value": ""
|
||||
}
|
||||
self.outputs = {
|
||||
"result": {
|
||||
"value": [],
|
||||
"type": "Array of ?"
|
||||
},
|
||||
"first": {
|
||||
"value": "",
|
||||
"type": "?"
|
||||
},
|
||||
"last": {
|
||||
"value": "",
|
||||
"type": "?"
|
||||
}
|
||||
}
|
||||
|
||||
def check(self):
|
||||
self.check_empty(self.query, "query")
|
||||
self.check_valid_value(self.operations, "Support operations", ["topN","head","tail","filter","sort","drop_duplicates"])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {}
|
||||
|
||||
|
||||
class ListOperations(ComponentBase,ABC):
|
||||
component_name = "ListOperations"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
self.input_objects=[]
|
||||
inputs = getattr(self._param, "query", None)
|
||||
self.inputs = self._canvas.get_variable_value(inputs)
|
||||
if not isinstance(self.inputs, list):
|
||||
raise TypeError("The input of List Operations should be an array.")
|
||||
self.set_input_value(inputs, self.inputs)
|
||||
if self._param.operations == "topN":
|
||||
self._topN()
|
||||
elif self._param.operations == "head":
|
||||
self._head()
|
||||
elif self._param.operations == "tail":
|
||||
self._tail()
|
||||
elif self._param.operations == "filter":
|
||||
self._filter()
|
||||
elif self._param.operations == "sort":
|
||||
self._sort()
|
||||
elif self._param.operations == "drop_duplicates":
|
||||
self._drop_duplicates()
|
||||
|
||||
|
||||
def _coerce_n(self):
|
||||
try:
|
||||
return int(getattr(self._param, "n", 0))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _set_outputs(self, outputs):
|
||||
self._param.outputs["result"]["value"] = outputs
|
||||
self._param.outputs["first"]["value"] = outputs[0] if outputs else None
|
||||
self._param.outputs["last"]["value"] = outputs[-1] if outputs else None
|
||||
|
||||
def _topN(self):
|
||||
n = self._coerce_n()
|
||||
if n < 1:
|
||||
outputs = []
|
||||
else:
|
||||
n = min(n, len(self.inputs))
|
||||
outputs = self.inputs[:n]
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _head(self):
|
||||
n = self._coerce_n()
|
||||
if 1 <= n <= len(self.inputs):
|
||||
outputs = [self.inputs[n - 1]]
|
||||
else:
|
||||
outputs = []
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _tail(self):
|
||||
n = self._coerce_n()
|
||||
if 1 <= n <= len(self.inputs):
|
||||
outputs = [self.inputs[-n]]
|
||||
else:
|
||||
outputs = []
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _filter(self):
|
||||
self._set_outputs([i for i in self.inputs if self._eval(self._norm(i),self._param.filter["operator"],self._param.filter["value"])])
|
||||
|
||||
def _norm(self,v):
|
||||
s = "" if v is None else str(v)
|
||||
return s
|
||||
|
||||
def _eval(self, v, operator, value):
|
||||
if operator == "=":
|
||||
return v == value
|
||||
elif operator == "≠":
|
||||
return v != value
|
||||
elif operator == "contains":
|
||||
return value in v
|
||||
elif operator == "start with":
|
||||
return v.startswith(value)
|
||||
elif operator == "end with":
|
||||
return v.endswith(value)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _sort(self):
|
||||
items = self.inputs or []
|
||||
method = getattr(self._param, "sort_method", "asc") or "asc"
|
||||
reverse = method == "desc"
|
||||
|
||||
if not items:
|
||||
self._set_outputs([])
|
||||
return
|
||||
|
||||
first = items[0]
|
||||
|
||||
if isinstance(first, dict):
|
||||
outputs = sorted(
|
||||
items,
|
||||
key=lambda x: self._hashable(x),
|
||||
reverse=reverse,
|
||||
)
|
||||
else:
|
||||
outputs = sorted(items, reverse=reverse)
|
||||
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _drop_duplicates(self):
|
||||
seen = set()
|
||||
outs = []
|
||||
for item in self.inputs:
|
||||
k = self._hashable(item)
|
||||
if k in seen:
|
||||
continue
|
||||
seen.add(k)
|
||||
outs.append(item)
|
||||
self._set_outputs(outs)
|
||||
|
||||
def _hashable(self,x):
|
||||
if isinstance(x, dict):
|
||||
return tuple(sorted((k, self._hashable(v)) for k, v in x.items()))
|
||||
if isinstance(x, (list, tuple)):
|
||||
return tuple(self._hashable(v) for v in x)
|
||||
if isinstance(x, set):
|
||||
return tuple(sorted(self._hashable(v) for v in x))
|
||||
return x
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "ListOperation in progress"
|
||||
@ -222,7 +222,7 @@ class LLM(ComponentBase):
|
||||
output_structure = self._param.outputs['structured']
|
||||
except Exception:
|
||||
pass
|
||||
if output_structure:
|
||||
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties"):
|
||||
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||
prompt += structured_output_prompt(schema)
|
||||
for _ in range(self._param.max_retries+1):
|
||||
@ -249,7 +249,7 @@ class LLM(ComponentBase):
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]):
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self._stream_output, prompt, msg))
|
||||
return
|
||||
|
||||
|
||||
@ -17,6 +17,8 @@ import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import logging
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
@ -24,6 +26,8 @@ from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from jinja2 import Template as Jinja2Template
|
||||
|
||||
from common.connection_utils import timeout
|
||||
from common.misc_utils import get_uuid
|
||||
from common import settings
|
||||
|
||||
|
||||
class MessageParam(ComponentParamBase):
|
||||
@ -34,6 +38,7 @@ class MessageParam(ComponentParamBase):
|
||||
super().__init__()
|
||||
self.content = []
|
||||
self.stream = True
|
||||
self.output_format = None # default output format
|
||||
self.outputs = {
|
||||
"content": {
|
||||
"type": "str"
|
||||
@ -133,6 +138,7 @@ class Message(ComponentBase):
|
||||
yield rand_cnt[s: ]
|
||||
|
||||
self.set_output("content", all_content)
|
||||
self._convert_content(all_content)
|
||||
|
||||
def _is_jinjia2(self, content:str) -> bool:
|
||||
patt = [
|
||||
@ -164,6 +170,72 @@ class Message(ComponentBase):
|
||||
content = re.sub(n, v, content)
|
||||
|
||||
self.set_output("content", content)
|
||||
self._convert_content(content)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return ""
|
||||
|
||||
def _convert_content(self, content):
|
||||
if not self._param.output_format:
|
||||
return
|
||||
|
||||
import pypandoc
|
||||
doc_id = get_uuid()
|
||||
|
||||
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}:
|
||||
self._param.output_format = "markdown"
|
||||
|
||||
try:
|
||||
if self._param.output_format in {"markdown", "html"}:
|
||||
if isinstance(content, str):
|
||||
converted = pypandoc.convert_text(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
)
|
||||
else:
|
||||
converted = pypandoc.convert_file(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
)
|
||||
|
||||
binary_content = converted.encode("utf-8")
|
||||
|
||||
else: # pdf, docx
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{self._param.output_format}", delete=False) as tmp:
|
||||
tmp_name = tmp.name
|
||||
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
pypandoc.convert_text(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
outputfile=tmp_name,
|
||||
)
|
||||
else:
|
||||
pypandoc.convert_file(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
outputfile=tmp_name,
|
||||
)
|
||||
|
||||
with open(tmp_name, "rb") as f:
|
||||
binary_content = f.read()
|
||||
|
||||
finally:
|
||||
if os.path.exists(tmp_name):
|
||||
os.remove(tmp_name)
|
||||
|
||||
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
|
||||
self.set_output("attachment", {
|
||||
"doc_id":doc_id,
|
||||
"format":self._param.output_format,
|
||||
"file_name":f"{doc_id[:8]}.{self._param.output_format}"})
|
||||
|
||||
logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error converting content to {self._param.output_format}: {e}")
|
||||
192
agent/component/variable_assigner.py
Normal file
192
agent/component/variable_assigner.py
Normal file
@ -0,0 +1,192 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import os
|
||||
import numbers
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
class VariableAssignerParam(ComponentParamBase):
|
||||
"""
|
||||
Define the Variable Assigner component parameters.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.variables=[]
|
||||
|
||||
def check(self):
|
||||
return True
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"items": {
|
||||
"type": "json",
|
||||
"name": "Items"
|
||||
}
|
||||
}
|
||||
|
||||
class VariableAssigner(ComponentBase,ABC):
|
||||
component_name = "VariableAssigner"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not isinstance(self._param.variables,list):
|
||||
return
|
||||
else:
|
||||
for item in self._param.variables:
|
||||
if any([not item.get("variable"), not item.get("operator"), not item.get("parameter")]):
|
||||
assert "Variable is not complete."
|
||||
variable=item["variable"]
|
||||
operator=item["operator"]
|
||||
parameter=item["parameter"]
|
||||
variable_value=self._canvas.get_variable_value(variable)
|
||||
new_variable=self._operate(variable_value,operator,parameter)
|
||||
self._canvas.set_variable_value(variable, new_variable)
|
||||
|
||||
def _operate(self,variable,operator,parameter):
|
||||
if operator == "overwrite":
|
||||
return self._overwrite(parameter)
|
||||
elif operator == "clear":
|
||||
return self._clear(variable)
|
||||
elif operator == "set":
|
||||
return self._set(variable,parameter)
|
||||
elif operator == "append":
|
||||
return self._append(variable,parameter)
|
||||
elif operator == "extend":
|
||||
return self._extend(variable,parameter)
|
||||
elif operator == "remove_first":
|
||||
return self._remove_first(variable)
|
||||
elif operator == "remove_last":
|
||||
return self._remove_last(variable)
|
||||
elif operator == "+=":
|
||||
return self._add(variable,parameter)
|
||||
elif operator == "-=":
|
||||
return self._subtract(variable,parameter)
|
||||
elif operator == "*=":
|
||||
return self._multiply(variable,parameter)
|
||||
elif operator == "/=":
|
||||
return self._divide(variable,parameter)
|
||||
else:
|
||||
return
|
||||
|
||||
def _overwrite(self,parameter):
|
||||
return self._canvas.get_variable_value(parameter)
|
||||
|
||||
def _clear(self,variable):
|
||||
if isinstance(variable,list):
|
||||
return []
|
||||
elif isinstance(variable,str):
|
||||
return ""
|
||||
elif isinstance(variable,dict):
|
||||
return {}
|
||||
elif isinstance(variable,int):
|
||||
return 0
|
||||
elif isinstance(variable,float):
|
||||
return 0.0
|
||||
elif isinstance(variable,bool):
|
||||
return False
|
||||
else:
|
||||
return None
|
||||
|
||||
def _set(self,variable,parameter):
|
||||
if variable is None:
|
||||
return self._canvas.get_value_with_variable(parameter)
|
||||
elif isinstance(variable,str):
|
||||
return self._canvas.get_value_with_variable(parameter)
|
||||
elif isinstance(variable,bool):
|
||||
return parameter
|
||||
elif isinstance(variable,int):
|
||||
return parameter
|
||||
elif isinstance(variable,float):
|
||||
return parameter
|
||||
else:
|
||||
return parameter
|
||||
|
||||
def _append(self,variable,parameter):
|
||||
parameter=self._canvas.get_variable_value(parameter)
|
||||
if variable is None:
|
||||
variable=[]
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
elif len(variable)!=0 and not isinstance(parameter,type(variable[0])):
|
||||
return "ERROR:PARAMETER_NOT_LIST_ELEMENT_TYPE"
|
||||
else:
|
||||
variable.append(parameter)
|
||||
return variable
|
||||
|
||||
def _extend(self,variable,parameter):
|
||||
parameter=self._canvas.get_variable_value(parameter)
|
||||
if variable is None:
|
||||
variable=[]
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
elif not isinstance(parameter,list):
|
||||
return "ERROR:PARAMETER_NOT_LIST"
|
||||
elif len(variable)!=0 and len(parameter)!=0 and not isinstance(parameter[0],type(variable[0])):
|
||||
return "ERROR:PARAMETER_NOT_LIST_ELEMENT_TYPE"
|
||||
else:
|
||||
return variable + parameter
|
||||
|
||||
def _remove_first(self,variable):
|
||||
if len(variable)==0:
|
||||
return variable
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
else:
|
||||
return variable[1:]
|
||||
|
||||
def _remove_last(self,variable):
|
||||
if len(variable)==0:
|
||||
return variable
|
||||
if not isinstance(variable,list):
|
||||
return "ERROR:VARIABLE_NOT_LIST"
|
||||
else:
|
||||
return variable[:-1]
|
||||
|
||||
def is_number(self, value):
|
||||
if isinstance(value, bool):
|
||||
return False
|
||||
return isinstance(value, numbers.Number)
|
||||
|
||||
def _add(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
return variable + parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def _subtract(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
return variable - parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def _multiply(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
return variable * parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def _divide(self,variable,parameter):
|
||||
if self.is_number(variable) and self.is_number(parameter):
|
||||
if parameter==0:
|
||||
return "ERROR:DIVIDE_BY_ZERO"
|
||||
else:
|
||||
return variable/parameter
|
||||
else:
|
||||
return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER"
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Assign variables from canvas."
|
||||
File diff suppressed because one or more lines are too long
@ -83,10 +83,10 @@
|
||||
"value": []
|
||||
}
|
||||
},
|
||||
"password": "20010812Yy!",
|
||||
"password": "",
|
||||
"port": 3306,
|
||||
"sql": "{Agent:WickedGoatsDivide@content}",
|
||||
"username": "13637682833@163.com"
|
||||
"username": ""
|
||||
}
|
||||
},
|
||||
"upstream": [
|
||||
@ -527,10 +527,10 @@
|
||||
"value": []
|
||||
}
|
||||
},
|
||||
"password": "20010812Yy!",
|
||||
"password": "",
|
||||
"port": 3306,
|
||||
"sql": "{Agent:WickedGoatsDivide@content}",
|
||||
"username": "13637682833@163.com"
|
||||
"username": ""
|
||||
},
|
||||
"label": "ExeSQL",
|
||||
"name": "ExeSQL"
|
||||
|
||||
@ -21,9 +21,8 @@ from functools import partial
|
||||
from typing import TypedDict, List, Any
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
from common.misc_utils import hash_str2int
|
||||
from rag.llm.chat_model import ToolCallSession
|
||||
from rag.prompts.generator import kb_prompt
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
|
||||
from timeit import default_timer as timer
|
||||
|
||||
|
||||
|
||||
@ -13,16 +13,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import ast
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC
|
||||
from strenum import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||
from common.connection_utils import timeout
|
||||
from strenum import StrEnum
|
||||
|
||||
from agent.tools.base import ToolBase, ToolMeta, ToolParamBase
|
||||
from common import settings
|
||||
from common.connection_utils import timeout
|
||||
|
||||
|
||||
class Language(StrEnum):
|
||||
@ -62,7 +66,7 @@ class CodeExecParam(ToolParamBase):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
self.meta: ToolMeta = {
|
||||
"name": "execute_code",
|
||||
"description": """
|
||||
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It recieves a piece of code and return a Json string.
|
||||
@ -99,16 +103,12 @@ module.exports = { main };
|
||||
"enum": ["python", "javascript"],
|
||||
"required": True,
|
||||
},
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "A piece of code in right format. There MUST be main function.",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
"script": {"type": "string", "description": "A piece of code in right format. There MUST be main function.", "required": True},
|
||||
},
|
||||
}
|
||||
super().__init__()
|
||||
self.lang = Language.PYTHON.value
|
||||
self.script = "def main(arg1: str, arg2: str) -> dict: return {\"result\": arg1 + arg2}"
|
||||
self.script = 'def main(arg1: str, arg2: str) -> dict: return {"result": arg1 + arg2}'
|
||||
self.arguments = {}
|
||||
self.outputs = {"result": {"value": "", "type": "string"}}
|
||||
|
||||
@ -119,17 +119,14 @@ module.exports = { main };
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
res = {}
|
||||
for k, v in self.arguments.items():
|
||||
res[k] = {
|
||||
"type": "line",
|
||||
"name": k
|
||||
}
|
||||
res[k] = {"type": "line", "name": k}
|
||||
return res
|
||||
|
||||
|
||||
class CodeExec(ToolBase, ABC):
|
||||
component_name = "CodeExec"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("CodeExec processing"):
|
||||
return
|
||||
@ -138,17 +135,12 @@ class CodeExec(ToolBase, ABC):
|
||||
script = kwargs.get("script", self._param.script)
|
||||
arguments = {}
|
||||
for k, v in self._param.arguments.items():
|
||||
|
||||
if kwargs.get(k):
|
||||
arguments[k] = kwargs[k]
|
||||
continue
|
||||
arguments[k] = self._canvas.get_variable_value(v) if v else None
|
||||
|
||||
self._execute_code(
|
||||
language=lang,
|
||||
code=script,
|
||||
arguments=arguments
|
||||
)
|
||||
self._execute_code(language=lang, code=script, arguments=arguments)
|
||||
|
||||
def _execute_code(self, language: str, code: str, arguments: dict):
|
||||
import requests
|
||||
@ -169,7 +161,7 @@ class CodeExec(ToolBase, ABC):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return "Task has been canceled"
|
||||
|
||||
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)))
|
||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:")
|
||||
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
@ -183,35 +175,10 @@ class CodeExec(ToolBase, ABC):
|
||||
if stderr:
|
||||
self.set_output("_ERROR", stderr)
|
||||
return
|
||||
try:
|
||||
rt = eval(body.get("stdout", ""))
|
||||
except Exception:
|
||||
rt = body.get("stdout", "")
|
||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run -> {rt}")
|
||||
if isinstance(rt, tuple):
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
if k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt[i]
|
||||
elif isinstance(rt, dict):
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
if k not in rt or k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt[k]
|
||||
else:
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
if k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt
|
||||
raw_stdout = body.get("stdout", "")
|
||||
parsed_stdout = self._deserialize_stdout(raw_stdout)
|
||||
logging.info(f"[CodeExec]: http://{settings.SANDBOX_HOST}:9385/run -> {parsed_stdout}")
|
||||
self._populate_outputs(parsed_stdout, raw_stdout)
|
||||
else:
|
||||
self.set_output("_ERROR", "There is no response from sandbox")
|
||||
|
||||
@ -228,3 +195,149 @@ class CodeExec(ToolBase, ABC):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Running a short script to process data."
|
||||
|
||||
def _deserialize_stdout(self, stdout: str):
|
||||
text = str(stdout).strip()
|
||||
if not text:
|
||||
return ""
|
||||
for loader in (json.loads, ast.literal_eval):
|
||||
try:
|
||||
return loader(text)
|
||||
except Exception:
|
||||
continue
|
||||
return text
|
||||
|
||||
def _coerce_output_value(self, value, expected_type: Optional[str]):
|
||||
if expected_type is None:
|
||||
return value
|
||||
|
||||
etype = expected_type.strip().lower()
|
||||
inner_type = None
|
||||
if etype.startswith("array<") and etype.endswith(">"):
|
||||
inner_type = etype[6:-1].strip()
|
||||
etype = "array"
|
||||
|
||||
try:
|
||||
if etype == "string":
|
||||
return "" if value is None else str(value)
|
||||
|
||||
if etype == "number":
|
||||
if value is None or value == "":
|
||||
return None
|
||||
if isinstance(value, (int, float)):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return value
|
||||
return float(value)
|
||||
|
||||
if etype == "boolean":
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
lv = value.lower()
|
||||
if lv in ("true", "1", "yes", "y", "on"):
|
||||
return True
|
||||
if lv in ("false", "0", "no", "n", "off"):
|
||||
return False
|
||||
return bool(value)
|
||||
|
||||
if etype == "array":
|
||||
candidate = value
|
||||
if isinstance(candidate, str):
|
||||
parsed = self._deserialize_stdout(candidate)
|
||||
candidate = parsed
|
||||
if isinstance(candidate, tuple):
|
||||
candidate = list(candidate)
|
||||
if not isinstance(candidate, list):
|
||||
candidate = [] if candidate is None else [candidate]
|
||||
|
||||
if inner_type == "string":
|
||||
return ["" if v is None else str(v) for v in candidate]
|
||||
if inner_type == "number":
|
||||
coerced = []
|
||||
for v in candidate:
|
||||
try:
|
||||
if v is None or v == "":
|
||||
coerced.append(None)
|
||||
elif isinstance(v, (int, float)):
|
||||
coerced.append(v)
|
||||
else:
|
||||
coerced.append(float(v))
|
||||
except Exception:
|
||||
coerced.append(v)
|
||||
return coerced
|
||||
return candidate
|
||||
|
||||
if etype == "object":
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
parsed = self._deserialize_stdout(value)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return value
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
return value
|
||||
|
||||
def _populate_outputs(self, parsed_stdout, raw_stdout: str):
|
||||
outputs_items = list(self._param.outputs.items())
|
||||
logging.info(f"[CodeExec]: outputs schema keys: {[k for k, _ in outputs_items]}")
|
||||
if not outputs_items:
|
||||
return
|
||||
|
||||
if isinstance(parsed_stdout, dict):
|
||||
for key, meta in outputs_items:
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
val = self._get_by_path(parsed_stdout, key)
|
||||
coerced = self._coerce_output_value(val, meta.get("type"))
|
||||
logging.info(f"[CodeExec]: populate dict key='{key}' raw='{val}' coerced='{coerced}'")
|
||||
self.set_output(key, coerced)
|
||||
return
|
||||
|
||||
if isinstance(parsed_stdout, (list, tuple)):
|
||||
for idx, (key, meta) in enumerate(outputs_items):
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
val = parsed_stdout[idx] if idx < len(parsed_stdout) else None
|
||||
coerced = self._coerce_output_value(val, meta.get("type"))
|
||||
logging.info(f"[CodeExec]: populate list key='{key}' raw='{val}' coerced='{coerced}'")
|
||||
self.set_output(key, coerced)
|
||||
return
|
||||
|
||||
default_val = parsed_stdout if parsed_stdout is not None else raw_stdout
|
||||
for idx, (key, meta) in enumerate(outputs_items):
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
val = default_val if idx == 0 else None
|
||||
coerced = self._coerce_output_value(val, meta.get("type"))
|
||||
logging.info(f"[CodeExec]: populate scalar key='{key}' raw='{val}' coerced='{coerced}'")
|
||||
self.set_output(key, coerced)
|
||||
|
||||
def _get_by_path(self, data, path: str):
|
||||
if not path:
|
||||
return None
|
||||
cur = data
|
||||
for part in path.split("."):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
return None
|
||||
if isinstance(cur, dict):
|
||||
cur = cur.get(part)
|
||||
elif isinstance(cur, list):
|
||||
try:
|
||||
idx = int(part)
|
||||
cur = cur[idx]
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
if cur is None:
|
||||
return None
|
||||
logging.info(f"[CodeExec]: resolve path '{path}' -> {cur}")
|
||||
return cur
|
||||
|
||||
@ -132,12 +132,12 @@ class Retrieval(ToolBase, ABC):
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if self._param.meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
||||
filters = gen_meta_filter(chat_mdl, metas, query)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, query)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif self._param.meta_data_filter.get("method") == "manual":
|
||||
filters=self._param.meta_data_filter["manual"]
|
||||
filters = self._param.meta_data_filter["manual"]
|
||||
for flt in filters:
|
||||
pat = re.compile(self.variable_ref_patt)
|
||||
s = flt["value"]
|
||||
@ -165,9 +165,9 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
out_parts.append(s[last:])
|
||||
flt["value"] = "".join(out_parts)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
doc_ids.extend(meta_filter(metas, filters, self._param.meta_data_filter.get("logic", "and")))
|
||||
if filters and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
|
||||
if self._param.cross_languages:
|
||||
query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
|
||||
|
||||
@ -18,30 +18,32 @@ import sys
|
||||
import logging
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from flask import Blueprint, Flask
|
||||
from quart import Blueprint, Quart, request, g, current_app, session
|
||||
from werkzeug.wrappers.request import Request
|
||||
from flask_cors import CORS
|
||||
from flasgger import Swagger
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
from quart_cors import cors
|
||||
from common.constants import StatusEnum
|
||||
from api.db.db_models import close_connection
|
||||
from api.db.db_models import close_connection, APIToken
|
||||
from api.db.services import UserService
|
||||
from api.utils.json_encode import CustomJSONEncoder
|
||||
from api.utils import commands
|
||||
|
||||
from flask_mail import Mail
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from quart_auth import Unauthorized
|
||||
from common import settings
|
||||
from api.utils.api_utils import server_error_response
|
||||
from api.constants import API_VERSION
|
||||
from common.misc_utils import get_uuid
|
||||
|
||||
settings.init_settings()
|
||||
|
||||
__all__ = ["app"]
|
||||
|
||||
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
||||
|
||||
app = Flask(__name__)
|
||||
app = Quart(__name__)
|
||||
app = cors(app, allow_origin="*")
|
||||
smtp_mail_server = Mail()
|
||||
|
||||
# Add this at the beginning of your file to configure Swagger UI
|
||||
@ -76,7 +78,6 @@ swagger = Swagger(
|
||||
},
|
||||
)
|
||||
|
||||
CORS(app, supports_credentials=True, max_age=2592000)
|
||||
app.url_map.strict_slashes = False
|
||||
app.json_encoder = CustomJSONEncoder
|
||||
app.errorhandler(Exception)(server_error_response)
|
||||
@ -84,24 +85,154 @@ app.errorhandler(Exception)(server_error_response)
|
||||
## convince for dev and debug
|
||||
# app.config["LOGIN_DISABLED"] = True
|
||||
app.config["SESSION_PERMANENT"] = False
|
||||
app.config["SESSION_TYPE"] = "filesystem"
|
||||
app.config["SESSION_TYPE"] = "redis"
|
||||
app.config["SESSION_REDIS"] = settings.decrypt_database_config(name="redis")
|
||||
app.config["MAX_CONTENT_LENGTH"] = int(
|
||||
os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024)
|
||||
)
|
||||
|
||||
Session(app)
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
app.config['SECRET_KEY'] = settings.SECRET_KEY
|
||||
app.secret_key = settings.SECRET_KEY
|
||||
commands.register_commands(app)
|
||||
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
from collections.abc import Awaitable, Callable
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
def search_pages_path(pages_dir):
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
def _load_user():
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = request.headers.get("Authorization")
|
||||
g.user = None
|
||||
if not authorization:
|
||||
return
|
||||
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
logging.warning("Authentication attempt with empty access token")
|
||||
return None
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
|
||||
return None
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if not user and len(authorization.split()) == 2:
|
||||
objs = APIToken.query(token=authorization.split()[1])
|
||||
if objs:
|
||||
user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
g.user = user[0]
|
||||
return user[0]
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
|
||||
|
||||
current_user = LocalProxy(_load_user)
|
||||
|
||||
|
||||
def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
|
||||
"""A decorator to restrict route access to authenticated users.
|
||||
|
||||
This should be used to wrap a route handler (or view function) to
|
||||
enforce that only authenticated requests can access it. Note that
|
||||
it is important that this decorator be wrapped by the route
|
||||
decorator and not vice, versa, as below.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@app.route('/')
|
||||
@login_required
|
||||
async def index():
|
||||
...
|
||||
|
||||
If the request is not authenticated a
|
||||
`quart.exceptions.Unauthorized` exception will be raised.
|
||||
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
if not current_user:# or not session.get("_user_id"):
|
||||
raise Unauthorized()
|
||||
else:
|
||||
return await current_app.ensure_async(func)(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def login_user(user, remember=False, duration=None, force=False, fresh=True):
|
||||
"""
|
||||
Logs a user in. You should pass the actual user object to this. If the
|
||||
user's `is_active` property is ``False``, they will not be logged in
|
||||
unless `force` is ``True``.
|
||||
|
||||
This will return ``True`` if the log in attempt succeeds, and ``False`` if
|
||||
it fails (i.e. because the user is inactive).
|
||||
|
||||
:param user: The user object to log in.
|
||||
:type user: object
|
||||
:param remember: Whether to remember the user after their session expires.
|
||||
Defaults to ``False``.
|
||||
:type remember: bool
|
||||
:param duration: The amount of time before the remember cookie expires. If
|
||||
``None`` the value set in the settings is used. Defaults to ``None``.
|
||||
:type duration: :class:`datetime.timedelta`
|
||||
:param force: If the user is inactive, setting this to ``True`` will log
|
||||
them in regardless. Defaults to ``False``.
|
||||
:type force: bool
|
||||
:param fresh: setting this to ``False`` will log in the user with a session
|
||||
marked as not "fresh". Defaults to ``True``.
|
||||
:type fresh: bool
|
||||
"""
|
||||
if not force and not user.is_active:
|
||||
return False
|
||||
|
||||
session["_user_id"] = user.id
|
||||
session["_fresh"] = fresh
|
||||
session["_id"] = get_uuid()
|
||||
return True
|
||||
|
||||
|
||||
def logout_user():
|
||||
"""
|
||||
Logs a user out. (You do not need to pass the actual user.) This will
|
||||
also clean up the remember me cookie if it exists.
|
||||
"""
|
||||
if "_user_id" in session:
|
||||
session.pop("_user_id")
|
||||
|
||||
if "_fresh" in session:
|
||||
session.pop("_fresh")
|
||||
|
||||
if "_id" in session:
|
||||
session.pop("_id")
|
||||
|
||||
COOKIE_NAME = "remember_token"
|
||||
cookie_name = current_app.config.get("REMEMBER_COOKIE_NAME", COOKIE_NAME)
|
||||
if cookie_name in request.cookies:
|
||||
session["_remember"] = "clear"
|
||||
if "_remember_seconds" in session:
|
||||
session.pop("_remember_seconds")
|
||||
|
||||
return True
|
||||
|
||||
def search_pages_path(page_path):
|
||||
app_path_list = [
|
||||
path for path in pages_dir.glob("*_app.py") if not path.name.startswith(".")
|
||||
path for path in page_path.glob("*_app.py") if not path.name.startswith(".")
|
||||
]
|
||||
api_path_list = [
|
||||
path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".")
|
||||
path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")
|
||||
]
|
||||
app_path_list.extend(api_path_list)
|
||||
return app_path_list
|
||||
@ -138,44 +269,12 @@ pages_dir = [
|
||||
]
|
||||
|
||||
client_urls_prefix = [
|
||||
register_page(path) for dir in pages_dir for path in search_pages_path(dir)
|
||||
register_page(path) for directory in pages_dir for path in search_pages_path(directory)
|
||||
]
|
||||
|
||||
|
||||
@login_manager.request_loader
|
||||
def load_user(web_request):
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = web_request.headers.get("Authorization")
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
logging.warning("Authentication attempt with empty access token")
|
||||
return None
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
|
||||
return None
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
return user[0]
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@app.teardown_request
|
||||
def _db_close(exc):
|
||||
def _db_close(exception):
|
||||
if exception:
|
||||
logging.exception(f"Request failed: {exception}")
|
||||
close_connection()
|
||||
|
||||
@ -13,46 +13,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from flask import request, Response
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db import VALID_FILE_TYPES, FileType
|
||||
from api.db.db_models import APIToken, Task, File
|
||||
from api.db.services import duplicate_name
|
||||
from quart import request
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||
from api.db.services.dialog_service import DialogService, chat
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import queue_tasks, TaskService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
||||
generate_confirmation_token
|
||||
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import keyword_extraction
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from agent.canvas import Canvas
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def new_token():
|
||||
req = request.json
|
||||
async def new_token():
|
||||
req = await request.json
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if not tenants:
|
||||
@ -97,8 +72,8 @@ def token_list():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@validate_request("tokens", "tenant_id")
|
||||
@login_required
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await request.json
|
||||
try:
|
||||
for token in req["tokens"]:
|
||||
APITokenService.filter_delete(
|
||||
@ -126,770 +101,19 @@ def stats():
|
||||
"to_date",
|
||||
datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
|
||||
"agent" if "canvas_id" in request.args else None)
|
||||
res = {
|
||||
"pv": [(o["dt"], o["pv"]) for o in objs],
|
||||
"uv": [(o["dt"], o["uv"]) for o in objs],
|
||||
"speed": [(o["dt"], float(o["tokens"]) / (float(o["duration"] + 0.1))) for o in objs],
|
||||
"tokens": [(o["dt"], float(o["tokens"]) / 1000.) for o in objs],
|
||||
"round": [(o["dt"], o["round"]) for o in objs],
|
||||
"thumb_up": [(o["dt"], o["thumb_up"]) for o in objs]
|
||||
}
|
||||
|
||||
res = {"pv": [], "uv": [], "speed": [], "tokens": [], "round": [], "thumb_up": []}
|
||||
|
||||
for obj in objs:
|
||||
dt = obj["dt"]
|
||||
res["pv"].append((dt, obj["pv"]))
|
||||
res["uv"].append((dt, obj["uv"]))
|
||||
res["speed"].append((dt, float(obj["tokens"]) / (float(obj["duration"]) + 0.1))) # +0.1 to avoid division by zero
|
||||
res["tokens"].append((dt, float(obj["tokens"]) / 1000.0)) # convert to thousands
|
||||
res["round"].append((dt, obj["round"]))
|
||||
res["thumb_up"].append((dt, obj["thumb_up"]))
|
||||
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/new_conversation', methods=['GET']) # noqa: F821
|
||||
def set_conversation():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
if objs[0].source == "agent":
|
||||
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": request.args.get("user_id", ""),
|
||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
||||
"source": "agent"
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
else:
|
||||
e, dia = DialogService.get_by_id(objs[0].dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found")
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": dia.id,
|
||||
"user_id": request.args.get("user_id", ""),
|
||||
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
||||
@validate_request("conversation_id", "messages")
|
||||
def completion():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
if "quote" not in req:
|
||||
req["quote"] = False
|
||||
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
def fillin_conv(ans):
|
||||
nonlocal conv, message_id
|
||||
if not conv.reference:
|
||||
conv.reference.append(ans["reference"])
|
||||
else:
|
||||
conv.reference[-1] = ans["reference"]
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
||||
ans["id"] = message_id
|
||||
|
||||
def rename_field(ans):
|
||||
reference = ans['reference']
|
||||
if not isinstance(reference, dict):
|
||||
return
|
||||
for chunk_i in reference.get('chunks', []):
|
||||
if 'docnm_kwd' in chunk_i:
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
|
||||
try:
|
||||
if conv.source == "agent":
|
||||
stream = req.get("stream", True)
|
||||
conv.message.append(msg[-1])
|
||||
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
final_ans = {"reference": [], "content": ""}
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
|
||||
canvas.messages.append(msg[-1])
|
||||
canvas.add_user_input(msg[-1]["content"])
|
||||
answer = canvas.run(stream=stream)
|
||||
|
||||
assert answer is not None, "Nothing. Is it over?"
|
||||
|
||||
if stream:
|
||||
assert isinstance(answer, partial), "Nothing. Is it over?"
|
||||
|
||||
def sse():
|
||||
nonlocal answer, cvs, conv
|
||||
try:
|
||||
for ans in answer():
|
||||
for k in ans.keys():
|
||||
final_ans[k] = ans[k]
|
||||
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
|
||||
fillin_conv(ans)
|
||||
rename_field(ans)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
canvas.history.append(("assistant", final_ans["content"]))
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(sse(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
|
||||
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
||||
fillin_conv(result)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
rename_field(result)
|
||||
return get_json_result(data=result)
|
||||
|
||||
# ******************For dialog******************
|
||||
conv.message.append(msg[-1])
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
def stream():
|
||||
nonlocal dia, msg, req, conv
|
||||
try:
|
||||
for ans in chat(dia, msg, True, **req):
|
||||
fillin_conv(ans)
|
||||
rename_field(ans)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
if req.get("stream", True):
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
answer = None
|
||||
for ans in chat(dia, msg, **req):
|
||||
answer = ans
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
break
|
||||
rename_field(answer)
|
||||
return get_json_result(data=answer)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/conversation/<conversation_id>', methods=['GET']) # noqa: F821
|
||||
# @login_required
|
||||
def get_conversation(conversation_id):
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
e, conv = API4ConversationService.get_by_id(conversation_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
|
||||
conv = conv.to_dict()
|
||||
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
|
||||
return get_json_result(data=False, message='Authentication error: API key is invalid for this conversation_id!"',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
for referenct_i in conv['reference']:
|
||||
if referenct_i is None or len(referenct_i) == 0:
|
||||
continue
|
||||
for chunk_i in referenct_i['chunks']:
|
||||
if 'docnm_kwd' in chunk_i.keys():
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/document/upload', methods=['POST']) # noqa: F821
|
||||
@validate_request("kb_name")
|
||||
def upload():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
kb_name = request.form.get("kb_name").strip()
|
||||
tenant_id = objs[0].tenant_id
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
kb_id = kb.id
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
kb_root_folder = FileService.get_kb_folder(tenant_id)
|
||||
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
|
||||
|
||||
try:
|
||||
if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)):
|
||||
return get_data_error_result(
|
||||
message="Exceed the maximum file number of a free user!")
|
||||
|
||||
filename = duplicate_name(
|
||||
DocumentService.query,
|
||||
name=file.filename,
|
||||
kb_id=kb_id)
|
||||
filetype = filename_type(filename)
|
||||
if not filetype:
|
||||
return get_data_error_result(
|
||||
message="This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
while settings.STORAGE_IMPL.obj_exist(kb_id, location):
|
||||
location += "_"
|
||||
blob = request.files['file'].read()
|
||||
settings.STORAGE_IMPL.put(kb_id, location, blob)
|
||||
doc = {
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": kb.tenant_id,
|
||||
"type": filetype,
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
"thumbnail": thumbnail(filename, blob),
|
||||
"suffix": Path(filename).suffix.lstrip("."),
|
||||
}
|
||||
|
||||
form_data = request.form
|
||||
if "parser_id" in form_data.keys():
|
||||
if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
|
||||
doc["parser_id"] = request.form.get("parser_id").strip()
|
||||
if doc["type"] == FileType.VISUAL:
|
||||
doc["parser_id"] = ParserType.PICTURE.value
|
||||
if doc["type"] == FileType.AURAL:
|
||||
doc["parser_id"] = ParserType.AUDIO.value
|
||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||
doc["parser_id"] = ParserType.PRESENTATION.value
|
||||
if re.search(r"\.(eml)$", filename):
|
||||
doc["parser_id"] = ParserType.EMAIL.value
|
||||
|
||||
doc_result = DocumentService.insert(doc)
|
||||
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
if "run" in form_data.keys():
|
||||
if request.form.get("run").strip() == "1":
|
||||
try:
|
||||
info = {"run": 1, "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
|
||||
DocumentService.update_by_id(doc["id"], info)
|
||||
# if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||
tenant_id = DocumentService.get_tenant_id(doc["id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
# e, doc = DocumentService.get_by_id(doc["id"])
|
||||
TaskService.filter_delete([Task.doc_id == doc["id"]])
|
||||
e, doc = DocumentService.get_by_id(doc["id"])
|
||||
doc = doc.to_dict()
|
||||
doc["tenant_id"] = tenant_id
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
return get_json_result(data=doc_result.to_json())
|
||||
|
||||
|
||||
@manager.route('/document/upload_and_parse', methods=['POST']) # noqa: F821
|
||||
@validate_request("conversation_id")
|
||||
def upload_parse():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist('file')
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
||||
return get_json_result(data=doc_ids)
|
||||
|
||||
|
||||
@manager.route('/list_chunks', methods=['POST']) # noqa: F821
|
||||
# @login_required
|
||||
def list_chunks():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
|
||||
try:
|
||||
if "doc_name" in req.keys():
|
||||
tenant_id = DocumentService.get_tenant_id_by_name(req['doc_name'])
|
||||
doc_id = DocumentService.get_doc_id_by_doc_name(req['doc_name'])
|
||||
|
||||
elif "doc_id" in req.keys():
|
||||
tenant_id = DocumentService.get_tenant_id(req['doc_id'])
|
||||
doc_id = req['doc_id']
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message="Can't find doc_name or doc_id"
|
||||
)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
|
||||
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = [
|
||||
{
|
||||
"content": res_item["content_with_weight"],
|
||||
"doc_name": res_item["docnm_kwd"],
|
||||
"image_id": res_item["img_id"]
|
||||
} for res_item in res
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
return get_json_result(data=res)
|
||||
|
||||
@manager.route('/get_chunk/<chunk_id>', methods=['GET']) # noqa: F821
|
||||
# @login_required
|
||||
def get_chunk(chunk_id):
|
||||
from rag.nlp import search
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
tenant_id = objs[0].tenant_id
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||
if chunk is None:
|
||||
return server_error_response(Exception("Chunk not found"))
|
||||
k = []
|
||||
for n in chunk.keys():
|
||||
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
|
||||
k.append(n)
|
||||
for n in k:
|
||||
del chunk[n]
|
||||
|
||||
return get_json_result(data=chunk)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/list_kb_docs', methods=['POST']) # noqa: F821
|
||||
# @login_required
|
||||
def list_kb_docs():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
tenant_id = objs[0].tenant_id
|
||||
kb_name = req.get("kb_name", "").strip()
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
kb_id = kb.id
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
page_number = int(req.get("page", 1))
|
||||
items_per_page = int(req.get("page_size", 15))
|
||||
orderby = req.get("orderby", "create_time")
|
||||
desc = req.get("desc", True)
|
||||
keywords = req.get("keywords", "")
|
||||
status = req.get("status", [])
|
||||
if status:
|
||||
invalid_status = {s for s in status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(
|
||||
message=f"Invalid filter status conditions: {', '.join(invalid_status)}"
|
||||
)
|
||||
types = req.get("types", [])
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(
|
||||
message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}"
|
||||
)
|
||||
try:
|
||||
docs, tol = DocumentService.get_by_kb_id(
|
||||
kb_id, page_number, items_per_page, orderby, desc, keywords, status, types)
|
||||
docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs]
|
||||
|
||||
return get_json_result(data={"total": tol, "docs": docs})
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/document/infos', methods=['POST']) # noqa: F821
|
||||
@validate_request("doc_ids")
|
||||
def docinfos():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
doc_ids = req["doc_ids"]
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
return get_json_result(data=list(docs.dicts()))
|
||||
|
||||
|
||||
@manager.route('/document', methods=['DELETE']) # noqa: F821
|
||||
# @login_required
|
||||
def document_rm():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
req = request.json
|
||||
try:
|
||||
doc_ids = DocumentService.get_doc_ids_by_doc_names(req.get("doc_names", []))
|
||||
for doc_id in req.get("doc_ids", []):
|
||||
if doc_id not in doc_ids:
|
||||
doc_ids.append(doc_id)
|
||||
|
||||
if not doc_ids:
|
||||
return get_json_result(
|
||||
data=False, message="Can't find doc_names or doc_ids"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
|
||||
errors = ""
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
doc_dic = {}
|
||||
for doc in docs:
|
||||
doc_dic[doc.id] = doc
|
||||
|
||||
for doc_id in doc_ids:
|
||||
try:
|
||||
if doc_id not in doc_dic:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
doc = doc_dic[doc_id]
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
|
||||
f2d = File2DocumentService.get_by_document_id(doc_id)
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc_id)
|
||||
|
||||
settings.STORAGE_IMPL.rm(b, n)
|
||||
except Exception as e:
|
||||
errors += str(e)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/completion_aibotk', methods=['POST']) # noqa: F821
|
||||
@validate_request("Authorization", "conversation_id", "word")
|
||||
def completion_faq():
|
||||
import base64
|
||||
req = request.json
|
||||
|
||||
token = req["Authorization"]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
if "quote" not in req:
|
||||
req["quote"] = True
|
||||
|
||||
msg = [{"role": "user", "content": req["word"]}]
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
def fillin_conv(ans):
|
||||
nonlocal conv, message_id
|
||||
if not conv.reference:
|
||||
conv.reference.append(ans["reference"])
|
||||
else:
|
||||
conv.reference[-1] = ans["reference"]
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
||||
ans["id"] = message_id
|
||||
|
||||
try:
|
||||
if conv.source == "agent":
|
||||
conv.message.append(msg[-1])
|
||||
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
final_ans = {"reference": [], "doc_aggs": []}
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
|
||||
canvas.messages.append(msg[-1])
|
||||
canvas.add_user_input(msg[-1]["content"])
|
||||
answer = canvas.run(stream=False)
|
||||
|
||||
assert answer is not None, "Nothing. Is it over?"
|
||||
|
||||
data_type_picture = {
|
||||
"type": 3,
|
||||
"url": "base64 content"
|
||||
}
|
||||
data = [
|
||||
{
|
||||
"type": 1,
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
|
||||
ans = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
||||
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
|
||||
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
||||
for chunk_idx in chunk_idxs[:1]:
|
||||
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
||||
try:
|
||||
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
||||
response = settings.STORAGE_IMPL.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
response = {"code": 200, "msg": "success", "data": data}
|
||||
return response
|
||||
|
||||
# ******************For dialog******************
|
||||
conv.message.append(msg[-1])
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
del req["conversation_id"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
data_type_picture = {
|
||||
"type": 3,
|
||||
"url": "base64 content"
|
||||
}
|
||||
data = [
|
||||
{
|
||||
"type": 1,
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
ans = ""
|
||||
for a in chat(dia, msg, stream=False, **req):
|
||||
ans = a
|
||||
break
|
||||
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
|
||||
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
||||
for chunk_idx in chunk_idxs[:1]:
|
||||
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
||||
try:
|
||||
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
||||
response = settings.STORAGE_IMPL.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
response = {"code": 200, "msg": "success", "data": data}
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/retrieval', methods=['POST']) # noqa: F821
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
kb_ids = req.get("kb_id", [])
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
question = req.get("question")
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("page_size", 30))
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = int(req.get("top_k", 1024))
|
||||
highlight = bool(req.get("highlight", False))
|
||||
|
||||
try:
|
||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
if len(embd_nms) != 1:
|
||||
return get_json_result(
|
||||
data=False, message='Knowledge bases use different embedding models or does not exist."',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, llm_name=kbs[0].embd_id)
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
|
||||
rank_feature=label_question(question, kbs))
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
return get_json_result(data=ranks)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
code=RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
@ -18,12 +18,8 @@ import logging
|
||||
import re
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
import flask
|
||||
import trio
|
||||
from flask import request, Response
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from quart import request, Response, make_response
|
||||
from agent.component import LLM
|
||||
from api.db import CanvasCategory, FileType
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||
@ -35,7 +31,8 @@ from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
|
||||
request_json
|
||||
from agent.canvas import Canvas
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
from api.db.db_models import APIToken, Task
|
||||
@ -46,6 +43,7 @@ from rag.flow.pipeline import Pipeline
|
||||
from rag.nlp import search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/templates', methods=['GET']) # noqa: F821
|
||||
@ -57,8 +55,9 @@ def templates():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@validate_request("canvas_ids")
|
||||
@login_required
|
||||
def rm():
|
||||
for i in request.json["canvas_ids"]:
|
||||
async def rm():
|
||||
req = await request_json()
|
||||
for i in req["canvas_ids"]:
|
||||
if not UserCanvasService.accessible(i, current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -70,8 +69,8 @@ def rm():
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@validate_request("dsl", "title")
|
||||
@login_required
|
||||
def save():
|
||||
req = request.json
|
||||
async def save():
|
||||
req = await request_json()
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
@ -129,8 +128,8 @@ def getsse(canvas_id):
|
||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def run():
|
||||
req = request.json
|
||||
async def run():
|
||||
req = await request_json()
|
||||
query = req.get("query", "")
|
||||
files = req.get("files", [])
|
||||
inputs = req.get("inputs", {})
|
||||
@ -160,10 +159,10 @@ def run():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
def sse():
|
||||
async def sse():
|
||||
nonlocal canvas, user_id
|
||||
try:
|
||||
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
@ -179,15 +178,15 @@ def run():
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
resp.call_on_close(lambda: canvas.cancel_task())
|
||||
#resp.call_on_close(lambda: canvas.cancel_task())
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "dsl", "component_id")
|
||||
@login_required
|
||||
def rerun():
|
||||
req = request.json
|
||||
async def rerun():
|
||||
req = await request_json()
|
||||
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||
if not doc:
|
||||
return get_data_error_result(message="Document not found.")
|
||||
@ -224,8 +223,8 @@ def cancel(task_id):
|
||||
@manager.route('/reset', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def reset():
|
||||
req = request.json
|
||||
async def reset():
|
||||
req = await request_json()
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -245,7 +244,7 @@ def reset():
|
||||
|
||||
|
||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||
def upload(canvas_id):
|
||||
async def upload(canvas_id):
|
||||
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
@ -311,7 +310,8 @@ def upload(canvas_id):
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
file = request.files['file']
|
||||
files = await request.files
|
||||
file = files['file']
|
||||
try:
|
||||
DocumentService.check_doc_health(user_id, file.filename)
|
||||
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
|
||||
@ -342,8 +342,8 @@ def input_form():
|
||||
@manager.route('/debug', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "component_id", "params")
|
||||
@login_required
|
||||
def debug():
|
||||
req = request.json
|
||||
async def debug():
|
||||
req = await request_json()
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -374,8 +374,8 @@ def debug():
|
||||
@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
|
||||
@validate_request("db_type", "database", "username", "host", "port", "password")
|
||||
@login_required
|
||||
def test_db_connect():
|
||||
req = request.json
|
||||
async def test_db_connect():
|
||||
req = await request_json()
|
||||
try:
|
||||
if req["db_type"] in ["mysql", "mariadb"]:
|
||||
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||
@ -426,7 +426,6 @@ def test_db_connect():
|
||||
try:
|
||||
import trino
|
||||
import os
|
||||
from trino.auth import BasicAuthentication
|
||||
except Exception as e:
|
||||
return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}")
|
||||
|
||||
@ -438,7 +437,7 @@ def test_db_connect():
|
||||
|
||||
auth = None
|
||||
if http_scheme == "https" and req.get("password"):
|
||||
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||
auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||
|
||||
conn = trino.dbapi.connect(
|
||||
host=req["host"],
|
||||
@ -471,8 +470,8 @@ def test_db_connect():
|
||||
@login_required
|
||||
def getlistversion(canvas_id):
|
||||
try:
|
||||
list =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
|
||||
return get_json_result(data=list)
|
||||
versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
|
||||
return get_json_result(data=versions)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"Error getting history files: {e}")
|
||||
|
||||
@ -520,8 +519,8 @@ def list_canvas():
|
||||
@manager.route('/setting', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "title", "permission")
|
||||
@login_required
|
||||
def setting():
|
||||
req = request.json
|
||||
async def setting():
|
||||
req = await request_json()
|
||||
req["user_id"] = current_user.id
|
||||
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
@ -602,8 +601,8 @@ def prompts():
|
||||
|
||||
|
||||
@manager.route('/download', methods=['GET']) # noqa: F821
|
||||
def download():
|
||||
async def download():
|
||||
id = request.args.get("id")
|
||||
created_by = request.args.get("created_by")
|
||||
blob = FileService.get_blob(created_by, id)
|
||||
return flask.make_response(blob)
|
||||
return await make_response(blob)
|
||||
|
||||
@ -18,8 +18,7 @@ import json
|
||||
import re
|
||||
|
||||
import xxhash
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import request
|
||||
|
||||
from api.db.services.dialog_service import meta_filter
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -27,7 +26,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
request_json
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
@ -35,13 +35,14 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def list_chunk():
|
||||
req = request.json
|
||||
async def list_chunk():
|
||||
req = await request_json()
|
||||
doc_id = req["doc_id"]
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
@ -121,8 +122,8 @@ def get():
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
||||
def set():
|
||||
req = request.json
|
||||
async def set():
|
||||
req = await request_json()
|
||||
d = {
|
||||
"id": req["chunk_id"],
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
@ -178,8 +179,8 @@ def set():
|
||||
@manager.route('/switch', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "available_int", "doc_id")
|
||||
def switch():
|
||||
req = request.json
|
||||
async def switch():
|
||||
req = await request_json()
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
@ -198,8 +199,8 @@ def switch():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "doc_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await request_json()
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
@ -222,8 +223,8 @@ def rm():
|
||||
@manager.route('/create', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "content_with_weight")
|
||||
def create():
|
||||
req = request.json
|
||||
async def create():
|
||||
req = await request_json()
|
||||
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
|
||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
@ -280,8 +281,8 @@ def create():
|
||||
@manager.route('/retrieval_test', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval_test():
|
||||
req = request.json
|
||||
async def retrieval_test():
|
||||
req = await request_json()
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
@ -304,14 +305,14 @@ def retrieval_test():
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
@ -20,8 +21,7 @@ import uuid
|
||||
from html import escape
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import request, make_response
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
|
||||
from api.db import InputType
|
||||
@ -32,12 +32,13 @@ from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, Docum
|
||||
from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
|
||||
from common.misc_utils import get_uuid
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def set_connector():
|
||||
req = request.json
|
||||
async def set_connector():
|
||||
req = await request.json
|
||||
if req.get("id"):
|
||||
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
|
||||
ConnectorService.update_by_id(req["id"], conn)
|
||||
@ -55,10 +56,9 @@ def set_connector():
|
||||
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
|
||||
"status": TaskStatus.SCHEDULE,
|
||||
}
|
||||
conn["status"] = TaskStatus.SCHEDULE
|
||||
ConnectorService.save(**conn)
|
||||
|
||||
time.sleep(1)
|
||||
await asyncio.sleep(1)
|
||||
e, conn = ConnectorService.get_by_id(req["id"])
|
||||
|
||||
return get_json_result(data=conn.to_dict())
|
||||
@ -89,8 +89,8 @@ def list_logs(connector_id):
|
||||
|
||||
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
def resume(connector_id):
|
||||
req = request.json
|
||||
async def resume(connector_id):
|
||||
req = await request.json
|
||||
if req.get("resume"):
|
||||
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
|
||||
else:
|
||||
@ -101,8 +101,8 @@ def resume(connector_id):
|
||||
@manager.route("/<connector_id>/rebuild", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def rebuild(connector_id):
|
||||
req = request.json
|
||||
async def rebuild(connector_id):
|
||||
req = await request.json
|
||||
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
|
||||
if err:
|
||||
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
|
||||
@ -146,7 +146,7 @@ def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]:
|
||||
return {"web": web_section}
|
||||
|
||||
|
||||
def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
|
||||
async def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
|
||||
status = "success" if success else "error"
|
||||
auto_close = "window.close();" if success else ""
|
||||
escaped_message = escape(message)
|
||||
@ -164,7 +164,7 @@ def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
|
||||
payload_json=payload_json,
|
||||
auto_close=auto_close,
|
||||
)
|
||||
response = make_response(html, 200)
|
||||
response = await make_response(html, 200)
|
||||
response.headers["Content-Type"] = "text/html; charset=utf-8"
|
||||
return response
|
||||
|
||||
@ -172,14 +172,14 @@ def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
|
||||
@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("credentials")
|
||||
def start_google_drive_web_oauth():
|
||||
async def start_google_drive_web_oauth():
|
||||
if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI:
|
||||
return get_json_result(
|
||||
code=RetCode.SERVER_ERROR,
|
||||
message="Google Drive OAuth redirect URI is not configured on the server.",
|
||||
)
|
||||
|
||||
req = request.json or {}
|
||||
req = await request.json or {}
|
||||
raw_credentials = req.get("credentials", "")
|
||||
try:
|
||||
credentials = _load_credentials(raw_credentials)
|
||||
@ -231,31 +231,31 @@ def start_google_drive_web_oauth():
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||
def google_drive_web_oauth_callback():
|
||||
async def google_drive_web_oauth_callback():
|
||||
state_id = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
error_description = request.args.get("error_description") or error
|
||||
|
||||
if not state_id:
|
||||
return _render_web_oauth_popup("", False, "Missing OAuth state parameter.")
|
||||
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.")
|
||||
|
||||
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id))
|
||||
if not state_cache:
|
||||
return _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.")
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.")
|
||||
|
||||
state_obj = json.loads(state_cache)
|
||||
client_config = state_obj.get("client_config")
|
||||
if not client_config:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.")
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.")
|
||||
|
||||
if error:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.")
|
||||
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.")
|
||||
|
||||
code = request.args.get("code")
|
||||
if not code:
|
||||
return _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.")
|
||||
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.")
|
||||
|
||||
try:
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||
@ -264,7 +264,7 @@ def google_drive_web_oauth_callback():
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.")
|
||||
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.")
|
||||
|
||||
creds_json = flow.credentials.to_json()
|
||||
result_payload = {
|
||||
@ -274,14 +274,14 @@ def google_drive_web_oauth_callback():
|
||||
REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
|
||||
return _render_web_oauth_popup(state_id, True, "Authorization completed successfully.")
|
||||
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.")
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("flow_id")
|
||||
def poll_google_drive_web_result():
|
||||
req = request.json or {}
|
||||
async def poll_google_drive_web_result():
|
||||
req = await request.json or {}
|
||||
flow_id = req.get("flow_id")
|
||||
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id))
|
||||
if not cache_raw:
|
||||
|
||||
@ -17,8 +17,8 @@ import json
|
||||
import re
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import Response, request
|
||||
from api.apps import current_user, login_required
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
|
||||
@ -34,8 +34,8 @@ from common.constants import RetCode, LLMType
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def set_conversation():
|
||||
req = request.json
|
||||
async def set_conversation():
|
||||
req = await request.json
|
||||
conv_id = req.get("conversation_id")
|
||||
is_new = req.get("is_new")
|
||||
name = req.get("name", "New conversation")
|
||||
@ -85,7 +85,6 @@ def get():
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
avatar = None
|
||||
for tenant in tenants:
|
||||
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
|
||||
if dialog and len(dialog) > 0:
|
||||
@ -129,8 +128,9 @@ def getsse(dialog_id):
|
||||
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def rm():
|
||||
conv_ids = request.json["conversation_ids"]
|
||||
async def rm():
|
||||
req = await request.json
|
||||
conv_ids = req["conversation_ids"]
|
||||
try:
|
||||
for cid in conv_ids:
|
||||
exist, conv = ConversationService.get_by_id(cid)
|
||||
@ -166,8 +166,8 @@ def list_conversation():
|
||||
@manager.route("/completion", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "messages")
|
||||
def completion():
|
||||
req = request.json
|
||||
async def completion():
|
||||
req = await request.json
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
@ -251,8 +251,8 @@ def completion():
|
||||
|
||||
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def tts():
|
||||
req = request.json
|
||||
async def tts():
|
||||
req = await request.json
|
||||
text = req["text"]
|
||||
|
||||
tenants = TenantService.get_info_by(current_user.id)
|
||||
@ -284,8 +284,8 @@ def tts():
|
||||
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def delete_msg():
|
||||
req = request.json
|
||||
async def delete_msg():
|
||||
req = await request.json
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
@ -307,8 +307,8 @@ def delete_msg():
|
||||
@manager.route("/thumbup", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def thumbup():
|
||||
req = request.json
|
||||
async def thumbup():
|
||||
req = await request.json
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
@ -334,8 +334,8 @@ def thumbup():
|
||||
@manager.route("/ask", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def ask_about():
|
||||
req = request.json
|
||||
async def ask_about():
|
||||
req = await request.json
|
||||
uid = current_user.id
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
@ -366,8 +366,8 @@ def ask_about():
|
||||
@manager.route("/mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def mindmap():
|
||||
req = request.json
|
||||
async def mindmap():
|
||||
req = await request.json
|
||||
search_id = req.get("search_id", "")
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
search_config = search_app.get("search_config", {}) if search_app else {}
|
||||
@ -384,8 +384,8 @@ def mindmap():
|
||||
@manager.route("/related_questions", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question")
|
||||
def related_questions():
|
||||
req = request.json
|
||||
async def related_questions():
|
||||
req = await request.json
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_config = {}
|
||||
|
||||
@ -14,8 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from common.constants import StatusEnum
|
||||
@ -26,13 +25,14 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@validate_request("prompt_config")
|
||||
@login_required
|
||||
def set_dialog():
|
||||
req = request.json
|
||||
async def set_dialog():
|
||||
req = await request.json
|
||||
dialog_id = req.get("dialog_id", "")
|
||||
is_create = not dialog_id
|
||||
name = req.get("name", "New Dialog")
|
||||
@ -154,33 +154,34 @@ def get_kb_names(kb_ids):
|
||||
@login_required
|
||||
def list_dialogs():
|
||||
try:
|
||||
diags = DialogService.query(
|
||||
conversations = DialogService.query(
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value,
|
||||
reverse=True,
|
||||
order_by=DialogService.model.create_time)
|
||||
diags = [d.to_dict() for d in diags]
|
||||
for d in diags:
|
||||
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
|
||||
return get_json_result(data=diags)
|
||||
conversations = [d.to_dict() for d in conversations]
|
||||
for conversation in conversations:
|
||||
conversation["kb_ids"], conversation["kb_names"] = get_kb_names(conversation["kb_ids"])
|
||||
return get_json_result(data=conversations)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/next', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def list_dialogs_next():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
parser_id = request.args.get("parser_id")
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
async def list_dialogs_next():
|
||||
args = request.args
|
||||
keywords = args.get("keywords", "")
|
||||
page_number = int(args.get("page", 0))
|
||||
items_per_page = int(args.get("page_size", 0))
|
||||
parser_id = args.get("parser_id")
|
||||
orderby = args.get("orderby", "create_time")
|
||||
if args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await request.get_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -207,8 +208,8 @@ def list_dialogs_next():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("dialog_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await request.json
|
||||
dialog_list=[]
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
try:
|
||||
|
||||
@ -18,11 +18,8 @@ import os.path
|
||||
import pathlib
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
|
||||
from quart import request, make_response
|
||||
from api.apps import current_user, login_required
|
||||
from api.common.check_team_permission import check_kb_team_permission
|
||||
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
||||
from api.db import VALID_FILE_TYPES, FileType
|
||||
@ -39,7 +36,7 @@ from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
validate_request, request_json,
|
||||
)
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from common.file_utils import get_project_base_directory
|
||||
@ -53,14 +50,16 @@ from common import settings
|
||||
@manager.route("/upload", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def upload():
|
||||
kb_id = request.form.get("kb_id")
|
||||
async def upload():
|
||||
form = await request.form
|
||||
kb_id = form.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
if "file" not in request.files:
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -87,12 +86,13 @@ def upload():
|
||||
@manager.route("/web_crawl", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id", "name", "url")
|
||||
def web_crawl():
|
||||
kb_id = request.form.get("kb_id")
|
||||
async def web_crawl():
|
||||
form = await request.form
|
||||
kb_id = form.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
name = request.form.get("name")
|
||||
url = request.form.get("url")
|
||||
name = form.get("name")
|
||||
url = form.get("url")
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
@ -152,8 +152,8 @@ def web_crawl():
|
||||
@manager.route("/create", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "kb_id")
|
||||
def create():
|
||||
req = request.json
|
||||
async def create():
|
||||
req = await request_json()
|
||||
kb_id = req["kb_id"]
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -208,7 +208,7 @@ def create():
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_docs():
|
||||
async def list_docs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -230,7 +230,7 @@ def list_docs():
|
||||
create_time_from = int(request.args.get("create_time_from", 0))
|
||||
create_time_to = int(request.args.get("create_time_to", 0))
|
||||
|
||||
req = request.get_json()
|
||||
req = await request.get_json()
|
||||
|
||||
run_status = req.get("run_status", [])
|
||||
if run_status:
|
||||
@ -270,8 +270,8 @@ def list_docs():
|
||||
|
||||
@manager.route("/filter", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def get_filter():
|
||||
req = request.get_json()
|
||||
async def get_filter():
|
||||
req = await request.get_json()
|
||||
|
||||
kb_id = req.get("kb_id")
|
||||
if not kb_id:
|
||||
@ -308,8 +308,8 @@ def get_filter():
|
||||
|
||||
@manager.route("/infos", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def docinfos():
|
||||
req = request.json
|
||||
async def doc_infos():
|
||||
req = await request_json()
|
||||
doc_ids = req["doc_ids"]
|
||||
for doc_id in doc_ids:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
@ -340,8 +340,8 @@ def thumbnails():
|
||||
@manager.route("/change_status", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_ids", "status")
|
||||
def change_status():
|
||||
req = request.get_json()
|
||||
async def change_status():
|
||||
req = await request.get_json()
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
status = str(req.get("status", ""))
|
||||
|
||||
@ -380,8 +380,8 @@ def change_status():
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await request_json()
|
||||
doc_ids = req["doc_id"]
|
||||
if isinstance(doc_ids, str):
|
||||
doc_ids = [doc_ids]
|
||||
@ -401,8 +401,8 @@ def rm():
|
||||
@manager.route("/run", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_ids", "run")
|
||||
def run():
|
||||
req = request.json
|
||||
async def run():
|
||||
req = await request_json()
|
||||
for doc_id in req["doc_ids"]:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
@ -448,8 +448,8 @@ def run():
|
||||
@manager.route("/rename", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "name")
|
||||
def rename():
|
||||
req = request.json
|
||||
async def rename():
|
||||
req = await request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
@ -495,19 +495,20 @@ def rename():
|
||||
|
||||
@manager.route("/get/<doc_id>", methods=["GET"]) # noqa: F821
|
||||
# @login_required
|
||||
def get(doc_id):
|
||||
async def get(doc_id):
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
response = flask.make_response(settings.STORAGE_IMPL.get(b, n))
|
||||
response = await make_response(settings.STORAGE_IMPL.get(b, n))
|
||||
|
||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
if ext:
|
||||
if doc.type == FileType.VISUAL.value:
|
||||
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
|
||||
else:
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
|
||||
@ -517,12 +518,28 @@ def get(doc_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/download/<attachment_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def download_attachment(attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id)
|
||||
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/change_parser", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def change_parser():
|
||||
async def change_parser():
|
||||
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
@ -544,6 +561,7 @@ def change_parser():
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
if "pipeline_id" in req and req["pipeline_id"] != "":
|
||||
@ -572,13 +590,13 @@ def change_parser():
|
||||
|
||||
@manager.route("/image/<image_id>", methods=["GET"]) # noqa: F821
|
||||
# @login_required
|
||||
def get_image(image_id):
|
||||
async def get_image(image_id):
|
||||
try:
|
||||
arr = image_id.split("-")
|
||||
if len(arr) != 2:
|
||||
return get_data_error_result(message="Image not found.")
|
||||
bkt, nm = image_id.split("-")
|
||||
response = flask.make_response(settings.STORAGE_IMPL.get(bkt, nm))
|
||||
response = await make_response(settings.STORAGE_IMPL.get(bkt, nm))
|
||||
response.headers.set("Content-Type", "image/JPEG")
|
||||
return response
|
||||
except Exception as e:
|
||||
@ -588,24 +606,25 @@ def get_image(image_id):
|
||||
@manager.route("/upload_and_parse", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id")
|
||||
def upload_and_parse():
|
||||
if "file" not in request.files:
|
||||
async def upload_and_parse():
|
||||
files = await request.file
|
||||
if "file" not in files:
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
|
||||
|
||||
form = await request.form
|
||||
doc_ids = doc_upload_and_parse(form.get("conversation_id"), file_objs, current_user.id)
|
||||
return get_json_result(data=doc_ids)
|
||||
|
||||
|
||||
@manager.route("/parse", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def parse():
|
||||
url = request.json.get("url") if request.json else ""
|
||||
async def parse():
|
||||
url = await request.json.get("url") if await request.json else ""
|
||||
if url:
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -646,10 +665,11 @@ def parse():
|
||||
txt = FileService.parse_docs([f], current_user.id)
|
||||
return get_json_result(data=txt)
|
||||
|
||||
if "file" not in request.files:
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
txt = FileService.parse_docs(file_objs, current_user.id)
|
||||
|
||||
return get_json_result(data=txt)
|
||||
@ -658,8 +678,8 @@ def parse():
|
||||
@manager.route("/set_meta", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "meta")
|
||||
def set_meta():
|
||||
req = request.json
|
||||
async def set_meta():
|
||||
req = await request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
|
||||
@ -19,8 +19,8 @@ from pathlib import Path
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
from api.apps import login_required, current_user
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from common.misc_utils import get_uuid
|
||||
@ -33,8 +33,8 @@ from api.utils.api_utils import get_json_result
|
||||
@manager.route('/convert', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids", "kb_ids")
|
||||
def convert():
|
||||
req = request.json
|
||||
async def convert():
|
||||
req = await request.json
|
||||
kb_ids = req["kb_ids"]
|
||||
file_ids = req["file_ids"]
|
||||
file2documents = []
|
||||
@ -103,8 +103,8 @@ def convert():
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await request.json
|
||||
file_ids = req["file_ids"]
|
||||
if not file_ids:
|
||||
return get_json_result(
|
||||
|
||||
@ -17,10 +17,8 @@ import logging
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request, make_response
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
from api.common.check_team_permission import check_file_team_permission
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -40,17 +38,19 @@ from common import settings
|
||||
@manager.route('/upload', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
# @validate_request("parent_id")
|
||||
def upload():
|
||||
pf_id = request.form.get("parent_id")
|
||||
async def upload():
|
||||
form = await request.form
|
||||
pf_id = form.get("parent_id")
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
if 'file' not in request.files:
|
||||
files = await request.files
|
||||
if 'file' not in files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
file_objs = request.files.getlist('file')
|
||||
file_objs = files.getlist('file')
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
@ -123,10 +123,10 @@ def upload():
|
||||
@manager.route('/create', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.json
|
||||
pf_id = request.json.get("parent_id")
|
||||
input_file_type = request.json.get("type")
|
||||
async def create():
|
||||
req = await request.json
|
||||
pf_id = req.get("parent_id")
|
||||
input_file_type = req.get("type")
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
@ -238,16 +238,16 @@ def get_all_parent_folders():
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await request.json
|
||||
file_ids = req["file_ids"]
|
||||
|
||||
def _delete_single_file(file):
|
||||
try:
|
||||
if file.location:
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
except Exception:
|
||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}")
|
||||
except Exception as e:
|
||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(file.id)
|
||||
for inform in informs:
|
||||
@ -299,8 +299,8 @@ def rm():
|
||||
@manager.route('/rename', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_id", "name")
|
||||
def rename():
|
||||
req = request.json
|
||||
async def rename():
|
||||
req = await request.json
|
||||
try:
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
@ -338,7 +338,7 @@ def rename():
|
||||
|
||||
@manager.route('/get/<file_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get(file_id):
|
||||
async def get(file_id):
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
@ -351,7 +351,7 @@ def get(file_id):
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = settings.STORAGE_IMPL.get(b, n)
|
||||
|
||||
response = flask.make_response(blob)
|
||||
response = await make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
if ext:
|
||||
@ -368,8 +368,8 @@ def get(file_id):
|
||||
@manager.route("/mv", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("src_file_ids", "dest_file_id")
|
||||
def move():
|
||||
req = request.json
|
||||
async def move():
|
||||
req = await request.json
|
||||
try:
|
||||
file_ids = req["src_file_ids"]
|
||||
dest_parent_id = req["dest_file_id"]
|
||||
|
||||
@ -16,12 +16,11 @@
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
import numpy as np
|
||||
|
||||
|
||||
from api.db.services.connector_service import Connector2KbService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
@ -30,7 +29,8 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
|
||||
request_json
|
||||
from api.db import VALID_FILE_TYPES
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
@ -41,23 +41,28 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/create', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.json
|
||||
req = KnowledgebaseService.create_with_name(
|
||||
async def create():
|
||||
req = await request_json()
|
||||
e, res = KnowledgebaseService.create_with_name(
|
||||
name = req.pop("name", None),
|
||||
tenant_id = current_user.id,
|
||||
parser_id = req.pop("parser_id", None),
|
||||
**req
|
||||
)
|
||||
|
||||
if not e:
|
||||
return res
|
||||
|
||||
try:
|
||||
if not KnowledgebaseService.save(**req):
|
||||
if not KnowledgebaseService.save(**res):
|
||||
return get_data_error_result()
|
||||
return get_json_result(data={"kb_id":req["id"]})
|
||||
return get_json_result(data={"kb_id":res["id"]})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -66,8 +71,8 @@ def create():
|
||||
@login_required
|
||||
@validate_request("kb_id", "name", "description", "parser_id")
|
||||
@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||
def update():
|
||||
req = request.json
|
||||
async def update():
|
||||
req = await request_json()
|
||||
if not isinstance(req["name"], str):
|
||||
return get_data_error_result(message="Dataset name must be string.")
|
||||
if req["name"].strip() == "":
|
||||
@ -165,18 +170,19 @@ def detail():
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def list_kbs():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
parser_id = request.args.get("parser_id")
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
async def list_kbs():
|
||||
args = request.args
|
||||
keywords = args.get("keywords", "")
|
||||
page_number = int(args.get("page", 0))
|
||||
items_per_page = int(args.get("page_size", 0))
|
||||
parser_id = args.get("parser_id")
|
||||
orderby = args.get("orderby", "create_time")
|
||||
if args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await request_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -198,11 +204,12 @@ def list_kbs():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
async def rm():
|
||||
req = await request_json()
|
||||
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -278,8 +285,8 @@ def list_tags_from_kbs():
|
||||
|
||||
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def rm_tags(kb_id):
|
||||
req = request.json
|
||||
async def rm_tags(kb_id):
|
||||
req = await request_json()
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -298,8 +305,8 @@ def rm_tags(kb_id):
|
||||
|
||||
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def rename_tags(kb_id):
|
||||
req = request.json
|
||||
async def rename_tags(kb_id):
|
||||
req = await request_json()
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -402,7 +409,7 @@ def get_basic_info():
|
||||
|
||||
@manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_pipeline_logs():
|
||||
async def list_pipeline_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -421,7 +428,7 @@ def list_pipeline_logs():
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = request.get_json()
|
||||
req = await request_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
@ -446,7 +453,7 @@ def list_pipeline_logs():
|
||||
|
||||
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_pipeline_dataset_logs():
|
||||
async def list_pipeline_dataset_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -463,7 +470,7 @@ def list_pipeline_dataset_logs():
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = request.get_json()
|
||||
req = await request_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
@ -480,12 +487,12 @@ def list_pipeline_dataset_logs():
|
||||
|
||||
@manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def delete_pipeline_logs():
|
||||
async def delete_pipeline_logs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
req = request.get_json()
|
||||
req = await request_json()
|
||||
log_ids = req.get("log_ids", [])
|
||||
|
||||
PipelineOperationLogService.delete_by_ids(log_ids)
|
||||
@ -509,8 +516,8 @@ def pipeline_log_detail():
|
||||
|
||||
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_graphrag():
|
||||
req = request.json
|
||||
async def run_graphrag():
|
||||
req = await request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -578,8 +585,8 @@ def trace_graphrag():
|
||||
|
||||
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_raptor():
|
||||
req = request.json
|
||||
async def run_raptor():
|
||||
req = await request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -647,8 +654,8 @@ def trace_raptor():
|
||||
|
||||
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_mindmap():
|
||||
req = request.json
|
||||
async def run_mindmap():
|
||||
req = await request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -731,6 +738,8 @@ def delete_kb_task():
|
||||
def cancel_task(task_id):
|
||||
REDIS_CONN.set(f"{task_id}-cancel", "x")
|
||||
|
||||
kb_task_id_field: str = ""
|
||||
kb_task_finish_at: str = ""
|
||||
match pipeline_task_type:
|
||||
case PipelineTaskType.GRAPH_RAG:
|
||||
kb_task_id_field = "graphrag_task_id"
|
||||
@ -761,7 +770,7 @@ def delete_kb_task():
|
||||
|
||||
@manager.route("/check_embedding", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
def check_embedding():
|
||||
async def check_embedding():
|
||||
|
||||
def _guess_vec_field(src: dict) -> str | None:
|
||||
for k in src or {}:
|
||||
@ -807,12 +816,12 @@ def check_embedding():
|
||||
offset=0, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
)
|
||||
total = docStoreConn.getTotal(res0)
|
||||
total = docStoreConn.get_total(res0)
|
||||
if total <= 0:
|
||||
return []
|
||||
|
||||
n = min(n, total)
|
||||
offsets = sorted(random.sample(range(total), n))
|
||||
offsets = sorted(random.sample(range(min(total,1000)), n))
|
||||
out = []
|
||||
|
||||
for off in offsets:
|
||||
@ -824,7 +833,7 @@ def check_embedding():
|
||||
offset=off, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
)
|
||||
ids = docStoreConn.getChunkIds(res1)
|
||||
ids = docStoreConn.get_chunk_ids(res1)
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
@ -845,9 +854,14 @@ def check_embedding():
|
||||
"position_int": full_doc.get("position_int"),
|
||||
"top_int": full_doc.get("top_int"),
|
||||
"content_with_weight": full_doc.get("content_with_weight") or "",
|
||||
"question_kwd": full_doc.get("question_kwd") or []
|
||||
})
|
||||
return out
|
||||
req = request.json
|
||||
|
||||
def _clean(s: str) -> str:
|
||||
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
|
||||
return s if s else "None"
|
||||
req = await request_json()
|
||||
kb_id = req.get("kb_id", "")
|
||||
embd_id = req.get("embd_id", "")
|
||||
n = int(req.get("check_num", 5))
|
||||
@ -859,8 +873,10 @@ def check_embedding():
|
||||
|
||||
results, eff_sims = [], []
|
||||
for ck in samples:
|
||||
txt = (ck.get("content_with_weight") or "").strip()
|
||||
if not txt:
|
||||
title = ck.get("doc_name") or "Title"
|
||||
txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
|
||||
txt_in = _clean(txt_in)
|
||||
if not txt_in:
|
||||
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
|
||||
continue
|
||||
|
||||
@ -869,10 +885,19 @@ def check_embedding():
|
||||
continue
|
||||
|
||||
try:
|
||||
qv, _ = emb_mdl.encode_queries(txt)
|
||||
sim = _cos_sim(qv, ck["vector"])
|
||||
except Exception:
|
||||
return get_error_data_result(message="embedding failure")
|
||||
v, _ = emb_mdl.encode([title, txt_in])
|
||||
assert len(v[1]) == len(ck["vector"]), f"The dimension ({len(v[1])}) of given embedding model is different from the original ({len(ck['vector'])})"
|
||||
sim_content = _cos_sim(v[1], ck["vector"])
|
||||
title_w = 0.1
|
||||
qv_mix = title_w * v[0] + (1 - title_w) * v[1]
|
||||
sim_mix = _cos_sim(qv_mix, ck["vector"])
|
||||
sim = sim_content
|
||||
mode = "content_only"
|
||||
if sim_mix > sim:
|
||||
sim = sim_mix
|
||||
mode = "title+content"
|
||||
except Exception as e:
|
||||
return get_error_data_result(message=f"Embedding failure. {e}")
|
||||
|
||||
eff_sims.append(sim)
|
||||
results.append({
|
||||
@ -892,9 +917,10 @@ def check_embedding():
|
||||
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"match_mode": mode,
|
||||
}
|
||||
if summary["avg_cos_sim"] > 0.99:
|
||||
if summary["avg_cos_sim"] > 0.9:
|
||||
return get_json_result(data={"summary": summary, "results": results})
|
||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})
|
||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})
|
||||
|
||||
|
||||
|
||||
@ -15,8 +15,8 @@
|
||||
#
|
||||
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import request
|
||||
from api.apps import current_user, login_required
|
||||
from langfuse import Langfuse
|
||||
|
||||
from api.db.db_models import DB
|
||||
@ -27,8 +27,8 @@ from api.utils.api_utils import get_error_data_result, get_json_result, server_e
|
||||
@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("secret_key", "public_key", "host")
|
||||
def set_api_key():
|
||||
req = request.get_json()
|
||||
async def set_api_key():
|
||||
req = await request.get_json()
|
||||
secret_key = req.get("secret_key", "")
|
||||
public_key = req.get("public_key", "")
|
||||
host = req.get("host", "")
|
||||
|
||||
@ -16,8 +16,9 @@
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from quart import request
|
||||
|
||||
from api.apps import login_required, current_user
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
@ -52,8 +53,8 @@ def factories():
|
||||
@manager.route("/set_api_key", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "api_key")
|
||||
def set_api_key():
|
||||
req = request.json
|
||||
async def set_api_key():
|
||||
req = await request.json
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
factory = req["llm_factory"]
|
||||
@ -122,8 +123,8 @@ def set_api_key():
|
||||
@manager.route("/add_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
def add_llm():
|
||||
req = request.json
|
||||
async def add_llm():
|
||||
req = await request.json
|
||||
factory = req["llm_factory"]
|
||||
api_key = req.get("api_key", "x")
|
||||
llm_name = req.get("llm_name")
|
||||
@ -142,11 +143,11 @@ def add_llm():
|
||||
|
||||
elif factory == "Tencent Hunyuan":
|
||||
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
|
||||
return set_api_key()
|
||||
return await set_api_key()
|
||||
|
||||
elif factory == "Tencent Cloud":
|
||||
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
|
||||
return set_api_key()
|
||||
return await set_api_key()
|
||||
|
||||
elif factory == "Bedrock":
|
||||
# For Bedrock, due to its special authentication method
|
||||
@ -267,8 +268,8 @@ def add_llm():
|
||||
@manager.route("/delete_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
def delete_llm():
|
||||
req = request.json
|
||||
async def delete_llm():
|
||||
req = await request.json
|
||||
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -276,8 +277,8 @@ def delete_llm():
|
||||
@manager.route("/enable_llm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
def enable_llm():
|
||||
req = request.json
|
||||
async def enable_llm():
|
||||
req = await request.json
|
||||
TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
|
||||
)
|
||||
@ -287,8 +288,8 @@ def enable_llm():
|
||||
@manager.route("/delete_factory", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
def delete_factory():
|
||||
req = request.json
|
||||
async def delete_factory():
|
||||
req = await request.json
|
||||
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import Response, request
|
||||
from api.apps import current_user, login_required
|
||||
|
||||
from api.db.db_models import MCPServer
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
@ -25,12 +25,12 @@ from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
get_mcp_tools
|
||||
from api.utils.web_utils import get_float, safe_json_parse
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_mcp() -> Response:
|
||||
async def list_mcp() -> Response:
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
@ -40,7 +40,7 @@ def list_mcp() -> Response:
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await request.get_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
try:
|
||||
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
|
||||
@ -72,8 +72,8 @@ def detail() -> Response:
|
||||
@manager.route("/create", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "url", "server_type")
|
||||
def create() -> Response:
|
||||
req = request.get_json()
|
||||
async def create() -> Response:
|
||||
req = await request.get_json()
|
||||
|
||||
server_type = req.get("server_type", "")
|
||||
if server_type not in VALID_MCP_SERVER_TYPES:
|
||||
@ -127,8 +127,8 @@ def create() -> Response:
|
||||
@manager.route("/update", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id")
|
||||
def update() -> Response:
|
||||
req = request.get_json()
|
||||
async def update() -> Response:
|
||||
req = await request.get_json()
|
||||
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
@ -183,8 +183,8 @@ def update() -> Response:
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def rm() -> Response:
|
||||
req = request.get_json()
|
||||
async def rm() -> Response:
|
||||
req = await request.get_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
|
||||
try:
|
||||
@ -201,8 +201,8 @@ def rm() -> Response:
|
||||
@manager.route("/import", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcpServers")
|
||||
def import_multiple() -> Response:
|
||||
req = request.get_json()
|
||||
async def import_multiple() -> Response:
|
||||
req = await request.get_json()
|
||||
servers = req.get("mcpServers", {})
|
||||
if not servers:
|
||||
return get_data_error_result(message="No MCP servers provided.")
|
||||
@ -268,8 +268,8 @@ def import_multiple() -> Response:
|
||||
@manager.route("/export", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def export_multiple() -> Response:
|
||||
req = request.get_json()
|
||||
async def export_multiple() -> Response:
|
||||
req = await request.get_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
|
||||
if not mcp_ids:
|
||||
@ -300,8 +300,8 @@ def export_multiple() -> Response:
|
||||
@manager.route("/list_tools", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def list_tools() -> Response:
|
||||
req = request.get_json()
|
||||
async def list_tools() -> Response:
|
||||
req = await request.get_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
if not mcp_ids:
|
||||
return get_data_error_result(message="No MCP server IDs provided.")
|
||||
@ -347,8 +347,8 @@ def list_tools() -> Response:
|
||||
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id", "tool_name", "arguments")
|
||||
def test_tool() -> Response:
|
||||
req = request.get_json()
|
||||
async def test_tool() -> Response:
|
||||
req = await request.get_json()
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
@ -380,8 +380,8 @@ def test_tool() -> Response:
|
||||
@manager.route("/cache_tools", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id", "tools")
|
||||
def cache_tool() -> Response:
|
||||
req = request.get_json()
|
||||
async def cache_tool() -> Response:
|
||||
req = await request.get_json()
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
@ -403,8 +403,8 @@ def cache_tool() -> Response:
|
||||
|
||||
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821
|
||||
@validate_request("url", "server_type")
|
||||
def test_mcp() -> Response:
|
||||
req = request.get_json()
|
||||
async def test_mcp() -> Response:
|
||||
req = await request.get_json()
|
||||
|
||||
url = req.get("url", "")
|
||||
if not url:
|
||||
|
||||
@ -15,8 +15,8 @@
|
||||
#
|
||||
|
||||
|
||||
from flask import Response
|
||||
from flask_login import login_required
|
||||
from quart import Response
|
||||
from api.apps import login_required
|
||||
from api.utils.api_utils import get_json_result
|
||||
from plugin import GlobalPluginManager
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required
|
||||
from api.utils.api_utils import get_result
|
||||
from flask import request, Response
|
||||
from quart import request, Response
|
||||
|
||||
|
||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||
@ -41,19 +41,19 @@ def list_agents(tenant_id):
|
||||
return get_error_data_result("The agent doesn't exist.")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 30))
|
||||
orderby = request.args.get("orderby", "update_time")
|
||||
order_by = request.args.get("orderby", "update_time")
|
||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title)
|
||||
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, order_by, desc, id, title)
|
||||
return get_result(data=canvas)
|
||||
|
||||
|
||||
@manager.route("/agents", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create_agent(tenant_id: str):
|
||||
req: dict[str, Any] = cast(dict[str, Any], request.json)
|
||||
async def create_agent(tenant_id: str):
|
||||
req: dict[str, Any] = cast(dict[str, Any], await request.json)
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
@ -89,8 +89,8 @@ def create_agent(tenant_id: str):
|
||||
|
||||
@manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update_agent(tenant_id: str, agent_id: str):
|
||||
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], request.json).items() if v is not None}
|
||||
async def update_agent(tenant_id: str, agent_id: str):
|
||||
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await request.json)).items() if v is not None}
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
@ -135,8 +135,8 @@ def delete_agent(tenant_id: str, agent_id: str):
|
||||
|
||||
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def webhook(tenant_id: str, agent_id: str):
|
||||
req = request.json
|
||||
async def webhook(tenant_id: str, agent_id: str):
|
||||
req = await request.json
|
||||
if not UserCanvasService.accessible(req["id"], tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -159,10 +159,10 @@ def webhook(tenant_id: str, agent_id: str):
|
||||
data=False, message=str(e),
|
||||
code=RetCode.EXCEPTION_ERROR)
|
||||
|
||||
def sse():
|
||||
async def sse():
|
||||
nonlocal canvas
|
||||
try:
|
||||
for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req):
|
||||
async for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
|
||||
@ -14,22 +14,20 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
|
||||
from quart import request
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode, StatusEnum
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, request_json
|
||||
|
||||
|
||||
@manager.route("/chats", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
req = request.json
|
||||
async def create(tenant_id):
|
||||
req = await request_json()
|
||||
ids = [i for i in req.get("dataset_ids", []) if i]
|
||||
for kb_id in ids:
|
||||
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
|
||||
@ -145,10 +143,10 @@ def create(tenant_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, chat_id):
|
||||
async def update(tenant_id, chat_id):
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message="You do not own the chat")
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
ids = req.get("dataset_ids", [])
|
||||
if "show_quotation" in req:
|
||||
req["do_refer"] = req.pop("show_quotation")
|
||||
@ -228,10 +226,10 @@ def update(tenant_id, chat_id):
|
||||
|
||||
@manager.route("/chats", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id):
|
||||
async def delete_chats(tenant_id):
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if not req:
|
||||
ids = None
|
||||
else:
|
||||
@ -251,8 +249,8 @@ def delete(tenant_id):
|
||||
errors.append(f"Assistant({id}) not found.")
|
||||
continue
|
||||
temp_dict = {"status": StatusEnum.INVALID.value}
|
||||
DialogService.update_by_id(id, temp_dict)
|
||||
success_count += 1
|
||||
success_count += DialogService.update_by_id(id, temp_dict)
|
||||
print(success_count, "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$", flush=True)
|
||||
|
||||
if errors:
|
||||
if success_count > 0:
|
||||
|
||||
@ -18,13 +18,14 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from flask import request
|
||||
from quart import request
|
||||
from peewee import OperationalError
|
||||
from api.db.db_models import File
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.constants import RetCode, FileSource, StatusEnum
|
||||
from api.utils.api_utils import (
|
||||
@ -53,7 +54,7 @@ from common import settings
|
||||
|
||||
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
async def create(tenant_id):
|
||||
"""
|
||||
Create a new dataset.
|
||||
---
|
||||
@ -115,17 +116,19 @@ def create(tenant_id):
|
||||
# | embedding_model| embd_id |
|
||||
# | chunk_method | parser_id |
|
||||
|
||||
req, err = validate_and_parse_json_request(request, CreateDatasetReq)
|
||||
req, err = await validate_and_parse_json_request(request, CreateDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
req = KnowledgebaseService.create_with_name(
|
||||
e, req = KnowledgebaseService.create_with_name(
|
||||
name = req.pop("name", None),
|
||||
tenant_id = tenant_id,
|
||||
parser_id = req.pop("parser_id", None),
|
||||
**req
|
||||
)
|
||||
|
||||
if not e:
|
||||
return req
|
||||
|
||||
# Insert embedding model(embd id)
|
||||
ok, t = TenantService.get_by_id(tenant_id)
|
||||
if not ok:
|
||||
@ -144,7 +147,6 @@ def create(tenant_id):
|
||||
ok, k = KnowledgebaseService.get_by_id(req["id"])
|
||||
if not ok:
|
||||
return get_error_data_result(message="Dataset created failed")
|
||||
|
||||
response_data = remap_dictionary_keys(k.to_dict())
|
||||
return get_result(data=response_data)
|
||||
except Exception as e:
|
||||
@ -153,7 +155,7 @@ def create(tenant_id):
|
||||
|
||||
@manager.route("/datasets", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id):
|
||||
async def delete(tenant_id):
|
||||
"""
|
||||
Delete datasets.
|
||||
---
|
||||
@ -191,7 +193,7 @@ def delete(tenant_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req, err = validate_and_parse_json_request(request, DeleteDatasetReq)
|
||||
req, err = await validate_and_parse_json_request(request, DeleteDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
@ -251,7 +253,7 @@ def delete(tenant_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, dataset_id):
|
||||
async def update(tenant_id, dataset_id):
|
||||
"""
|
||||
Update a dataset.
|
||||
---
|
||||
@ -317,7 +319,7 @@ def update(tenant_id, dataset_id):
|
||||
# | embedding_model| embd_id |
|
||||
# | chunk_method | parser_id |
|
||||
extras = {"dataset_id": dataset_id}
|
||||
req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
|
||||
req, err = await validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
@ -532,3 +534,157 @@ def delete_knowledge_graph(tenant_id, dataset_id):
|
||||
search.index_name(kb.tenant_id), dataset_id)
|
||||
|
||||
return get_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/run_graphrag", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def run_graphrag(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=dataset_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
||||
logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}")
|
||||
|
||||
return get_result(data={"graphrag_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/trace_graphrag", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def trace_graphrag(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if not task_id:
|
||||
return get_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_result(data={})
|
||||
|
||||
return get_result(data=task.to_dict())
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/run_raptor", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def run_raptor(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=dataset_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
|
||||
logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}")
|
||||
|
||||
return get_result(data={"raptor_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/trace_raptor", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def trace_raptor(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if not task_id:
|
||||
return get_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
|
||||
|
||||
return get_result(data=task.to_dict())
|
||||
@ -15,7 +15,7 @@
|
||||
#
|
||||
import logging
|
||||
|
||||
from flask import request, jsonify
|
||||
from quart import request, jsonify
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -29,7 +29,7 @@ from common import settings
|
||||
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
|
||||
@apikey_required
|
||||
@validate_request("knowledge_id", "query")
|
||||
def retrieval(tenant_id):
|
||||
async def retrieval(tenant_id):
|
||||
"""
|
||||
Dify-compatible retrieval API
|
||||
---
|
||||
@ -113,14 +113,14 @@ def retrieval(tenant_id):
|
||||
404:
|
||||
description: Knowledge base or document not found
|
||||
"""
|
||||
req = request.json
|
||||
req = await request.json
|
||||
question = req["query"]
|
||||
kb_id = req["knowledge_id"]
|
||||
use_kg = req.get("use_kg", False)
|
||||
retrieval_setting = req.get("retrieval_setting", {})
|
||||
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
|
||||
top = int(retrieval_setting.get("top_k", 1024))
|
||||
metadata_condition = req.get("metadata_condition", {})
|
||||
metadata_condition = req.get("metadata_condition", {}) or {}
|
||||
metas = DocumentService.get_meta_by_kbs([kb_id])
|
||||
|
||||
doc_ids = []
|
||||
@ -131,12 +131,10 @@ def retrieval(tenant_id):
|
||||
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
print(metadata_condition)
|
||||
# print("after", convert_conditions(metadata_condition))
|
||||
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition)))
|
||||
# print("doc_ids", doc_ids)
|
||||
if not doc_ids and metadata_condition is not None:
|
||||
doc_ids = ['-999']
|
||||
if metadata_condition:
|
||||
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||
if not doc_ids and metadata_condition:
|
||||
doc_ids = ["-999"]
|
||||
ranks = settings.retriever.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
|
||||
@ -20,7 +20,7 @@ import re
|
||||
from io import BytesIO
|
||||
|
||||
import xxhash
|
||||
from flask import request, send_file
|
||||
from quart import request, send_file
|
||||
from peewee import OperationalError
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
@ -35,7 +35,8 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.task_service import TaskService, queue_tasks
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
|
||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
|
||||
request_json
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
@ -69,7 +70,7 @@ class Chunk(BaseModel):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def upload(dataset_id, tenant_id):
|
||||
async def upload(dataset_id, tenant_id):
|
||||
"""
|
||||
Upload documents to a dataset.
|
||||
---
|
||||
@ -93,6 +94,10 @@ def upload(dataset_id, tenant_id):
|
||||
type: file
|
||||
required: true
|
||||
description: Document files to upload.
|
||||
- in: formData
|
||||
name: parent_path
|
||||
type: string
|
||||
description: Optional nested path under the parent folder. Uses '/' separators.
|
||||
responses:
|
||||
200:
|
||||
description: Successfully uploaded documents.
|
||||
@ -126,9 +131,11 @@ def upload(dataset_id, tenant_id):
|
||||
type: string
|
||||
description: Processing status.
|
||||
"""
|
||||
if "file" not in request.files:
|
||||
form = await request.form
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_error_data_result(message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
file_objs = request.files.getlist("file")
|
||||
file_objs = files.getlist("file")
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_result(message="No file selected!", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -151,7 +158,7 @@ def upload(dataset_id, tenant_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not e:
|
||||
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
|
||||
err, files = FileService.upload_document(kb, file_objs, tenant_id)
|
||||
err, files = FileService.upload_document(kb, file_objs, tenant_id, parent_path=form.get("parent_path"))
|
||||
if err:
|
||||
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||
# rename key's name
|
||||
@ -175,7 +182,7 @@ def upload(dataset_id, tenant_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update_doc(tenant_id, dataset_id, document_id):
|
||||
async def update_doc(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Update a document within a dataset.
|
||||
---
|
||||
@ -224,7 +231,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
|
||||
return get_error_data_result(message="You don't own the dataset.")
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
@ -355,7 +362,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def download(tenant_id, dataset_id, document_id):
|
||||
async def download(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Download a document from a dataset.
|
||||
---
|
||||
@ -405,10 +412,10 @@ def download(tenant_id, dataset_id, document_id):
|
||||
return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR)
|
||||
file = BytesIO(file_stream)
|
||||
# Use send_file with a proper filename and MIME type
|
||||
return send_file(
|
||||
return await send_file(
|
||||
file,
|
||||
as_attachment=True,
|
||||
download_name=doc[0].name,
|
||||
attachment_filename=doc[0].name,
|
||||
mimetype="application/octet-stream", # Set a default MIME type
|
||||
)
|
||||
|
||||
@ -585,7 +592,7 @@ def list_docs(dataset_id, tenant_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id, dataset_id):
|
||||
async def delete(tenant_id, dataset_id):
|
||||
"""
|
||||
Delete documents from a dataset.
|
||||
---
|
||||
@ -624,7 +631,7 @@ def delete(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if not req:
|
||||
doc_ids = None
|
||||
else:
|
||||
@ -695,7 +702,7 @@ def delete(tenant_id, dataset_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/chunks", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def parse(tenant_id, dataset_id):
|
||||
async def parse(tenant_id, dataset_id):
|
||||
"""
|
||||
Start parsing documents into chunks.
|
||||
---
|
||||
@ -734,7 +741,7 @@ def parse(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if not req.get("document_ids"):
|
||||
return get_error_data_result("`document_ids` is required")
|
||||
doc_list = req.get("document_ids")
|
||||
@ -778,7 +785,7 @@ def parse(tenant_id, dataset_id):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/chunks", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def stop_parsing(tenant_id, dataset_id):
|
||||
async def stop_parsing(tenant_id, dataset_id):
|
||||
"""
|
||||
Stop parsing documents into chunks.
|
||||
---
|
||||
@ -817,7 +824,7 @@ def stop_parsing(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
|
||||
if not req.get("document_ids"):
|
||||
return get_error_data_result("`document_ids` is required")
|
||||
@ -1019,7 +1026,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
"/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["POST"]
|
||||
)
|
||||
@token_required
|
||||
def add_chunk(tenant_id, dataset_id, document_id):
|
||||
async def add_chunk(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Add a chunk to a document.
|
||||
---
|
||||
@ -1089,7 +1096,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
||||
if not doc:
|
||||
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
||||
doc = doc[0]
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if not str(req.get("content", "")).strip():
|
||||
return get_error_data_result(message="`content` is required")
|
||||
if "important_keywords" in req:
|
||||
@ -1148,7 +1155,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
||||
"datasets/<dataset_id>/documents/<document_id>/chunks", methods=["DELETE"]
|
||||
)
|
||||
@token_required
|
||||
def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
async def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
Remove chunks from a document.
|
||||
---
|
||||
@ -1195,7 +1202,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
docs = DocumentService.get_by_ids([document_id])
|
||||
if not docs:
|
||||
raise LookupError(f"Can't find the document with ID {document_id}!")
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
condition = {"doc_id": document_id}
|
||||
if "chunk_ids" in req:
|
||||
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
|
||||
@ -1219,7 +1226,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
"/datasets/<dataset_id>/documents/<document_id>/chunks/<chunk_id>", methods=["PUT"]
|
||||
)
|
||||
@token_required
|
||||
def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
"""
|
||||
Update a chunk within a document.
|
||||
---
|
||||
@ -1281,8 +1288,8 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
if not doc:
|
||||
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
||||
doc = doc[0]
|
||||
req = request.json
|
||||
if "content" in req:
|
||||
req = await request_json()
|
||||
if "content" in req and req["content"] is not None:
|
||||
content = req["content"]
|
||||
else:
|
||||
content = chunk.get("content_with_weight", "")
|
||||
@ -1323,7 +1330,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
|
||||
@manager.route("/retrieval", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def retrieval_test(tenant_id):
|
||||
async def retrieval_test(tenant_id):
|
||||
"""
|
||||
Retrieve chunks based on a query.
|
||||
---
|
||||
@ -1404,7 +1411,7 @@ def retrieval_test(tenant_id):
|
||||
format: float
|
||||
description: Similarity score.
|
||||
"""
|
||||
req = request.json
|
||||
req = await request_json()
|
||||
if not req.get("dataset_ids"):
|
||||
return get_error_data_result("`dataset_ids` is required.")
|
||||
kb_ids = req["dataset_ids"]
|
||||
@ -1427,6 +1434,7 @@ def retrieval_test(tenant_id):
|
||||
question = req["question"]
|
||||
doc_ids = req.get("document_ids", [])
|
||||
use_kg = req.get("use_kg", False)
|
||||
toc_enhance = req.get("toc_enhance", False)
|
||||
langs = req.get("cross_languages", [])
|
||||
if not isinstance(doc_ids, list):
|
||||
return get_error_data_result("`documents` should be a list")
|
||||
@ -1435,9 +1443,11 @@ def retrieval_test(tenant_id):
|
||||
if doc_id not in doc_ids_list:
|
||||
return get_error_data_result(f"The datasets don't own the document {doc_id}")
|
||||
if not doc_ids:
|
||||
metadata_condition = req.get("metadata_condition", {})
|
||||
metadata_condition = req.get("metadata_condition", {}) or {}
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
doc_ids = meta_filter(metas, convert_conditions(metadata_condition))
|
||||
doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
|
||||
if metadata_condition and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = int(req.get("top_k", 1024))
|
||||
@ -1478,6 +1488,11 @@ def retrieval_test(tenant_id):
|
||||
highlight=highlight,
|
||||
rank_feature=label_question(question, kbs),
|
||||
)
|
||||
if toc_enhance:
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
cks = settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size)
|
||||
if cks:
|
||||
ranks["chunks"] = cks
|
||||
if use_kg:
|
||||
ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
|
||||
@ -17,9 +17,7 @@
|
||||
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from quart import request, make_response
|
||||
from pathlib import Path
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -33,11 +31,11 @@ from api.db.services.file_service import FileService
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.file_utils import filename_type
|
||||
from common import settings
|
||||
|
||||
from common.constants import RetCode
|
||||
|
||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def upload(tenant_id):
|
||||
async def upload(tenant_id):
|
||||
"""
|
||||
Upload a file to the system.
|
||||
---
|
||||
@ -79,26 +77,28 @@ def upload(tenant_id):
|
||||
type: string
|
||||
description: File type (e.g., document, folder)
|
||||
"""
|
||||
pf_id = request.form.get("parent_id")
|
||||
form = await request.form
|
||||
files = await request.files
|
||||
pf_id = form.get("parent_id")
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(data=False, message='No file part!', code=400)
|
||||
file_objs = request.files.getlist('file')
|
||||
if 'file' not in files:
|
||||
return get_json_result(data=False, message='No file part!', code=RetCode.BAD_REQUEST)
|
||||
file_objs = files.getlist('file')
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(data=False, message='No selected file!', code=400)
|
||||
return get_json_result(data=False, message='No selected file!', code=RetCode.BAD_REQUEST)
|
||||
|
||||
file_res = []
|
||||
|
||||
try:
|
||||
e, pf_folder = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Can't find this folder!", code=404)
|
||||
return get_json_result(data=False, message="Can't find this folder!", code=RetCode.NOT_FOUND)
|
||||
|
||||
for file_obj in file_objs:
|
||||
# Handle file path
|
||||
@ -114,13 +114,13 @@ def upload(tenant_id):
|
||||
if file_len != len_id_list:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
||||
return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
len_id_list)
|
||||
else:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
||||
return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
len_id_list)
|
||||
|
||||
@ -151,7 +151,7 @@ def upload(tenant_id):
|
||||
|
||||
@manager.route('/file/create', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
async def create(tenant_id):
|
||||
"""
|
||||
Create a new file or folder.
|
||||
---
|
||||
@ -193,16 +193,16 @@ def create(tenant_id):
|
||||
type:
|
||||
type: string
|
||||
"""
|
||||
req = request.json
|
||||
pf_id = request.json.get("parent_id")
|
||||
input_file_type = request.json.get("type")
|
||||
req = await request.json
|
||||
pf_id = await request.json.get("parent_id")
|
||||
input_file_type = await request.json.get("type")
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
try:
|
||||
if not FileService.is_parent_folder_exist(pf_id):
|
||||
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=400)
|
||||
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=RetCode.BAD_REQUEST)
|
||||
if FileService.query(name=req["name"], parent_id=pf_id):
|
||||
return get_json_result(data=False, message="Duplicated folder name in the same folder.", code=409)
|
||||
|
||||
@ -306,13 +306,13 @@ def list_files(tenant_id):
|
||||
try:
|
||||
e, file = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
files, total = FileService.get_by_pf_id(tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords)
|
||||
|
||||
parent_folder = FileService.get_parent_folder(pf_id)
|
||||
if not parent_folder:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
return get_json_result(data={"total": total, "files": files, "parent_folder": parent_folder.to_json()})
|
||||
except Exception as e:
|
||||
@ -392,7 +392,7 @@ def get_parent_folder():
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
parent_folder = FileService.get_parent_folder(file_id)
|
||||
return get_json_result(data={"parent_folder": parent_folder.to_json()})
|
||||
@ -439,7 +439,7 @@ def get_all_parent_folders(tenant_id):
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
parent_folders = FileService.get_all_parent_folders(file_id)
|
||||
parent_folders_res = [folder.to_json() for folder in parent_folders]
|
||||
@ -450,7 +450,7 @@ def get_all_parent_folders(tenant_id):
|
||||
|
||||
@manager.route('/file/rm', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def rm(tenant_id):
|
||||
async def rm(tenant_id):
|
||||
"""
|
||||
Delete one or multiple files/folders.
|
||||
---
|
||||
@ -481,40 +481,40 @@ def rm(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
req = await request.json
|
||||
file_ids = req["file_ids"]
|
||||
try:
|
||||
for file_id in file_ids:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="File or Folder not found!", code=404)
|
||||
return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND)
|
||||
if not file.tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_id_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
for inner_file_id in file_id_list:
|
||||
e, file = FileService.get_by_id(inner_file_id)
|
||||
if not e:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
FileService.delete_folder_by_pf_id(tenant_id, file_id)
|
||||
else:
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
if not FileService.delete(file):
|
||||
return get_json_result(message="Database error (File removal)!", code=500)
|
||||
return get_json_result(message="Database error (File removal)!", code=RetCode.SERVER_ERROR)
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(file_id)
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_json_result(message="Database error (Document removal)!", code=500)
|
||||
return get_json_result(message="Database error (Document removal)!", code=RetCode.SERVER_ERROR)
|
||||
File2DocumentService.delete_by_file_id(file_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
@ -524,7 +524,7 @@ def rm(tenant_id):
|
||||
|
||||
@manager.route('/file/rename', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def rename(tenant_id):
|
||||
async def rename(tenant_id):
|
||||
"""
|
||||
Rename a file.
|
||||
---
|
||||
@ -556,27 +556,27 @@ def rename(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
req = await request.json
|
||||
try:
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||
file.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=400)
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.BAD_REQUEST)
|
||||
|
||||
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||
if existing_file.name == req["name"]:
|
||||
return get_json_result(data=False, message="Duplicated file name in the same folder.", code=409)
|
||||
|
||||
if not FileService.update_by_id(req["file_id"], {"name": req["name"]}):
|
||||
return get_json_result(message="Database error (File rename)!", code=500)
|
||||
return get_json_result(message="Database error (File rename)!", code=RetCode.SERVER_ERROR)
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(req["file_id"])
|
||||
if informs:
|
||||
if not DocumentService.update_by_id(informs[0].document_id, {"name": req["name"]}):
|
||||
return get_json_result(message="Database error (Document rename)!", code=500)
|
||||
return get_json_result(message="Database error (Document rename)!", code=RetCode.SERVER_ERROR)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
@ -585,7 +585,7 @@ def rename(tenant_id):
|
||||
|
||||
@manager.route('/file/get/<file_id>', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get(tenant_id, file_id):
|
||||
async def get(tenant_id, file_id):
|
||||
"""
|
||||
Download a file.
|
||||
---
|
||||
@ -606,20 +606,20 @@ def get(tenant_id, file_id):
|
||||
description: File stream
|
||||
schema:
|
||||
type: file
|
||||
404:
|
||||
RetCode.NOT_FOUND:
|
||||
description: File not found
|
||||
"""
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
if not blob:
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = settings.STORAGE_IMPL.get(b, n)
|
||||
|
||||
response = flask.make_response(blob)
|
||||
response = await make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name)
|
||||
if ext:
|
||||
if file.type == FileType.VISUAL.value:
|
||||
@ -633,7 +633,7 @@ def get(tenant_id, file_id):
|
||||
|
||||
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def move(tenant_id):
|
||||
async def move(tenant_id):
|
||||
"""
|
||||
Move one or multiple files to another folder.
|
||||
---
|
||||
@ -667,7 +667,7 @@ def move(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
req = await request.json
|
||||
try:
|
||||
file_ids = req["src_file_ids"]
|
||||
parent_id = req["dest_file_id"]
|
||||
@ -677,13 +677,13 @@ def move(tenant_id):
|
||||
for file_id in file_ids:
|
||||
file = files_dict[file_id]
|
||||
if not file:
|
||||
return get_json_result(message="File or Folder not found!", code=404)
|
||||
return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND)
|
||||
if not file.tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
fe, _ = FileService.get_by_id(parent_id)
|
||||
if not fe:
|
||||
return get_json_result(message="Parent Folder not found!", code=404)
|
||||
return get_json_result(message="Parent Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
FileService.move_file(file_ids, parent_id)
|
||||
return get_json_result(data=True)
|
||||
@ -693,8 +693,8 @@ def move(tenant_id):
|
||||
|
||||
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def convert(tenant_id):
|
||||
req = request.json
|
||||
async def convert(tenant_id):
|
||||
req = await request.json
|
||||
kb_ids = req["kb_ids"]
|
||||
file_ids = req["file_ids"]
|
||||
file2documents = []
|
||||
@ -705,7 +705,7 @@ def convert(tenant_id):
|
||||
for file_id in file_ids:
|
||||
file = files_set[file_id]
|
||||
if not file:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
file_ids_list = [file_id]
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
@ -716,13 +716,13 @@ def convert(tenant_id):
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_json_result(
|
||||
message="Database error (Document removal)!", code=404)
|
||||
message="Database error (Document removal)!", code=RetCode.NOT_FOUND)
|
||||
File2DocumentService.delete_by_file_id(id)
|
||||
|
||||
# insert
|
||||
@ -730,11 +730,11 @@ def convert(tenant_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
message="Can't find this knowledgebase!", code=404)
|
||||
message="Can't find this knowledgebase!", code=RetCode.NOT_FOUND)
|
||||
e, file = FileService.get_by_id(id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
message="Can't find this file!", code=404)
|
||||
message="Can't find this file!", code=RetCode.NOT_FOUND)
|
||||
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
|
||||
@ -18,7 +18,7 @@ import re
|
||||
import time
|
||||
|
||||
import tiktoken
|
||||
from flask import Response, jsonify, request
|
||||
from quart import Response, jsonify, request
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.db.db_models import APIToken
|
||||
@ -44,8 +44,8 @@ from common import settings
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id, chat_id):
|
||||
req = request.json
|
||||
async def create(tenant_id, chat_id):
|
||||
req = await request.json
|
||||
req["dialog_id"] = chat_id
|
||||
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
||||
if not dia:
|
||||
@ -97,8 +97,8 @@ def create_agent_session(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, chat_id, session_id):
|
||||
req = request.json
|
||||
async def update(tenant_id, chat_id, session_id):
|
||||
req = await request.json
|
||||
req["dialog_id"] = chat_id
|
||||
conv_id = session_id
|
||||
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
|
||||
@ -119,8 +119,8 @@ def update(tenant_id, chat_id, session_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def chat_completion(tenant_id, chat_id):
|
||||
req = request.json
|
||||
async def chat_completion(tenant_id, chat_id):
|
||||
req = await request.json
|
||||
if not req:
|
||||
req = {"question": ""}
|
||||
if not req.get("session_id"):
|
||||
@ -149,7 +149,7 @@ def chat_completion(tenant_id, chat_id):
|
||||
@manager.route("/chats_openai/<chat_id>/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("model", "messages") # noqa: F821
|
||||
@token_required
|
||||
def chat_completion_openai_like(tenant_id, chat_id):
|
||||
async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
"""
|
||||
OpenAI-like chat completion API that simulates the behavior of OpenAI's completions endpoint.
|
||||
|
||||
@ -206,7 +206,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
if reference:
|
||||
print(completion.choices[0].message.reference)
|
||||
"""
|
||||
req = request.get_json()
|
||||
req = await request.get_json()
|
||||
|
||||
need_reference = bool(req.get("reference", False))
|
||||
|
||||
@ -383,8 +383,8 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
@manager.route("/agents_openai/<agent_id>/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("model", "messages") # noqa: F821
|
||||
@token_required
|
||||
def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
req = request.json
|
||||
async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
req = await request.json
|
||||
tiktokenenc = tiktoken.get_encoding("cl100k_base")
|
||||
messages = req.get("messages", [])
|
||||
if not messages:
|
||||
@ -428,28 +428,26 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
return resp
|
||||
else:
|
||||
# For non-streaming, just return the response directly
|
||||
response = next(
|
||||
completion_openai(
|
||||
async for response in completion_openai(
|
||||
tenant_id,
|
||||
agent_id,
|
||||
question,
|
||||
session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""),
|
||||
stream=False,
|
||||
**req,
|
||||
)
|
||||
)
|
||||
return jsonify(response)
|
||||
):
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def agent_completions(tenant_id, agent_id):
|
||||
req = request.json
|
||||
async def agent_completions(tenant_id, agent_id):
|
||||
req = await request.json
|
||||
|
||||
if req.get("stream", True):
|
||||
|
||||
def generate():
|
||||
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
async def generate():
|
||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
if isinstance(answer, str):
|
||||
try:
|
||||
ans = json.loads(answer[5:]) # remove "data:"
|
||||
@ -473,7 +471,7 @@ def agent_completions(tenant_id, agent_id):
|
||||
full_content = ""
|
||||
reference = {}
|
||||
final_ans = ""
|
||||
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
try:
|
||||
ans = json.loads(answer[5:])
|
||||
|
||||
@ -610,13 +608,13 @@ def list_agent_session(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id, chat_id):
|
||||
async def delete(tenant_id, chat_id):
|
||||
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message="You don't own the chat")
|
||||
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
req = await request.json
|
||||
convs = ConversationService.query(dialog_id=chat_id)
|
||||
if not req:
|
||||
ids = None
|
||||
@ -661,10 +659,10 @@ def delete(tenant_id, chat_id):
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete_agent_session(tenant_id, agent_id):
|
||||
async def delete_agent_session(tenant_id, agent_id):
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
req = await request.json
|
||||
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
|
||||
if not cvs:
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
@ -716,8 +714,8 @@ def delete_agent_session(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def ask_about(tenant_id):
|
||||
req = request.json
|
||||
async def ask_about(tenant_id):
|
||||
req = await request.json
|
||||
if not req.get("question"):
|
||||
return get_error_data_result("`question` is required.")
|
||||
if not req.get("dataset_ids"):
|
||||
@ -755,8 +753,8 @@ def ask_about(tenant_id):
|
||||
|
||||
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def related_questions(tenant_id):
|
||||
req = request.json
|
||||
async def related_questions(tenant_id):
|
||||
req = await request.json
|
||||
if not req.get("question"):
|
||||
return get_error_data_result("`question` is required.")
|
||||
question = req["question"]
|
||||
@ -806,8 +804,8 @@ Related search terms:
|
||||
|
||||
|
||||
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
|
||||
def chatbot_completions(dialog_id):
|
||||
req = request.json
|
||||
async def chatbot_completions(dialog_id):
|
||||
req = await request.json
|
||||
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
@ -856,8 +854,8 @@ def chatbots_inputs(dialog_id):
|
||||
|
||||
|
||||
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
def agent_bot_completions(agent_id):
|
||||
req = request.json
|
||||
async def agent_bot_completions(agent_id):
|
||||
req = await request.json
|
||||
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
@ -875,7 +873,7 @@ def agent_bot_completions(agent_id):
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
|
||||
async for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
|
||||
return get_result(data=answer)
|
||||
|
||||
|
||||
@ -901,7 +899,7 @@ def begin_inputs(agent_id):
|
||||
|
||||
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
|
||||
@validate_request("question", "kb_ids")
|
||||
def ask_about_embedded():
|
||||
async def ask_about_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -910,7 +908,7 @@ def ask_about_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = request.json
|
||||
req = await request.json
|
||||
uid = objs[0].tenant_id
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
@ -940,7 +938,7 @@ def ask_about_embedded():
|
||||
|
||||
@manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval_test_embedded():
|
||||
async def retrieval_test_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -949,7 +947,7 @@ def retrieval_test_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = request.json
|
||||
req = await request.json
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
@ -977,14 +975,14 @@ def retrieval_test_embedded():
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=tenant_id)
|
||||
@ -1039,7 +1037,7 @@ def retrieval_test_embedded():
|
||||
|
||||
@manager.route("/searchbots/related_questions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("question")
|
||||
def related_questions_embedded():
|
||||
async def related_questions_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -1048,7 +1046,7 @@ def related_questions_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = request.json
|
||||
req = await request.json
|
||||
tenant_id = objs[0].tenant_id
|
||||
if not tenant_id:
|
||||
return get_error_data_result(message="permission denined.")
|
||||
@ -1115,7 +1113,7 @@ def detail_share_embedded():
|
||||
|
||||
@manager.route("/searchbots/mindmap", methods=["POST"]) # noqa: F821
|
||||
@validate_request("question", "kb_ids")
|
||||
def mindmap():
|
||||
async def mindmap():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -1125,7 +1123,7 @@ def mindmap():
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
req = request.json
|
||||
req = await request.json
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
|
||||
@ -14,8 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
from quart import request
|
||||
from api.apps import current_user, login_required
|
||||
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from api.db.db_models import DB
|
||||
@ -30,8 +30,8 @@ from api.utils.api_utils import get_data_error_result, get_json_result, not_allo
|
||||
@manager.route("/create", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.get_json()
|
||||
async def create():
|
||||
req = await request.get_json()
|
||||
search_name = req["name"]
|
||||
description = req.get("description", "")
|
||||
if not isinstance(search_name, str):
|
||||
@ -65,8 +65,8 @@ def create():
|
||||
@login_required
|
||||
@validate_request("search_id", "name", "search_config", "tenant_id")
|
||||
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||
def update():
|
||||
req = request.get_json()
|
||||
async def update():
|
||||
req = await request.get_json()
|
||||
if not isinstance(req["name"], str):
|
||||
return get_data_error_result(message="Search name must be string.")
|
||||
if req["name"].strip() == "":
|
||||
@ -140,7 +140,7 @@ def detail():
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_search_app():
|
||||
async def list_search_app():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
@ -150,7 +150,7 @@ def list_search_app():
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
req = await request.get_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -173,8 +173,8 @@ def list_search_app():
|
||||
@manager.route("/rm", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("search_id")
|
||||
def rm():
|
||||
req = request.get_json()
|
||||
async def rm():
|
||||
req = await request.get_json()
|
||||
search_id = req["search_id"]
|
||||
if not SearchService.accessible4deletion(search_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
@ -17,7 +17,7 @@ import logging
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import APITokenService
|
||||
@ -34,7 +34,7 @@ from common.time_utils import current_timestamp, datetime_format
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from flask import jsonify
|
||||
from quart import jsonify
|
||||
from api.utils.health_utils import run_health_checks
|
||||
from common import settings
|
||||
|
||||
|
||||
@ -14,10 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.apps import smtp_mail_server
|
||||
from quart import request
|
||||
from api.db import UserTenantRole
|
||||
from api.db.db_models import UserTenant
|
||||
from api.db.services.user_service import UserTenantService, UserService
|
||||
@ -28,6 +25,7 @@ from common.time_utils import delta_seconds
|
||||
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
|
||||
from api.utils.web_utils import send_invite_email
|
||||
from common import settings
|
||||
from api.apps import smtp_mail_server, login_required, current_user
|
||||
|
||||
|
||||
@manager.route("/<tenant_id>/user/list", methods=["GET"]) # noqa: F821
|
||||
@ -51,14 +49,14 @@ def user_list(tenant_id):
|
||||
@manager.route('/<tenant_id>/user', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("email")
|
||||
def create(tenant_id):
|
||||
async def create(tenant_id):
|
||||
if current_user.id != tenant_id:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
req = await request.json
|
||||
invite_user_email = req["email"]
|
||||
invite_users = UserService.query(email=invite_user_email)
|
||||
if not invite_users:
|
||||
|
||||
@ -22,8 +22,7 @@ import secrets
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from flask import redirect, request, session, make_response
|
||||
from flask_login import current_user, login_required, login_user, logout_user
|
||||
from quart import redirect, request, session, make_response
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
|
||||
from api.apps.auth import get_auth_client
|
||||
@ -45,7 +44,7 @@ from api.utils.api_utils import (
|
||||
)
|
||||
from api.utils.crypt import decrypt
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api.apps import smtp_mail_server
|
||||
from api.apps import smtp_mail_server, login_required, current_user, login_user, logout_user
|
||||
from api.utils.web_utils import (
|
||||
send_email_html,
|
||||
OTP_LENGTH,
|
||||
@ -61,7 +60,7 @@ from common import settings
|
||||
|
||||
|
||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||
def login():
|
||||
async def login():
|
||||
"""
|
||||
User login endpoint.
|
||||
---
|
||||
@ -91,10 +90,11 @@ def login():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
if not request.json:
|
||||
json_body = await request.json
|
||||
if not json_body:
|
||||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
|
||||
|
||||
email = request.json.get("email", "")
|
||||
email = json_body.get("email", "")
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
return get_json_result(
|
||||
@ -103,7 +103,7 @@ def login():
|
||||
message=f"Email: {email} is not registered!",
|
||||
)
|
||||
|
||||
password = request.json.get("password")
|
||||
password = json_body.get("password")
|
||||
try:
|
||||
password = decrypt(password)
|
||||
except BaseException:
|
||||
@ -125,7 +125,8 @@ def login():
|
||||
user.update_date = (datetime_format(datetime.now()),)
|
||||
user.save()
|
||||
msg = "Welcome back!"
|
||||
return construct_response(data=response_data, auth=user.get_id(), message=msg)
|
||||
|
||||
return await construct_response(data=response_data, auth=user.get_id(), message=msg)
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -501,7 +502,7 @@ def log_out():
|
||||
|
||||
@manager.route("/setting", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def setting_user():
|
||||
async def setting_user():
|
||||
"""
|
||||
Update user settings.
|
||||
---
|
||||
@ -530,7 +531,7 @@ def setting_user():
|
||||
type: object
|
||||
"""
|
||||
update_dict = {}
|
||||
request_data = request.json
|
||||
request_data = await request.json
|
||||
if request_data.get("password"):
|
||||
new_password = request_data.get("new_password")
|
||||
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||||
@ -660,7 +661,7 @@ def user_register(user_id, user):
|
||||
|
||||
@manager.route("/register", methods=["POST"]) # noqa: F821
|
||||
@validate_request("nickname", "email", "password")
|
||||
def user_add():
|
||||
async def user_add():
|
||||
"""
|
||||
Register a new user.
|
||||
---
|
||||
@ -697,7 +698,7 @@ def user_add():
|
||||
code=RetCode.OPERATING_ERROR,
|
||||
)
|
||||
|
||||
req = request.json
|
||||
req = await request.json
|
||||
email_address = req["email"]
|
||||
|
||||
# Validate the email address
|
||||
@ -737,7 +738,7 @@ def user_add():
|
||||
raise Exception(f"Same email: {email_address} exists!")
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return construct_response(
|
||||
return await construct_response(
|
||||
data=user.to_json(),
|
||||
auth=user.get_id(),
|
||||
message=f"{nickname}, welcome aboard!",
|
||||
@ -793,7 +794,7 @@ def tenant_info():
|
||||
@manager.route("/set_tenant_info", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
|
||||
def set_tenant_info():
|
||||
async def set_tenant_info():
|
||||
"""
|
||||
Update tenant information.
|
||||
---
|
||||
@ -830,7 +831,7 @@ def set_tenant_info():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req = request.json
|
||||
req = await request.json
|
||||
try:
|
||||
tid = req.pop("tenant_id")
|
||||
TenantService.update_by_id(tid, req)
|
||||
@ -840,7 +841,7 @@ def set_tenant_info():
|
||||
|
||||
|
||||
@manager.route("/forget/captcha", methods=["GET"]) # noqa: F821
|
||||
def forget_get_captcha():
|
||||
async def forget_get_captcha():
|
||||
"""
|
||||
GET /forget/captcha?email=<email>
|
||||
- Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = OTP_TTL_SECONDS.
|
||||
@ -862,19 +863,19 @@ def forget_get_captcha():
|
||||
from captcha.image import ImageCaptcha
|
||||
image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70])
|
||||
img_bytes = image.generate(captcha_text).read()
|
||||
response = make_response(img_bytes)
|
||||
response = await make_response(img_bytes)
|
||||
response.headers.set("Content-Type", "image/JPEG")
|
||||
return response
|
||||
|
||||
|
||||
@manager.route("/forget/otp", methods=["POST"]) # noqa: F821
|
||||
def forget_send_otp():
|
||||
async def forget_send_otp():
|
||||
"""
|
||||
POST /forget/otp
|
||||
- Verify the image captcha stored at captcha:{email} (case-insensitive).
|
||||
- On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
|
||||
"""
|
||||
req = request.get_json()
|
||||
req = await request.get_json()
|
||||
email = req.get("email") or ""
|
||||
captcha = (req.get("captcha") or "").strip()
|
||||
|
||||
@ -935,12 +936,12 @@ def forget_send_otp():
|
||||
|
||||
|
||||
@manager.route("/forget", methods=["POST"]) # noqa: F821
|
||||
def forget():
|
||||
async def forget():
|
||||
"""
|
||||
POST: Verify email + OTP and reset password, then log the user in.
|
||||
Request JSON: { email, otp, new_password, confirm_new_password }
|
||||
"""
|
||||
req = request.get_json()
|
||||
req = await request.get_json()
|
||||
email = req.get("email") or ""
|
||||
otp = (req.get("otp") or "").strip()
|
||||
new_pwd = req.get("new_password")
|
||||
|
||||
@ -25,7 +25,7 @@ from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import UserMixin
|
||||
from quart_auth import AuthUser
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||
@ -305,6 +305,7 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
|
||||
@ -594,7 +595,7 @@ def fill_db_model_object(model_object, human_model_dict):
|
||||
return model_object
|
||||
|
||||
|
||||
class User(DataBaseModel, UserMixin):
|
||||
class User(DataBaseModel, AuthUser):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
access_token = CharField(max_length=255, null=True, index=True)
|
||||
nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
|
||||
@ -748,7 +749,7 @@ class Knowledgebase(DataBaseModel):
|
||||
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0})
|
||||
pagerank = IntegerField(default=0, index=False)
|
||||
|
||||
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
|
||||
@ -772,8 +773,8 @@ class Document(DataBaseModel):
|
||||
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
||||
kb_id = CharField(max_length=256, null=False, index=True)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="pipleline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="pipeline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0})
|
||||
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
|
||||
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
|
||||
created_by = CharField(max_length=32, null=False, help_text="who created it", index=True)
|
||||
@ -876,7 +877,7 @@ class Dialog(DataBaseModel):
|
||||
class Conversation(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||
name = CharField(max_length=255, null=True, help_text="converastion name", index=True)
|
||||
name = CharField(max_length=255, null=True, help_text="conversation name", index=True)
|
||||
message = JSONField(null=True)
|
||||
reference = JSONField(null=True, default=[])
|
||||
user_id = CharField(max_length=255, null=True, help_text="user_id", index=True)
|
||||
|
||||
@ -34,14 +34,17 @@ from common.file_utils import get_project_base_directory
|
||||
from common import settings
|
||||
from api.common.base64 import encode_to_base64
|
||||
|
||||
DEFAULT_SUPERUSER_NICKNAME = os.getenv("DEFAULT_SUPERUSER_NICKNAME", "admin")
|
||||
DEFAULT_SUPERUSER_EMAIL = os.getenv("DEFAULT_SUPERUSER_EMAIL", "admin@ragflow.io")
|
||||
DEFAULT_SUPERUSER_PASSWORD = os.getenv("DEFAULT_SUPERUSER_PASSWORD", "admin")
|
||||
|
||||
def init_superuser():
|
||||
def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_EMAIL, password=DEFAULT_SUPERUSER_PASSWORD, role=UserTenantRole.OWNER):
|
||||
user_info = {
|
||||
"id": uuid.uuid1().hex,
|
||||
"password": encode_to_base64("admin"),
|
||||
"nickname": "admin",
|
||||
"password": encode_to_base64(password),
|
||||
"nickname": nickname,
|
||||
"is_superuser": True,
|
||||
"email": "admin@ragflow.io",
|
||||
"email": email,
|
||||
"creator": "system",
|
||||
"status": "1",
|
||||
}
|
||||
@ -58,7 +61,7 @@ def init_superuser():
|
||||
"tenant_id": user_info["id"],
|
||||
"user_id": user_info["id"],
|
||||
"invited_by": user_info["id"],
|
||||
"role": UserTenantRole.OWNER
|
||||
"role": role
|
||||
}
|
||||
|
||||
tenant_llm = get_init_tenant_llm(user_info["id"])
|
||||
@ -70,7 +73,7 @@ def init_superuser():
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
logging.info(
|
||||
"Super user initialized. email: admin@ragflow.io, password: admin. Changing the password after login is strongly recommended.")
|
||||
f"Super user initialized. email: {email}, password: {password}. Changing the password after login is strongly recommended.")
|
||||
|
||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||
msg = chat_mdl.chat(system="", history=[
|
||||
|
||||
@ -177,7 +177,7 @@ class UserCanvasService(CommonService):
|
||||
return True
|
||||
|
||||
|
||||
def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
async def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
query = kwargs.get("query", "") or kwargs.get("question", "")
|
||||
files = kwargs.get("files", [])
|
||||
inputs = kwargs.get("inputs", {})
|
||||
@ -219,10 +219,14 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
"id": message_id
|
||||
})
|
||||
txt = ""
|
||||
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
ans["session_id"] = session_id
|
||||
if ans["event"] == "message":
|
||||
txt += ans["data"]["content"]
|
||||
if ans["data"].get("start_to_think", False):
|
||||
txt += "<think>"
|
||||
elif ans["data"].get("end_to_think", False):
|
||||
txt += "</think>"
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
conv.message.append({"role": "assistant", "content": txt, "created_at": time.time(), "id": message_id})
|
||||
@ -233,7 +237,7 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
API4ConversationService.append_message(conv["id"], conv)
|
||||
|
||||
|
||||
def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
|
||||
async def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
|
||||
tiktoken_encoder = tiktoken.get_encoding("cl100k_base")
|
||||
prompt_tokens = len(tiktoken_encoder.encode(str(question)))
|
||||
user_id = kwargs.get("user_id", "")
|
||||
@ -241,7 +245,7 @@ def completion_openai(tenant_id, agent_id, question, session_id=None, stream=Tru
|
||||
if stream:
|
||||
completion_tokens = 0
|
||||
try:
|
||||
for ans in completion(
|
||||
async for ans in completion(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
@ -300,7 +304,7 @@ def completion_openai(tenant_id, agent_id, question, session_id=None, stream=Tru
|
||||
try:
|
||||
all_content = ""
|
||||
reference = {}
|
||||
for ans in completion(
|
||||
async for ans in completion(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import os
|
||||
from typing import Tuple, List
|
||||
|
||||
from anthropic import BaseModel
|
||||
@ -24,7 +25,6 @@ from api.db import InputType
|
||||
from api.db.db_models import Connector, SyncLogs, Connector2Kb, Knowledgebase
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import TaskStatus
|
||||
from common.time_utils import current_timestamp, timestamp_to_date
|
||||
@ -68,9 +68,10 @@ class ConnectorService(CommonService):
|
||||
|
||||
@classmethod
|
||||
def rebuild(cls, kb_id:str, connector_id: str, tenant_id:str):
|
||||
from api.db.services.file_service import FileService
|
||||
e, conn = cls.get_by_id(connector_id)
|
||||
if not e:
|
||||
return
|
||||
return None
|
||||
SyncLogsService.filter_delete([SyncLogs.connector_id==connector_id, SyncLogs.kb_id==kb_id])
|
||||
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}", kb_id=kb_id)
|
||||
err = FileService.delete_docs([d.id for d in docs], tenant_id)
|
||||
@ -103,7 +104,8 @@ class SyncLogsService(CommonService):
|
||||
Knowledgebase.avatar.alias("kb_avatar"),
|
||||
Connector2Kb.auto_parse,
|
||||
cls.model.from_beginning.alias("reindex"),
|
||||
cls.model.status
|
||||
cls.model.status,
|
||||
cls.model.update_time
|
||||
]
|
||||
if not connector_id:
|
||||
fields.append(Connector.config)
|
||||
@ -116,7 +118,11 @@ class SyncLogsService(CommonService):
|
||||
if connector_id:
|
||||
query = query.where(cls.model.connector_id == connector_id)
|
||||
else:
|
||||
interval_expr = SQL("INTERVAL `t2`.`refresh_freq` MINUTE")
|
||||
database_type = os.getenv("DB_TYPE", "mysql")
|
||||
if "postgres" in database_type.lower():
|
||||
interval_expr = SQL("make_interval(mins => t2.refresh_freq)")
|
||||
else:
|
||||
interval_expr = SQL("INTERVAL `t2`.`refresh_freq` MINUTE")
|
||||
query = query.where(
|
||||
Connector.input_type == InputType.POLL,
|
||||
Connector.status == TaskStatus.SCHEDULE,
|
||||
@ -125,11 +131,11 @@ class SyncLogsService(CommonService):
|
||||
)
|
||||
|
||||
query = query.distinct().order_by(cls.model.update_time.desc())
|
||||
totbal = query.count()
|
||||
total = query.count()
|
||||
if page_number:
|
||||
query = query.paginate(page_number, items_per_page)
|
||||
|
||||
return list(query.dicts()), totbal
|
||||
return list(query.dicts()), total
|
||||
|
||||
@classmethod
|
||||
def start(cls, id, connector_id):
|
||||
@ -191,6 +197,7 @@ class SyncLogsService(CommonService):
|
||||
|
||||
@classmethod
|
||||
def duplicate_and_parse(cls, kb, docs, tenant_id, src, auto_parse=True):
|
||||
from api.db.services.file_service import FileService
|
||||
if not docs:
|
||||
return None
|
||||
|
||||
@ -207,9 +214,21 @@ class SyncLogsService(CommonService):
|
||||
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
|
||||
errs.extend(err)
|
||||
|
||||
# Create a mapping from filename to metadata for later use
|
||||
metadata_map = {}
|
||||
for d in docs:
|
||||
if d.get("metadata"):
|
||||
filename = d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else "")
|
||||
metadata_map[filename] = d["metadata"]
|
||||
|
||||
kb_table_num_map = {}
|
||||
for doc, _ in doc_blob_pairs:
|
||||
doc_ids.append(doc["id"])
|
||||
|
||||
# Set metadata if available for this document
|
||||
if doc["name"] in metadata_map:
|
||||
DocumentService.update_by_id(doc["id"], {"meta_fields": metadata_map[doc["name"]]})
|
||||
|
||||
if not auto_parse or auto_parse == "0":
|
||||
continue
|
||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||
@ -242,7 +261,7 @@ class Connector2KbService(CommonService):
|
||||
"id": get_uuid(),
|
||||
"connector_id": conn_id,
|
||||
"kb_id": kb_id,
|
||||
"auto_parse": conn.get("auto_parse", "1")
|
||||
"auto_parse": conn.get("auto_parse", "1")
|
||||
})
|
||||
SyncLogsService.schedule(conn_id, kb_id, reindex=True)
|
||||
|
||||
|
||||
@ -287,7 +287,7 @@ def convert_conditions(metadata_condition):
|
||||
]
|
||||
|
||||
|
||||
def meta_filter(metas: dict, filters: list[dict]):
|
||||
def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
|
||||
doc_ids = set([])
|
||||
|
||||
def filter_out(v2docs, operator, value):
|
||||
@ -304,6 +304,8 @@ def meta_filter(metas: dict, filters: list[dict]):
|
||||
for conds in [
|
||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||
(operator == "in", str(input).lower() in str(value).lower()),
|
||||
(operator == "not in", str(input).lower() not in str(value).lower()),
|
||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||
(operator == "empty", not input),
|
||||
@ -331,7 +333,10 @@ def meta_filter(metas: dict, filters: list[dict]):
|
||||
if not doc_ids:
|
||||
doc_ids = set(ids)
|
||||
else:
|
||||
doc_ids = doc_ids & set(ids)
|
||||
if logic == "and":
|
||||
doc_ids = doc_ids & set(ids)
|
||||
else:
|
||||
doc_ids = doc_ids | set(ids)
|
||||
if not doc_ids:
|
||||
return []
|
||||
return list(doc_ids)
|
||||
@ -342,7 +347,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||
for ans in chat_solo(dialog, messages, stream):
|
||||
yield ans
|
||||
return
|
||||
return None
|
||||
|
||||
chat_start_ts = timer()
|
||||
|
||||
@ -386,7 +391,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||
if ans:
|
||||
yield ans
|
||||
return
|
||||
return None
|
||||
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["key"] == "knowledge":
|
||||
@ -407,14 +412,15 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if dialog.meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
|
||||
if dialog.meta_data_filter.get("method") == "auto":
|
||||
filters = gen_meta_filter(chat_mdl, metas, questions[-1])
|
||||
attachments.extend(meta_filter(metas, filters))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, questions[-1])
|
||||
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not attachments:
|
||||
attachments = None
|
||||
elif dialog.meta_data_filter.get("method") == "manual":
|
||||
attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"]))
|
||||
if not attachments:
|
||||
attachments = None
|
||||
conds = dialog.meta_data_filter["manual"]
|
||||
attachments.extend(meta_filter(metas, conds, dialog.meta_data_filter.get("logic", "and")))
|
||||
if conds and not attachments:
|
||||
attachments = ["-999"]
|
||||
|
||||
if prompt_config.get("keyword", False):
|
||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||||
@ -617,6 +623,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
res["audio_binary"] = tts(tts_mdl, answer)
|
||||
yield res
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||
sys_prompt = """
|
||||
@ -745,7 +753,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
|
||||
def tts(tts_mdl, text):
|
||||
if not tts_mdl or not text:
|
||||
return
|
||||
return None
|
||||
bin = b""
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
@ -776,14 +784,14 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
|
||||
kbinfos = retriever.retrieval(
|
||||
question=question,
|
||||
@ -851,14 +859,14 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
|
||||
ranks = settings.retriever.retrieval(
|
||||
question=question,
|
||||
|
||||
@ -41,6 +41,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common import settings
|
||||
|
||||
|
||||
class DocumentService(CommonService):
|
||||
model = Document
|
||||
|
||||
@ -113,7 +114,7 @@ class DocumentService(CommonService):
|
||||
def check_doc_health(cls, tenant_id: str, filename):
|
||||
import os
|
||||
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
|
||||
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(tenant_id) >= MAX_FILE_NUM_PER_USER:
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(tenant_id):
|
||||
raise RuntimeError("Exceed the maximum file number of a free user!")
|
||||
if len(filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
raise RuntimeError("Exceed the maximum length of file name!")
|
||||
@ -309,7 +310,7 @@ class DocumentService(CommonService):
|
||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
page * page_size, page_size, search.index_name(tenant_id),
|
||||
[doc.kb_id])
|
||||
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
|
||||
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
|
||||
if not chunk_ids:
|
||||
break
|
||||
all_chunk_ids.extend(chunk_ids)
|
||||
@ -322,7 +323,7 @@ class DocumentService(CommonService):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
graph_source = settings.docStoreConn.getFields(
|
||||
graph_source = settings.docStoreConn.get_fields(
|
||||
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||
)
|
||||
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
|
||||
@ -464,7 +465,7 @@ class DocumentService(CommonService):
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["tenant_id"]
|
||||
|
||||
@classmethod
|
||||
@ -473,7 +474,7 @@ class DocumentService(CommonService):
|
||||
docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["kb_id"]
|
||||
|
||||
@classmethod
|
||||
@ -486,7 +487,7 @@ class DocumentService(CommonService):
|
||||
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["tenant_id"]
|
||||
|
||||
@classmethod
|
||||
@ -533,7 +534,7 @@ class DocumentService(CommonService):
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["embd_id"]
|
||||
|
||||
@classmethod
|
||||
@ -569,7 +570,7 @@ class DocumentService(CommonService):
|
||||
.where(cls.model.name == doc_name)
|
||||
doc_id = doc_id.dicts()
|
||||
if not doc_id:
|
||||
return
|
||||
return None
|
||||
return doc_id[0]["id"]
|
||||
|
||||
@classmethod
|
||||
@ -715,7 +716,7 @@ class DocumentService(CommonService):
|
||||
prg = 1
|
||||
status = TaskStatus.DONE.value
|
||||
|
||||
# only for special task and parsed docs and unfinised
|
||||
# only for special task and parsed docs and unfinished
|
||||
freeze_progress = special_task_running and doc_progress >= 1 and not finished
|
||||
msg = "\n".join(sorted(msg))
|
||||
info = {
|
||||
@ -922,7 +923,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
ParserType.AUDIO.value: audio,
|
||||
ParserType.EMAIL.value: email
|
||||
}
|
||||
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
|
||||
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0}
|
||||
exe = ThreadPoolExecutor(max_workers=12)
|
||||
threads = []
|
||||
doc_nm = {}
|
||||
@ -974,13 +975,13 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
|
||||
def embedding(doc_id, cnts, batch_size=16):
|
||||
nonlocal embd_mdl, chunk_counts, token_counts
|
||||
vects = []
|
||||
vectors = []
|
||||
for i in range(0, len(cnts), batch_size):
|
||||
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
|
||||
vects.extend(vts.tolist())
|
||||
vectors.extend(vts.tolist())
|
||||
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
|
||||
token_counts[doc_id] += c
|
||||
return vects
|
||||
return vectors
|
||||
|
||||
idxnm = search.index_name(kb.tenant_id)
|
||||
try_create_idx = True
|
||||
@ -1011,15 +1012,15 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
except Exception:
|
||||
logging.exception("Mind map generation error")
|
||||
|
||||
vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
|
||||
assert len(cks) == len(vects)
|
||||
vectors = embedding(doc_id, [c["content_with_weight"] for c in cks])
|
||||
assert len(cks) == len(vectors)
|
||||
for i, d in enumerate(cks):
|
||||
v = vects[i]
|
||||
v = vectors[i]
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0]))
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@ import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
from flask_login import current_user
|
||||
from peewee import fn
|
||||
|
||||
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileType
|
||||
@ -31,7 +30,7 @@ from common.misc_utils import get_uuid
|
||||
from common.constants import TaskStatus, FileSource, ParserType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img, sanitize_path
|
||||
from rag.llm.cv_model import GptV4
|
||||
from common import settings
|
||||
|
||||
@ -184,6 +183,7 @@ class FileService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create_folder(cls, file, parent_id, name, count):
|
||||
from api.apps import current_user
|
||||
# Recursively create folder structure
|
||||
# Args:
|
||||
# file: Current file object
|
||||
@ -329,7 +329,7 @@ class FileService(CommonService):
|
||||
current_id = start_id
|
||||
while current_id:
|
||||
e, file = cls.get_by_id(current_id)
|
||||
if file.parent_id != file.id and e:
|
||||
if e and file.parent_id != file.id:
|
||||
parent_folders.append(file)
|
||||
current_id = file.parent_id
|
||||
else:
|
||||
@ -423,13 +423,15 @@ class FileService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def upload_document(self, kb, file_objs, user_id, src="local"):
|
||||
def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str | None = None):
|
||||
root_folder = self.get_root_folder(user_id)
|
||||
pf_id = root_folder["id"]
|
||||
self.init_knowledgebase_docs(pf_id, user_id)
|
||||
kb_root_folder = self.get_kb_folder(user_id)
|
||||
kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
|
||||
|
||||
safe_parent_path = sanitize_path(parent_path)
|
||||
|
||||
err, files = [], []
|
||||
for file in file_objs:
|
||||
try:
|
||||
@ -439,7 +441,7 @@ class FileService(CommonService):
|
||||
if filetype == FileType.OTHER.value:
|
||||
raise RuntimeError("This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
location = filename if not safe_parent_path else f"{safe_parent_path}/{filename}"
|
||||
while settings.STORAGE_IMPL.obj_exist(kb.id, location):
|
||||
location += "_"
|
||||
|
||||
@ -506,6 +508,7 @@ class FileService(CommonService):
|
||||
@staticmethod
|
||||
def parse(filename, blob, img_base64=True, tenant_id=None):
|
||||
from rag.app import audio, email, naive, picture, presentation
|
||||
from api.apps import current_user
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
@ -28,6 +28,7 @@ from common.constants import StatusEnum
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from api.utils.api_utils import get_parser_config, get_data_error_result
|
||||
|
||||
|
||||
class KnowledgebaseService(CommonService):
|
||||
"""Service class for managing knowledge base operations.
|
||||
|
||||
@ -391,12 +392,12 @@ class KnowledgebaseService(CommonService):
|
||||
"""
|
||||
# Validate name
|
||||
if not isinstance(name, str):
|
||||
return get_data_error_result(message="Dataset name must be string.")
|
||||
return False, get_data_error_result(message="Dataset name must be string.")
|
||||
dataset_name = name.strip()
|
||||
if dataset_name == "":
|
||||
return get_data_error_result(message="Dataset name can't be empty.")
|
||||
return False, get_data_error_result(message="Dataset name can't be empty.")
|
||||
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||
return get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
|
||||
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
|
||||
|
||||
# Deduplicate name within tenant
|
||||
dataset_name = duplicate_name(
|
||||
@ -409,7 +410,7 @@ class KnowledgebaseService(CommonService):
|
||||
# Verify tenant exists
|
||||
ok, _t = TenantService.get_by_id(tenant_id)
|
||||
if not ok:
|
||||
return False, "Tenant not found."
|
||||
return False, get_data_error_result(message="Tenant not found.")
|
||||
|
||||
# Build payload
|
||||
kb_id = get_uuid()
|
||||
@ -419,12 +420,13 @@ class KnowledgebaseService(CommonService):
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": tenant_id,
|
||||
"parser_id": (parser_id or "naive"),
|
||||
**kwargs
|
||||
**kwargs # Includes optional fields such as description, language, permission, avatar, parser_config, etc.
|
||||
}
|
||||
|
||||
# Default parser_config (align with kb_app.create) — do not accept external overrides
|
||||
# Update parser_config (always override with validated default/merged config)
|
||||
payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config"))
|
||||
return payload
|
||||
|
||||
return True, payload
|
||||
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -19,6 +19,7 @@ import re
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
from common.constants import LLMType
|
||||
from api.db.db_models import LLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||
@ -32,6 +33,14 @@ def get_init_tenant_llm(user_id):
|
||||
from common import settings
|
||||
tenant_llm = []
|
||||
|
||||
model_configs = {
|
||||
LLMType.CHAT: settings.CHAT_CFG,
|
||||
LLMType.EMBEDDING: settings.EMBEDDING_CFG,
|
||||
LLMType.SPEECH2TEXT: settings.ASR_CFG,
|
||||
LLMType.IMAGE2TEXT: settings.IMAGE2TEXT_CFG,
|
||||
LLMType.RERANK: settings.RERANK_CFG,
|
||||
}
|
||||
|
||||
seen = set()
|
||||
factory_configs = []
|
||||
for factory_config in [
|
||||
@ -54,8 +63,8 @@ def get_init_tenant_llm(user_id):
|
||||
"llm_factory": factory_config["factory"],
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": factory_config["api_key"],
|
||||
"api_base": factory_config["base_url"],
|
||||
"api_key": model_configs.get(llm.model_type, {}).get("api_key", factory_config["api_key"]),
|
||||
"api_base": model_configs.get(llm.model_type, {}).get("base_url", factory_config["base_url"]),
|
||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||
}
|
||||
)
|
||||
@ -80,8 +89,8 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def encode(self, texts: list):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
|
||||
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
|
||||
|
||||
safe_texts = []
|
||||
for text in texts:
|
||||
token_size = num_tokens_from_string(text)
|
||||
@ -90,7 +99,7 @@ class LLMBundle(LLM4Tenant):
|
||||
safe_texts.append(text[:target_len])
|
||||
else:
|
||||
safe_texts.append(text)
|
||||
|
||||
|
||||
embeddings, used_tokens = self.mdl.encode(safe_texts)
|
||||
|
||||
llm_name = getattr(self, "llm_name", None)
|
||||
|
||||
@ -20,7 +20,6 @@
|
||||
|
||||
from common.log_utils import init_root_logger
|
||||
from plugin import GlobalPluginManager
|
||||
init_root_logger("ragflow_server")
|
||||
|
||||
import logging
|
||||
import os
|
||||
@ -30,18 +29,18 @@ import time
|
||||
import traceback
|
||||
import threading
|
||||
import uuid
|
||||
import faulthandler
|
||||
|
||||
from werkzeug.serving import run_simple
|
||||
from api.apps import app, smtp_mail_server
|
||||
from api.db.runtime_config import RuntimeConfig
|
||||
from api.db.services.document_service import DocumentService
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common import settings
|
||||
from api.db.db_models import init_database_tables as init_web_db
|
||||
from api.db.init_data import init_web_data
|
||||
from api.db.init_data import init_web_data, init_superuser
|
||||
from common.versions import get_ragflow_version
|
||||
from common.config_utils import show_configs
|
||||
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||
from common.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
|
||||
stop_event = threading.Event()
|
||||
@ -74,6 +73,8 @@ def signal_handler(sig, frame):
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
faulthandler.enable()
|
||||
init_root_logger("ragflow_server")
|
||||
logging.info(r"""
|
||||
____ ___ ______ ______ __
|
||||
/ __ \ / | / ____// ____// /____ _ __
|
||||
@ -110,11 +111,16 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
"--debug", default=False, help="debug mode", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init-superuser", default=False, help="init superuser", action="store_true"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.version:
|
||||
print(get_ragflow_version())
|
||||
sys.exit(0)
|
||||
|
||||
if args.init_superuser:
|
||||
init_superuser()
|
||||
RuntimeConfig.DEBUG = args.debug
|
||||
if RuntimeConfig.DEBUG:
|
||||
logging.info("run on debug mode")
|
||||
@ -153,14 +159,7 @@ if __name__ == '__main__':
|
||||
# start http server
|
||||
try:
|
||||
logging.info("RAGFlow HTTP server start...")
|
||||
run_simple(
|
||||
hostname=settings.HOST_IP,
|
||||
port=settings.HOST_PORT,
|
||||
application=app,
|
||||
threaded=True,
|
||||
use_reloader=RuntimeConfig.DEBUG,
|
||||
use_debugger=RuntimeConfig.DEBUG,
|
||||
)
|
||||
app.run(host=settings.HOST_IP, port=settings.HOST_PORT)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
stop_event.set()
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -24,20 +25,18 @@ from functools import wraps
|
||||
|
||||
import requests
|
||||
import trio
|
||||
from flask import (
|
||||
from quart import (
|
||||
Response,
|
||||
jsonify,
|
||||
request
|
||||
)
|
||||
from flask_login import current_user
|
||||
from flask import (
|
||||
request as flask_request,
|
||||
)
|
||||
|
||||
from peewee import OperationalError
|
||||
|
||||
from common.constants import ActiveEnum
|
||||
from api.db.db_models import APIToken
|
||||
from api.utils.json_encode import CustomJSONEncoder
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService
|
||||
from common.connection_utils import timeout
|
||||
from common.constants import RetCode
|
||||
@ -46,6 +45,12 @@ from common import settings
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
|
||||
|
||||
async def request_json():
|
||||
try:
|
||||
return await request.json
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def serialize_for_json(obj):
|
||||
"""
|
||||
Recursively serialize objects to make them JSON serializable.
|
||||
@ -84,7 +89,8 @@ def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!
|
||||
|
||||
|
||||
def server_error_response(e):
|
||||
logging.exception(e)
|
||||
# Quart invokes this handler outside the original except block, so we must pass exc_info manually.
|
||||
logging.error("Unhandled exception during request", exc_info=(type(e), e, e.__traceback__))
|
||||
try:
|
||||
msg = repr(e).lower()
|
||||
if getattr(e, "code", None) == 401 or ("unauthorized" in msg) or ("401" in msg):
|
||||
@ -105,31 +111,37 @@ def server_error_response(e):
|
||||
|
||||
|
||||
def validate_request(*args, **kwargs):
|
||||
def process_args(input_arguments):
|
||||
no_arguments = []
|
||||
error_arguments = []
|
||||
for arg in args:
|
||||
if arg not in input_arguments:
|
||||
no_arguments.append(arg)
|
||||
for k, v in kwargs.items():
|
||||
config_value = input_arguments.get(k, None)
|
||||
if config_value is None:
|
||||
no_arguments.append(k)
|
||||
elif isinstance(v, (tuple, list)):
|
||||
if config_value not in v:
|
||||
error_arguments.append((k, set(v)))
|
||||
elif config_value != v:
|
||||
error_arguments.append((k, v))
|
||||
if no_arguments or error_arguments:
|
||||
error_string = ""
|
||||
if no_arguments:
|
||||
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
|
||||
if error_arguments:
|
||||
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||
return error_string
|
||||
|
||||
def wrapper(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*_args, **_kwargs):
|
||||
input_arguments = flask_request.json or flask_request.form.to_dict()
|
||||
no_arguments = []
|
||||
error_arguments = []
|
||||
for arg in args:
|
||||
if arg not in input_arguments:
|
||||
no_arguments.append(arg)
|
||||
for k, v in kwargs.items():
|
||||
config_value = input_arguments.get(k, None)
|
||||
if config_value is None:
|
||||
no_arguments.append(k)
|
||||
elif isinstance(v, (tuple, list)):
|
||||
if config_value not in v:
|
||||
error_arguments.append((k, set(v)))
|
||||
elif config_value != v:
|
||||
error_arguments.append((k, v))
|
||||
if no_arguments or error_arguments:
|
||||
error_string = ""
|
||||
if no_arguments:
|
||||
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
|
||||
if error_arguments:
|
||||
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=error_string)
|
||||
async def decorated_function(*_args, **_kwargs):
|
||||
errs = process_args(await request.json or (await request.form).to_dict())
|
||||
if errs:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*_args, **_kwargs)
|
||||
return func(*_args, **_kwargs)
|
||||
|
||||
return decorated_function
|
||||
@ -138,30 +150,34 @@ def validate_request(*args, **kwargs):
|
||||
|
||||
|
||||
def not_allowed_parameters(*params):
|
||||
def decorator(f):
|
||||
def wrapper(*args, **kwargs):
|
||||
input_arguments = flask_request.json or flask_request.form.to_dict()
|
||||
def decorator(func):
|
||||
async def wrapper(*args, **kwargs):
|
||||
input_arguments = await request.json or (await request.form).to_dict()
|
||||
for param in params:
|
||||
if param in input_arguments:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
|
||||
return f(*args, **kwargs)
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def active_required(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
def active_required(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
from api.db.services import UserService
|
||||
from api.apps import current_user
|
||||
|
||||
user_id = current_user.id
|
||||
usr = UserService.filter_by_id(user_id)
|
||||
# check is_active
|
||||
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
|
||||
return get_json_result(code=RetCode.FORBIDDEN, message="User isn't active, please activate first.")
|
||||
return f(*args, **kwargs)
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -173,12 +189,15 @@ def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=Non
|
||||
|
||||
def apikey_required(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
token = flask_request.headers.get("Authorization").split()[1]
|
||||
async def decorated_function(*args, **kwargs):
|
||||
token = request.headers.get("Authorization").split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return build_error_result(message="API-KEY is invalid!", code=RetCode.FORBIDDEN)
|
||||
kwargs["tenant_id"] = objs[0].tenant_id
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
@ -199,23 +218,38 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da
|
||||
|
||||
|
||||
def token_required(func):
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
def get_tenant_id(**kwargs):
|
||||
if os.environ.get("DISABLE_SDK"):
|
||||
return get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
authorization_str = flask_request.headers.get("Authorization")
|
||||
return False, get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
authorization_str = request.headers.get("Authorization")
|
||||
if not authorization_str:
|
||||
return get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
return False, get_json_result(data=False, message="`Authorization` can't be empty")
|
||||
authorization_list = authorization_str.split()
|
||||
if len(authorization_list) < 2:
|
||||
return get_json_result(data=False, message="Please check your authorization format.")
|
||||
return False, get_json_result(data=False, message="Please check your authorization format.")
|
||||
token = authorization_list[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR)
|
||||
return False, get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR)
|
||||
kwargs["tenant_id"] = objs[0].tenant_id
|
||||
return True, kwargs
|
||||
|
||||
@wraps(func)
|
||||
def decorated_function(*args, **kwargs):
|
||||
e, kwargs = get_tenant_id(**kwargs)
|
||||
if not e:
|
||||
return kwargs
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
async def adecorated_function(*args, **kwargs):
|
||||
e, kwargs = get_tenant_id(**kwargs)
|
||||
if not e:
|
||||
return kwargs
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return adecorated_function
|
||||
return decorated_function
|
||||
|
||||
|
||||
@ -279,6 +313,10 @@ def get_parser_config(chunk_method, parser_config):
|
||||
chunk_method = "naive"
|
||||
|
||||
# Define default configurations for each chunking method
|
||||
base_defaults = {
|
||||
"table_context_size": 0,
|
||||
"image_context_size": 0,
|
||||
}
|
||||
key_mapping = {
|
||||
"naive": {
|
||||
"layout_recognize": "DeepDOC",
|
||||
@ -331,16 +369,19 @@ def get_parser_config(chunk_method, parser_config):
|
||||
|
||||
default_config = key_mapping[chunk_method]
|
||||
|
||||
# If no parser_config provided, return default
|
||||
# If no parser_config provided, return default merged with base defaults
|
||||
if not parser_config:
|
||||
return default_config
|
||||
if default_config is None:
|
||||
return deep_merge(base_defaults, {})
|
||||
return deep_merge(base_defaults, default_config)
|
||||
|
||||
# If parser_config is provided, merge with defaults to ensure required fields exist
|
||||
if default_config is None:
|
||||
return parser_config
|
||||
return deep_merge(base_defaults, parser_config)
|
||||
|
||||
# Ensure raptor and graphrag fields have default values if not provided
|
||||
merged_config = deep_merge(default_config, parser_config)
|
||||
merged_config = deep_merge(base_defaults, default_config)
|
||||
merged_config = deep_merge(merged_config, parser_config)
|
||||
|
||||
return merged_config
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ import base64
|
||||
import click
|
||||
import re
|
||||
|
||||
from flask import Flask
|
||||
from quart import Quart
|
||||
from werkzeug.security import generate_password_hash
|
||||
|
||||
from api.db.services import UserService
|
||||
@ -73,6 +73,7 @@ def reset_email(email, new_email, email_confirm):
|
||||
UserService.update_user(user[0].id,user_dict)
|
||||
click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
|
||||
|
||||
def register_commands(app: Flask):
|
||||
|
||||
def register_commands(app: Quart):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Reusable HTML email templates and registry.
|
||||
"""
|
||||
|
||||
@ -164,3 +164,23 @@ def read_potential_broken_pdf(blob):
|
||||
return repaired
|
||||
|
||||
return blob
|
||||
|
||||
|
||||
def sanitize_path(raw_path: str | None) -> str:
|
||||
"""Normalize and sanitize a user-provided path segment.
|
||||
|
||||
- Converts backslashes to forward slashes
|
||||
- Strips leading/trailing slashes
|
||||
- Removes '.' and '..' segments
|
||||
- Restricts characters to A-Za-z0-9, underscore, dash, and '/'
|
||||
"""
|
||||
if not raw_path:
|
||||
return ""
|
||||
backslash_re = re.compile(r"[\\]+")
|
||||
unsafe_re = re.compile(r"[^A-Za-z0-9_\-/]")
|
||||
normalized = backslash_re.sub("/", raw_path)
|
||||
normalized = normalized.strip("/")
|
||||
parts = [seg for seg in normalized.split("/") if seg and seg not in (".", "..")]
|
||||
sanitized = "/".join(parts)
|
||||
sanitized = unsafe_re.sub("", sanitized)
|
||||
return sanitized
|
||||
|
||||
@ -173,7 +173,8 @@ def check_task_executor_alive():
|
||||
heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats]
|
||||
task_executor_heartbeats[task_executor_id] = heartbeats
|
||||
if task_executor_heartbeats:
|
||||
return {"status": "alive", "message": task_executor_heartbeats}
|
||||
status = "alive" if any(task_executor_heartbeats.values()) else "timeout"
|
||||
return {"status": status, "message": task_executor_heartbeats}
|
||||
else:
|
||||
return {"status": "timeout", "message": "Not found any task executor."}
|
||||
except Exception as e:
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
@ -17,7 +17,7 @@ from collections import Counter
|
||||
from typing import Annotated, Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Request
|
||||
from quart import Request
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@ -32,7 +32,7 @@ from werkzeug.exceptions import BadRequest, UnsupportedMediaType
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
|
||||
|
||||
def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
|
||||
async def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""
|
||||
Validates and parses JSON requests through a multi-stage validation pipeline.
|
||||
|
||||
@ -81,7 +81,7 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
|
||||
from the final output after validation
|
||||
"""
|
||||
try:
|
||||
payload = request.get_json() or {}
|
||||
payload = await request.get_json() or {}
|
||||
except UnsupportedMediaType:
|
||||
return None, f"Unsupported content type: Expected application/json, got {request.content_type}"
|
||||
except BadRequest:
|
||||
|
||||
@ -23,7 +23,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from api.apps import smtp_mail_server
|
||||
from flask_mail import Message
|
||||
from flask import render_template_string
|
||||
from quart import render_template_string
|
||||
from api.utils.email_templates import EMAIL_TEMPLATES
|
||||
from selenium import webdriver
|
||||
from selenium.common.exceptions import TimeoutException
|
||||
|
||||
48
check_comment_ascii.py
Normal file
48
check_comment_ascii.py
Normal file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Check whether given python files contain non-ASCII comments.
|
||||
|
||||
How to check the whole git repo:
|
||||
|
||||
```
|
||||
$ git ls-files -z -- '*.py' | xargs -0 python3 check_comment_ascii.py
|
||||
```
|
||||
"""
|
||||
|
||||
import sys
|
||||
import tokenize
|
||||
import ast
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
ASCII = re.compile(r"^[\n -~]*\Z") # Printable ASCII + newline
|
||||
|
||||
|
||||
def check(src: str, name: str) -> int:
|
||||
"""
|
||||
docstring line 1
|
||||
docstring line 2
|
||||
"""
|
||||
ok = 1
|
||||
# A common comment begins with `#`
|
||||
with tokenize.open(src) as fp:
|
||||
for tk in tokenize.generate_tokens(fp.readline):
|
||||
if tk.type == tokenize.COMMENT and not ASCII.fullmatch(tk.string):
|
||||
print(f"{name}:{tk.start[0]}: non-ASCII comment: {tk.string}")
|
||||
ok = 0
|
||||
# A docstring begins and ends with `'''`
|
||||
for node in ast.walk(ast.parse(pathlib.Path(src).read_text(), filename=name)):
|
||||
if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
|
||||
if (doc := ast.get_docstring(node)) and not ASCII.fullmatch(doc):
|
||||
print(f"{name}:{node.lineno}: non-ASCII docstring: {doc}")
|
||||
ok = 0
|
||||
return ok
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
status = 0
|
||||
for file in sys.argv[1:]:
|
||||
if not check(file, file):
|
||||
status = 1
|
||||
sys.exit(status)
|
||||
@ -21,7 +21,7 @@ from typing import Any, Callable, Coroutine, Optional, Type, Union
|
||||
import asyncio
|
||||
import trio
|
||||
from functools import wraps
|
||||
from flask import make_response, jsonify
|
||||
from quart import make_response, jsonify
|
||||
from common.constants import RetCode
|
||||
|
||||
TimeoutException = Union[Type[BaseException], BaseException]
|
||||
@ -103,7 +103,7 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception:
|
||||
return decorator
|
||||
|
||||
|
||||
def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
|
||||
async def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
|
||||
result_dict = {"code": code, "message": message, "data": data}
|
||||
response_dict = {}
|
||||
for key, value in result_dict.items():
|
||||
@ -111,7 +111,27 @@ def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=
|
||||
continue
|
||||
else:
|
||||
response_dict[key] = value
|
||||
response = make_response(jsonify(response_dict))
|
||||
response = await make_response(jsonify(response_dict))
|
||||
if auth:
|
||||
response.headers["Authorization"] = auth
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Method"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Expose-Headers"] = "Authorization"
|
||||
return response
|
||||
|
||||
|
||||
def sync_construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
|
||||
import flask
|
||||
result_dict = {"code": code, "message": message, "data": data}
|
||||
response_dict = {}
|
||||
for key, value in result_dict.items():
|
||||
if value is None and key != "code":
|
||||
continue
|
||||
else:
|
||||
response_dict[key] = value
|
||||
response = flask.make_response(flask.jsonify(response_dict))
|
||||
if auth:
|
||||
response.headers["Authorization"] = auth
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
|
||||
@ -49,6 +49,7 @@ class RetCode(IntEnum, CustomEnum):
|
||||
RUNNING = 106
|
||||
PERMISSION_ERROR = 108
|
||||
AUTHENTICATION_ERROR = 109
|
||||
BAD_REQUEST = 400
|
||||
UNAUTHORIZED = 401
|
||||
SERVER_ERROR = 500
|
||||
FORBIDDEN = 403
|
||||
@ -118,6 +119,9 @@ class FileSource(StrEnum):
|
||||
SHAREPOINT = "sharepoint"
|
||||
SLACK = "slack"
|
||||
TEAMS = "teams"
|
||||
WEBDAV = "webdav"
|
||||
MOODLE = "moodle"
|
||||
DROPBOX = "dropbox"
|
||||
|
||||
|
||||
class PipelineTaskType(StrEnum):
|
||||
|
||||
@ -11,9 +11,11 @@ from .confluence_connector import ConfluenceConnector
|
||||
from .discord_connector import DiscordConnector
|
||||
from .dropbox_connector import DropboxConnector
|
||||
from .google_drive.connector import GoogleDriveConnector
|
||||
from .jira_connector import JiraConnector
|
||||
from .jira.connector import JiraConnector
|
||||
from .sharepoint_connector import SharePointConnector
|
||||
from .teams_connector import TeamsConnector
|
||||
from .webdav_connector import WebDAVConnector
|
||||
from .moodle_connector import MoodleConnector
|
||||
from .config import BlobType, DocumentSource
|
||||
from .models import Document, TextSection, ImageSection, BasicExpertInfo
|
||||
from .exceptions import (
|
||||
@ -36,6 +38,8 @@ __all__ = [
|
||||
"JiraConnector",
|
||||
"SharePointConnector",
|
||||
"TeamsConnector",
|
||||
"WebDAVConnector",
|
||||
"MoodleConnector",
|
||||
"BlobType",
|
||||
"DocumentSource",
|
||||
"Document",
|
||||
|
||||
@ -87,6 +87,13 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
|
||||
|
||||
elif self.bucket_type == BlobType.S3_COMPATIBLE:
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["endpoint_url", "aws_access_key_id", "aws_secret_access_key", "addressing_style"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("S3 Compatible Storage")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ def get_current_tz_offset() -> int:
|
||||
return round(time_diff.total_seconds() / 3600)
|
||||
|
||||
|
||||
ONE_MINUTE = 60
|
||||
ONE_HOUR = 3600
|
||||
ONE_DAY = ONE_HOUR * 24
|
||||
|
||||
@ -31,6 +32,7 @@ class BlobType(str, Enum):
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
S3_COMPATIBLE = "s3_compatible"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
@ -42,9 +44,14 @@ class DocumentSource(str, Enum):
|
||||
OCI_STORAGE = "oci_storage"
|
||||
SLACK = "slack"
|
||||
CONFLUENCE = "confluence"
|
||||
JIRA = "jira"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
GMAIL = "gmail"
|
||||
DISCORD = "discord"
|
||||
WEBDAV = "webdav"
|
||||
MOODLE = "moodle"
|
||||
S3_COMPATIBLE = "s3_compatible"
|
||||
DROPBOX = "dropbox"
|
||||
|
||||
|
||||
class FileOrigin(str, Enum):
|
||||
@ -178,6 +185,21 @@ GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
if ignored_tag
|
||||
]
|
||||
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
|
||||
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
|
||||
)
|
||||
JIRA_SYNC_TIME_BUFFER_SECONDS = int(
|
||||
os.environ.get("JIRA_SYNC_TIME_BUFFER_SECONDS", ONE_MINUTE)
|
||||
)
|
||||
JIRA_TIMEZONE_OFFSET = float(
|
||||
os.environ.get("JIRA_TIMEZONE_OFFSET", get_current_tz_offset())
|
||||
)
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||
|
||||
@ -1562,6 +1562,7 @@ class ConfluenceConnector(
|
||||
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
|
||||
doc_updated_at=datetime_from_string(page["version"]["when"]),
|
||||
primary_owners=primary_owners if primary_owners else None,
|
||||
metadata=metadata if metadata else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error converting page {page.get('id', 'unknown')}: {e}")
|
||||
@ -1788,6 +1789,7 @@ class ConfluenceConnector(
|
||||
cql_url = self.confluence_client.build_cql_url(
|
||||
page_query, expand=",".join(_PAGE_EXPANSION_FIELDS)
|
||||
)
|
||||
logging.info(f"[Confluence Connector] Building CQL URL {cql_url}")
|
||||
return update_param_in_path(cql_url, "limit", str(limit))
|
||||
|
||||
@override
|
||||
|
||||
@ -65,6 +65,7 @@ def _convert_message_to_document(
|
||||
blob=message.content.encode("utf-8"),
|
||||
extension=".txt",
|
||||
size_bytes=len(message.content.encode("utf-8")),
|
||||
metadata=metadata if metadata else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,13 +1,24 @@
|
||||
"""Dropbox connector"""
|
||||
|
||||
import logging
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from dropbox import Dropbox
|
||||
from dropbox.exceptions import ApiError, AuthError
|
||||
from dropbox.files import FileMetadata, FolderMetadata
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import ConnectorValidationError, InsufficientPermissionsError, ConnectorMissingCredentialError
|
||||
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
ConnectorValidationError,
|
||||
InsufficientPermissionsError,
|
||||
)
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
||||
from common.data_source.models import Document, GenerateDocumentsOutput
|
||||
from common.data_source.utils import get_file_ext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DropboxConnector(LoadConnector, PollConnector):
|
||||
@ -19,29 +30,29 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load Dropbox credentials"""
|
||||
try:
|
||||
access_token = credentials.get("dropbox_access_token")
|
||||
if not access_token:
|
||||
raise ConnectorMissingCredentialError("Dropbox access token is required")
|
||||
|
||||
self.dropbox_client = Dropbox(access_token)
|
||||
return None
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"Dropbox: {e}")
|
||||
access_token = credentials.get("dropbox_access_token")
|
||||
if not access_token:
|
||||
raise ConnectorMissingCredentialError("Dropbox access token is required")
|
||||
|
||||
self.dropbox_client = Dropbox(access_token)
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Dropbox connector settings"""
|
||||
if not self.dropbox_client:
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
|
||||
try:
|
||||
# Test connection by getting current account info
|
||||
self.dropbox_client.users_get_current_account()
|
||||
except (AuthError, ApiError) as e:
|
||||
if "invalid_access_token" in str(e).lower():
|
||||
raise InsufficientPermissionsError("Invalid Dropbox access token")
|
||||
else:
|
||||
raise ConnectorValidationError(f"Dropbox validation error: {e}")
|
||||
self.dropbox_client.files_list_folder(path="", limit=1)
|
||||
except AuthError as e:
|
||||
logger.exception("[Dropbox]: Failed to validate Dropbox credentials")
|
||||
raise ConnectorValidationError(f"Dropbox credential is invalid: {e}")
|
||||
except ApiError as e:
|
||||
if e.error is not None and "insufficient_permissions" in str(e.error).lower():
|
||||
raise InsufficientPermissionsError("Your Dropbox token does not have sufficient permissions.")
|
||||
raise ConnectorValidationError(f"Unexpected Dropbox error during validation: {e.user_message_text or e}")
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(f"Unexpected error during Dropbox settings validation: {e}")
|
||||
|
||||
def _download_file(self, path: str) -> bytes:
|
||||
"""Download a single file from Dropbox."""
|
||||
@ -54,26 +65,105 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
"""Create a shared link for a file in Dropbox."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
|
||||
try:
|
||||
# Try to get existing shared links first
|
||||
shared_links = self.dropbox_client.sharing_list_shared_links(path=path)
|
||||
if shared_links.links:
|
||||
return shared_links.links[0].url
|
||||
|
||||
# Create a new shared link
|
||||
link_settings = self.dropbox_client.sharing_create_shared_link_with_settings(path)
|
||||
return link_settings.url
|
||||
except Exception:
|
||||
# Fallback to basic link format
|
||||
return f"https://www.dropbox.com/home{path}"
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
link_metadata = self.dropbox_client.sharing_create_shared_link_with_settings(path)
|
||||
return link_metadata.url
|
||||
except ApiError as err:
|
||||
logger.exception(f"[Dropbox]: Failed to create a shared link for {path}: {err}")
|
||||
return ""
|
||||
|
||||
def _yield_files_recursive(
|
||||
self,
|
||||
path: str,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Yield files in batches from a specified Dropbox folder, including subfolders."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
result = self.dropbox_client.files_list_folder(
|
||||
path,
|
||||
limit=self.batch_size,
|
||||
recursive=False,
|
||||
include_non_downloadable_files=False,
|
||||
)
|
||||
|
||||
while True:
|
||||
batch: list[Document] = []
|
||||
for entry in result.entries:
|
||||
if isinstance(entry, FileMetadata):
|
||||
modified_time = entry.client_modified
|
||||
if modified_time.tzinfo is None:
|
||||
modified_time = modified_time.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
modified_time = modified_time.astimezone(timezone.utc)
|
||||
|
||||
time_as_seconds = modified_time.timestamp()
|
||||
if start is not None and time_as_seconds <= start:
|
||||
continue
|
||||
if end is not None and time_as_seconds > end:
|
||||
continue
|
||||
|
||||
try:
|
||||
downloaded_file = self._download_file(entry.path_display)
|
||||
except Exception:
|
||||
logger.exception(f"[Dropbox]: Error downloading file {entry.path_display}")
|
||||
continue
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"dropbox:{entry.id}",
|
||||
blob=downloaded_file,
|
||||
source=DocumentSource.DROPBOX,
|
||||
semantic_identifier=entry.name,
|
||||
extension=get_file_ext(entry.name),
|
||||
doc_updated_at=modified_time,
|
||||
size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file),
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(entry, FolderMetadata):
|
||||
yield from self._yield_files_recursive(entry.path_lower, start, end)
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
if not result.has_more:
|
||||
break
|
||||
|
||||
result = self.dropbox_client.files_list_folder_continue(result.cursor)
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput:
|
||||
"""Poll Dropbox for recent file changes"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
def load_from_state(self) -> Any:
|
||||
for batch in self._yield_files_recursive("", start, end):
|
||||
yield batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""Load files from Dropbox state"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
return self._yield_files_recursive("", None, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
connector = DropboxConnector()
|
||||
connector.load_credentials({"dropbox_access_token": os.environ.get("DROPBOX_ACCESS_TOKEN")})
|
||||
connector.validate_connector_settings()
|
||||
document_batches = connector.load_from_state()
|
||||
try:
|
||||
first_batch = next(document_batches)
|
||||
print(f"Loaded {len(first_batch)} documents in first batch.")
|
||||
for doc in first_batch:
|
||||
print(f"- {doc.semantic_identifier} ({doc.size_bytes} bytes)")
|
||||
except StopIteration:
|
||||
print("No documents available in Dropbox.")
|
||||
|
||||
@ -3,15 +3,9 @@ import os
|
||||
import threading
|
||||
from typing import Any, Callable
|
||||
|
||||
import requests
|
||||
|
||||
from common.data_source.config import DocumentSource
|
||||
from common.data_source.google_util.constant import GOOGLE_SCOPES
|
||||
|
||||
GOOGLE_DEVICE_CODE_URL = "https://oauth2.googleapis.com/device/code"
|
||||
GOOGLE_DEVICE_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
DEFAULT_DEVICE_INTERVAL = 5
|
||||
|
||||
|
||||
def _get_requested_scopes(source: DocumentSource) -> list[str]:
|
||||
"""Return the scopes to request, honoring an optional override env var."""
|
||||
@ -55,62 +49,6 @@ def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_messag
|
||||
return result.get("value")
|
||||
|
||||
|
||||
def _extract_client_info(credentials: dict[str, Any]) -> tuple[str, str | None]:
|
||||
if "client_id" in credentials:
|
||||
return credentials["client_id"], credentials.get("client_secret")
|
||||
for key in ("installed", "web"):
|
||||
if key in credentials and isinstance(credentials[key], dict):
|
||||
nested = credentials[key]
|
||||
if "client_id" not in nested:
|
||||
break
|
||||
return nested["client_id"], nested.get("client_secret")
|
||||
raise ValueError("Provided Google OAuth credentials are missing client_id.")
|
||||
|
||||
|
||||
def start_device_authorization_flow(
|
||||
credentials: dict[str, Any],
|
||||
source: DocumentSource,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
client_id, client_secret = _extract_client_info(credentials)
|
||||
data = {
|
||||
"client_id": client_id,
|
||||
"scope": " ".join(_get_requested_scopes(source)),
|
||||
}
|
||||
if client_secret:
|
||||
data["client_secret"] = client_secret
|
||||
resp = requests.post(GOOGLE_DEVICE_CODE_URL, data=data, timeout=15)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
state = {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"device_code": payload.get("device_code"),
|
||||
"interval": payload.get("interval", DEFAULT_DEVICE_INTERVAL),
|
||||
}
|
||||
response_data = {
|
||||
"user_code": payload.get("user_code"),
|
||||
"verification_url": payload.get("verification_url") or payload.get("verification_uri"),
|
||||
"verification_url_complete": payload.get("verification_url_complete")
|
||||
or payload.get("verification_uri_complete"),
|
||||
"expires_in": payload.get("expires_in"),
|
||||
"interval": state["interval"],
|
||||
}
|
||||
return state, response_data
|
||||
|
||||
|
||||
def poll_device_authorization_flow(state: dict[str, Any]) -> dict[str, Any]:
|
||||
data = {
|
||||
"client_id": state["client_id"],
|
||||
"device_code": state["device_code"],
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
}
|
||||
if state.get("client_secret"):
|
||||
data["client_secret"] = state["client_secret"]
|
||||
resp = requests.post(GOOGLE_DEVICE_TOKEN_URL, data=data, timeout=20)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
|
||||
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
@ -125,10 +63,7 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
|
||||
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
|
||||
port = int(preferred_port) if preferred_port else 0
|
||||
timeout_secs = _get_oauth_timeout_secs()
|
||||
timeout_message = (
|
||||
f"Google OAuth verification timed out after {timeout_secs} seconds. "
|
||||
"Close any pending consent windows and rerun the connector configuration to try again."
|
||||
)
|
||||
timeout_message = f"Google OAuth verification timed out after {timeout_secs} seconds. Close any pending consent windows and rerun the connector configuration to try again."
|
||||
|
||||
print("Launching Google OAuth flow. A browser window should open shortly.")
|
||||
print("If it does not, copy the URL shown in the console into your browser manually.")
|
||||
@ -153,11 +88,8 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
|
||||
instructions = [
|
||||
"Google rejected one or more of the requested OAuth scopes.",
|
||||
"Fix options:",
|
||||
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
|
||||
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
|
||||
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes (Drive metadata + Admin Directory read scopes), then re-run the flow.",
|
||||
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
|
||||
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
|
||||
" (be aware the connector may lose functionality).",
|
||||
]
|
||||
raise RuntimeError("\n".join(instructions)) from warning
|
||||
raise
|
||||
@ -184,8 +116,6 @@ def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource)
|
||||
client_config = {"web": credentials["web"]}
|
||||
|
||||
if client_config is None:
|
||||
raise ValueError(
|
||||
"Provided Google OAuth credentials are missing both tokens and a client configuration."
|
||||
)
|
||||
raise ValueError("Provided Google OAuth credentials are missing both tokens and a client configuration.")
|
||||
|
||||
return _run_local_server_flow(client_config, source)
|
||||
|
||||
@ -69,7 +69,7 @@ class SlimConnectorWithPermSync(ABC):
|
||||
|
||||
|
||||
class CheckpointedConnectorWithPermSync(ABC):
|
||||
"""Checkpointed connector interface (with permission sync)"""
|
||||
"""Checkpoint connector interface (with permission sync)"""
|
||||
|
||||
@abstractmethod
|
||||
def load_from_checkpoint(
|
||||
@ -143,7 +143,7 @@ class CredentialsProviderInterface(abc.ABC, Generic[T]):
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_dynamic(self) -> bool:
|
||||
"""If dynamic, the credentials may change during usage ... maening the client
|
||||
"""If dynamic, the credentials may change during usage ... meaning the client
|
||||
needs to use the locking features of the credentials provider to operate
|
||||
correctly.
|
||||
|
||||
|
||||
0
common/data_source/jira/__init__.py
Normal file
0
common/data_source/jira/__init__.py
Normal file
973
common/data_source/jira/connector.py
Normal file
973
common/data_source/jira/connector.py
Normal file
@ -0,0 +1,973 @@
|
||||
"""Checkpointed Jira connector that emits markdown blobs for each issue."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
from pydantic import Field
|
||||
|
||||
from common.data_source.config import (
|
||||
INDEX_BATCH_SIZE,
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP,
|
||||
JIRA_CONNECTOR_MAX_TICKET_SIZE,
|
||||
JIRA_TIMEZONE_OFFSET,
|
||||
ONE_HOUR,
|
||||
DocumentSource,
|
||||
)
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
ConnectorValidationError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError,
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
CheckpointedConnectorWithPermSync,
|
||||
CheckpointOutputWrapper,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync,
|
||||
)
|
||||
from common.data_source.jira.utils import (
|
||||
JIRA_CLOUD_API_VERSION,
|
||||
JIRA_SERVER_API_VERSION,
|
||||
build_issue_url,
|
||||
extract_body_text,
|
||||
extract_named_value,
|
||||
extract_user,
|
||||
format_attachments,
|
||||
format_comments,
|
||||
parse_jira_datetime,
|
||||
should_skip_issue,
|
||||
)
|
||||
from common.data_source.models import (
|
||||
ConnectorCheckpoint,
|
||||
ConnectorFailure,
|
||||
Document,
|
||||
DocumentFailure,
|
||||
SlimDocument,
|
||||
)
|
||||
from common.data_source.utils import is_atlassian_cloud_url, is_atlassian_date_error, scoped_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_FIELDS = "summary,description,updated,created,status,priority,assignee,reporter,labels,issuetype,project,comment,attachment"
|
||||
_SLIM_FIELDS = "key,project"
|
||||
_MAX_RESULTS_FETCH_IDS = 5000
|
||||
_JIRA_SLIM_PAGE_SIZE = 500
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
_DEFAULT_ATTACHMENT_SIZE_LIMIT = 10 * 1024 * 1024 # 10MB
|
||||
|
||||
|
||||
class JiraCheckpoint(ConnectorCheckpoint):
|
||||
"""Checkpoint that tracks which slice of the current JQL result set was emitted."""
|
||||
|
||||
start_at: int = 0
|
||||
cursor: str | None = None
|
||||
ids_done: bool = False
|
||||
all_issue_ids: list[list[str]] = Field(default_factory=list)
|
||||
|
||||
|
||||
_TZ_OFFSET_PATTERN = re.compile(r"([+-])(\d{2})(:?)(\d{2})$")
|
||||
|
||||
|
||||
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||
"""Retrieve Jira issues and emit them as markdown documents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
jira_base_url: str,
|
||||
project_key: str | None = None,
|
||||
jql_query: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
include_comments: bool = True,
|
||||
include_attachments: bool = False,
|
||||
labels_to_skip: Sequence[str] | None = None,
|
||||
comment_email_blacklist: Sequence[str] | None = None,
|
||||
scoped_token: bool = False,
|
||||
attachment_size_limit: int | None = None,
|
||||
timezone_offset: float | None = None,
|
||||
) -> None:
|
||||
if not jira_base_url:
|
||||
raise ConnectorValidationError("Jira base URL must be provided.")
|
||||
|
||||
self.jira_base_url = jira_base_url.rstrip("/")
|
||||
self.project_key = project_key
|
||||
self.jql_query = jql_query
|
||||
self.batch_size = batch_size
|
||||
self.include_comments = include_comments
|
||||
self.include_attachments = include_attachments
|
||||
configured_labels = labels_to_skip or JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
self.labels_to_skip = {label.lower() for label in configured_labels}
|
||||
self.comment_email_blacklist = {email.lower() for email in comment_email_blacklist or []}
|
||||
self.scoped_token = scoped_token
|
||||
self.jira_client: JIRA | None = None
|
||||
|
||||
self.max_ticket_size = JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
self.attachment_size_limit = attachment_size_limit if attachment_size_limit and attachment_size_limit > 0 else _DEFAULT_ATTACHMENT_SIZE_LIMIT
|
||||
self._fields_param = _DEFAULT_FIELDS
|
||||
self._slim_fields = _SLIM_FIELDS
|
||||
|
||||
tz_offset_value = float(timezone_offset) if timezone_offset is not None else float(JIRA_TIMEZONE_OFFSET)
|
||||
self.timezone_offset = tz_offset_value
|
||||
self.timezone = timezone(offset=timedelta(hours=tz_offset_value))
|
||||
self._timezone_overridden = timezone_offset is not None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Connector lifecycle helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Instantiate the Jira client using either an API token or username/password."""
|
||||
jira_url_for_client = self.jira_base_url
|
||||
if self.scoped_token:
|
||||
if is_atlassian_cloud_url(self.jira_base_url):
|
||||
try:
|
||||
jira_url_for_client = scoped_url(self.jira_base_url, "jira")
|
||||
except ValueError as exc:
|
||||
raise ConnectorValidationError(str(exc)) from exc
|
||||
else:
|
||||
logger.warning(f"[Jira] Scoped token requested but Jira base URL {self.jira_base_url} does not appear to be an Atlassian Cloud domain; scoped token ignored.")
|
||||
|
||||
user_email = credentials.get("jira_user_email") or credentials.get("username")
|
||||
api_token = credentials.get("jira_api_token") or credentials.get("token") or credentials.get("api_token")
|
||||
password = credentials.get("jira_password") or credentials.get("password")
|
||||
rest_api_version = credentials.get("rest_api_version")
|
||||
|
||||
if not rest_api_version:
|
||||
rest_api_version = JIRA_CLOUD_API_VERSION if api_token else JIRA_SERVER_API_VERSION
|
||||
options: dict[str, Any] = {"rest_api_version": rest_api_version}
|
||||
|
||||
try:
|
||||
if user_email and api_token:
|
||||
self.jira_client = JIRA(
|
||||
server=jira_url_for_client,
|
||||
basic_auth=(user_email, api_token),
|
||||
options=options,
|
||||
)
|
||||
elif api_token:
|
||||
self.jira_client = JIRA(
|
||||
server=jira_url_for_client,
|
||||
token_auth=api_token,
|
||||
options=options,
|
||||
)
|
||||
elif user_email and password:
|
||||
self.jira_client = JIRA(
|
||||
server=jira_url_for_client,
|
||||
basic_auth=(user_email, password),
|
||||
options=options,
|
||||
)
|
||||
else:
|
||||
raise ConnectorMissingCredentialError("Jira credentials must include either an API token or username/password.")
|
||||
except Exception as exc: # pragma: no cover - jira lib raises many types
|
||||
raise ConnectorMissingCredentialError(f"Jira: {exc}") from exc
|
||||
self._sync_timezone_from_server()
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate connectivity by fetching basic Jira info."""
|
||||
if not self.jira_client:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
try:
|
||||
if self.jql_query:
|
||||
dummy_checkpoint = self.build_dummy_checkpoint()
|
||||
checkpoint_callback = self._make_checkpoint_callback(dummy_checkpoint)
|
||||
iterator = self._perform_jql_search(
|
||||
jql=self.jql_query,
|
||||
start=0,
|
||||
max_results=1,
|
||||
fields="key",
|
||||
all_issue_ids=dummy_checkpoint.all_issue_ids,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
next_page_token=dummy_checkpoint.cursor,
|
||||
ids_done=dummy_checkpoint.ids_done,
|
||||
)
|
||||
next(iter(iterator), None)
|
||||
elif self.project_key:
|
||||
self.jira_client.project(self.project_key)
|
||||
else:
|
||||
self.jira_client.projects()
|
||||
except Exception as exc: # pragma: no cover - dependent on Jira responses
|
||||
self._handle_validation_error(exc)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Checkpointed connector implementation
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: JiraCheckpoint,
|
||||
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
|
||||
"""Load Jira issues, emitting a Document per issue."""
|
||||
try:
|
||||
return (yield from self._load_with_retry(start, end, checkpoint))
|
||||
except Exception as exc:
|
||||
logger.exception(f"[Jira] Jira query ultimately failed: {exc}")
|
||||
yield ConnectorFailure(
|
||||
failure_message=f"Failed to query Jira: {exc}",
|
||||
exception=exc,
|
||||
)
|
||||
return JiraCheckpoint(has_more=False, start_at=checkpoint.start_at)
|
||||
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: JiraCheckpoint,
|
||||
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
|
||||
"""Permissions are not synced separately, so reuse the standard loader."""
|
||||
return (yield from self.load_from_checkpoint(start=start, end=end, checkpoint=checkpoint))
|
||||
|
||||
def _load_with_retry(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: JiraCheckpoint,
|
||||
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
|
||||
if not self.jira_client:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
attempt_start = start
|
||||
retried_with_buffer = False
|
||||
attempt = 0
|
||||
|
||||
while True:
|
||||
attempt += 1
|
||||
jql = self._build_jql(attempt_start, end)
|
||||
logger.info(f"[Jira] Executing Jira JQL attempt {attempt} (start={attempt_start}, end={end}, buffered_retry={retried_with_buffer}): {jql}")
|
||||
try:
|
||||
return (yield from self._load_from_checkpoint_internal(jql, checkpoint, start_filter=start))
|
||||
except Exception as exc:
|
||||
if attempt_start is not None and not retried_with_buffer and is_atlassian_date_error(exc):
|
||||
attempt_start = attempt_start - ONE_HOUR
|
||||
retried_with_buffer = True
|
||||
logger.info(f"[Jira] Atlassian date error detected; retrying with start={attempt_start}.")
|
||||
continue
|
||||
raise
|
||||
|
||||
def _handle_validation_error(self, exc: Exception) -> None:
|
||||
status_code = getattr(exc, "status_code", None)
|
||||
if status_code == 401:
|
||||
raise InsufficientPermissionsError("Jira credential appears to be invalid or expired (HTTP 401).") from exc
|
||||
if status_code == 403:
|
||||
raise InsufficientPermissionsError("Jira token does not have permission to access the requested resources (HTTP 403).") from exc
|
||||
if status_code == 404:
|
||||
raise ConnectorValidationError("Jira resource not found (HTTP 404).") from exc
|
||||
if status_code == 429:
|
||||
raise ConnectorValidationError("Jira rate limit exceeded during validation (HTTP 429).") from exc
|
||||
|
||||
message = getattr(exc, "text", str(exc))
|
||||
if not message:
|
||||
raise UnexpectedValidationError("Unexpected Jira validation error.") from exc
|
||||
|
||||
raise ConnectorValidationError(f"Jira validation failed: {message}") from exc
|
||||
|
||||
def _load_from_checkpoint_internal(
|
||||
self,
|
||||
jql: str,
|
||||
checkpoint: JiraCheckpoint,
|
||||
start_filter: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
|
||||
assert self.jira_client, "load_credentials must be called before loading issues."
|
||||
|
||||
page_size = self._full_page_size()
|
||||
new_checkpoint = copy.deepcopy(checkpoint)
|
||||
starting_offset = new_checkpoint.start_at or 0
|
||||
current_offset = starting_offset
|
||||
checkpoint_callback = self._make_checkpoint_callback(new_checkpoint)
|
||||
|
||||
issue_iter = self._perform_jql_search(
|
||||
jql=jql,
|
||||
start=current_offset,
|
||||
max_results=page_size,
|
||||
fields=self._fields_param,
|
||||
all_issue_ids=new_checkpoint.all_issue_ids,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
next_page_token=new_checkpoint.cursor,
|
||||
ids_done=new_checkpoint.ids_done,
|
||||
)
|
||||
|
||||
start_cutoff = float(start_filter) if start_filter is not None else None
|
||||
|
||||
for issue in issue_iter:
|
||||
current_offset += 1
|
||||
issue_key = getattr(issue, "key", "unknown")
|
||||
if should_skip_issue(issue, self.labels_to_skip):
|
||||
continue
|
||||
|
||||
issue_updated = parse_jira_datetime(issue.raw.get("fields", {}).get("updated"))
|
||||
if start_cutoff is not None and issue_updated is not None and issue_updated.timestamp() <= start_cutoff:
|
||||
# Jira JQL only supports minute precision, so we discard already-processed
|
||||
# issues here based on the original second-level cutoff.
|
||||
continue
|
||||
|
||||
try:
|
||||
document = self._issue_to_document(issue)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.exception(f"[Jira] Failed to convert Jira issue {issue_key}: {exc}")
|
||||
yield ConnectorFailure(
|
||||
failure_message=f"Failed to convert Jira issue {issue_key}: {exc}",
|
||||
failed_document=DocumentFailure(
|
||||
document_id=issue_key,
|
||||
document_link=build_issue_url(self.jira_base_url, issue_key),
|
||||
),
|
||||
exception=exc,
|
||||
)
|
||||
continue
|
||||
|
||||
if document is not None:
|
||||
yield document
|
||||
if self.include_attachments:
|
||||
for attachment_document in self._attachment_documents(issue):
|
||||
if attachment_document is not None:
|
||||
yield attachment_document
|
||||
|
||||
self._update_checkpoint_for_next_run(
|
||||
checkpoint=new_checkpoint,
|
||||
current_offset=current_offset,
|
||||
starting_offset=starting_offset,
|
||||
page_size=page_size,
|
||||
)
|
||||
new_checkpoint.start_at = current_offset
|
||||
return new_checkpoint
|
||||
|
||||
def build_dummy_checkpoint(self) -> JiraCheckpoint:
|
||||
"""Create an empty checkpoint used to kick off ingestion."""
|
||||
return JiraCheckpoint(has_more=True, start_at=0)
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> JiraCheckpoint:
|
||||
"""Validate a serialized checkpoint."""
|
||||
return JiraCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Slim connector implementation
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None, # noqa: ARG002 - maintained for interface compatibility
|
||||
) -> Generator[list[SlimDocument], None, None]:
|
||||
"""Return lightweight references to Jira issues (used for permission syncing)."""
|
||||
if not self.jira_client:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
start_ts = start if start is not None else 0
|
||||
end_ts = end if end is not None else datetime.now(timezone.utc).timestamp()
|
||||
jql = self._build_jql(start_ts, end_ts)
|
||||
|
||||
checkpoint = self.build_dummy_checkpoint()
|
||||
checkpoint_callback = self._make_checkpoint_callback(checkpoint)
|
||||
prev_offset = 0
|
||||
current_offset = 0
|
||||
slim_batch: list[SlimDocument] = []
|
||||
|
||||
while checkpoint.has_more:
|
||||
for issue in self._perform_jql_search(
|
||||
jql=jql,
|
||||
start=current_offset,
|
||||
max_results=_JIRA_SLIM_PAGE_SIZE,
|
||||
fields=self._slim_fields,
|
||||
all_issue_ids=checkpoint.all_issue_ids,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
next_page_token=checkpoint.cursor,
|
||||
ids_done=checkpoint.ids_done,
|
||||
):
|
||||
current_offset += 1
|
||||
if should_skip_issue(issue, self.labels_to_skip):
|
||||
continue
|
||||
|
||||
doc_id = build_issue_url(self.jira_base_url, issue.key)
|
||||
slim_batch.append(SlimDocument(id=doc_id))
|
||||
|
||||
if len(slim_batch) >= _JIRA_SLIM_PAGE_SIZE:
|
||||
yield slim_batch
|
||||
slim_batch = []
|
||||
|
||||
self._update_checkpoint_for_next_run(
|
||||
checkpoint=checkpoint,
|
||||
current_offset=current_offset,
|
||||
starting_offset=prev_offset,
|
||||
page_size=_JIRA_SLIM_PAGE_SIZE,
|
||||
)
|
||||
prev_offset = current_offset
|
||||
|
||||
if slim_batch:
|
||||
yield slim_batch
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _build_jql(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> str:
|
||||
clauses: list[str] = []
|
||||
if self.jql_query:
|
||||
clauses.append(f"({self.jql_query})")
|
||||
elif self.project_key:
|
||||
clauses.append(f'project = "{self.project_key}"')
|
||||
else:
|
||||
raise ConnectorValidationError("Either project_key or jql_query must be provided for Jira connector.")
|
||||
|
||||
if self.labels_to_skip:
|
||||
labels = ", ".join(f'"{label}"' for label in self.labels_to_skip)
|
||||
clauses.append(f"labels NOT IN ({labels})")
|
||||
|
||||
if start is not None:
|
||||
clauses.append(f'updated >= "{self._format_jql_time(start)}"')
|
||||
if end is not None:
|
||||
clauses.append(f'updated <= "{self._format_jql_time(end)}"')
|
||||
|
||||
if not clauses:
|
||||
raise ConnectorValidationError("Unable to build Jira JQL query.")
|
||||
|
||||
jql = " AND ".join(clauses)
|
||||
if "order by" not in jql.lower():
|
||||
jql = f"{jql} ORDER BY updated ASC"
|
||||
return jql
|
||||
|
||||
def _format_jql_time(self, timestamp: SecondsSinceUnixEpoch) -> str:
|
||||
dt_utc = datetime.fromtimestamp(float(timestamp), tz=timezone.utc)
|
||||
dt_local = dt_utc.astimezone(self.timezone)
|
||||
# Jira only accepts minute-precision timestamps in JQL, so we format accordingly
|
||||
# and rely on a post-query second-level filter to avoid duplicates.
|
||||
return dt_local.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
def _issue_to_document(self, issue: Issue) -> Document | None:
|
||||
fields = issue.raw.get("fields", {})
|
||||
summary = fields.get("summary") or ""
|
||||
description_text = extract_body_text(fields.get("description"))
|
||||
comments_text = (
|
||||
format_comments(
|
||||
fields.get("comment"),
|
||||
blacklist=self.comment_email_blacklist,
|
||||
)
|
||||
if self.include_comments
|
||||
else ""
|
||||
)
|
||||
attachments_text = format_attachments(fields.get("attachment"))
|
||||
|
||||
reporter_name, reporter_email = extract_user(fields.get("reporter"))
|
||||
assignee_name, assignee_email = extract_user(fields.get("assignee"))
|
||||
status = extract_named_value(fields.get("status"))
|
||||
priority = extract_named_value(fields.get("priority"))
|
||||
issue_type = extract_named_value(fields.get("issuetype"))
|
||||
project = fields.get("project") or {}
|
||||
|
||||
issue_url = build_issue_url(self.jira_base_url, issue.key)
|
||||
|
||||
metadata_lines = [
|
||||
f"key: {issue.key}",
|
||||
f"url: {issue_url}",
|
||||
f"summary: {summary}",
|
||||
f"status: {status or 'Unknown'}",
|
||||
f"priority: {priority or 'Unspecified'}",
|
||||
f"issue_type: {issue_type or 'Unknown'}",
|
||||
f"project: {project.get('name') or ''}",
|
||||
f"project_key: {project.get('key') or self.project_key or ''}",
|
||||
]
|
||||
|
||||
if reporter_name:
|
||||
metadata_lines.append(f"reporter: {reporter_name}")
|
||||
if reporter_email:
|
||||
metadata_lines.append(f"reporter_email: {reporter_email}")
|
||||
if assignee_name:
|
||||
metadata_lines.append(f"assignee: {assignee_name}")
|
||||
if assignee_email:
|
||||
metadata_lines.append(f"assignee_email: {assignee_email}")
|
||||
if fields.get("labels"):
|
||||
metadata_lines.append(f"labels: {', '.join(fields.get('labels'))}")
|
||||
|
||||
created_dt = parse_jira_datetime(fields.get("created"))
|
||||
updated_dt = parse_jira_datetime(fields.get("updated")) or created_dt or datetime.now(timezone.utc)
|
||||
metadata_lines.append(f"created: {created_dt.isoformat() if created_dt else ''}")
|
||||
metadata_lines.append(f"updated: {updated_dt.isoformat() if updated_dt else ''}")
|
||||
|
||||
sections: list[str] = [
|
||||
"---",
|
||||
"\n".join(filter(None, metadata_lines)),
|
||||
"---",
|
||||
"",
|
||||
"## Description",
|
||||
description_text or "No description provided.",
|
||||
]
|
||||
|
||||
if comments_text:
|
||||
sections.extend(["", "## Comments", comments_text])
|
||||
if attachments_text:
|
||||
sections.extend(["", "## Attachments", attachments_text])
|
||||
|
||||
blob_text = "\n".join(sections).strip() + "\n"
|
||||
blob = blob_text.encode("utf-8")
|
||||
|
||||
if len(blob) > self.max_ticket_size:
|
||||
logger.info(f"[Jira] Skipping {issue.key} because it exceeds the maximum size of {self.max_ticket_size} bytes.")
|
||||
return None
|
||||
|
||||
semantic_identifier = f"{issue.key}: {summary}" if summary else issue.key
|
||||
|
||||
return Document(
|
||||
id=issue_url,
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=semantic_identifier,
|
||||
extension=".md",
|
||||
blob=blob,
|
||||
doc_updated_at=updated_dt,
|
||||
size_bytes=len(blob),
|
||||
)
|
||||
|
||||
def _attachment_documents(self, issue: Issue) -> Iterable[Document]:
|
||||
attachments = issue.raw.get("fields", {}).get("attachment") or []
|
||||
for attachment in attachments:
|
||||
try:
|
||||
document = self._attachment_to_document(issue, attachment)
|
||||
if document is not None:
|
||||
yield document
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
failed_id = attachment.get("id") or attachment.get("filename")
|
||||
issue_key = getattr(issue, "key", "unknown")
|
||||
logger.warning(f"[Jira] Failed to process attachment {failed_id} for issue {issue_key}: {exc}")
|
||||
|
||||
def _attachment_to_document(self, issue: Issue, attachment: dict[str, Any]) -> Document | None:
|
||||
if not self.include_attachments:
|
||||
return None
|
||||
|
||||
filename = attachment.get("filename")
|
||||
content_url = attachment.get("content")
|
||||
if not filename or not content_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
attachment_size = int(attachment.get("size", 0))
|
||||
except (TypeError, ValueError):
|
||||
attachment_size = 0
|
||||
if attachment_size and attachment_size > self.attachment_size_limit:
|
||||
logger.info(f"[Jira] Skipping attachment {filename} on {issue.key} because reported size exceeds limit ({self.attachment_size_limit} bytes).")
|
||||
return None
|
||||
|
||||
blob = self._download_attachment(content_url)
|
||||
if blob is None:
|
||||
return None
|
||||
|
||||
if len(blob) > self.attachment_size_limit:
|
||||
logger.info(f"[Jira] Skipping attachment {filename} on {issue.key} because it exceeds the size limit ({self.attachment_size_limit} bytes).")
|
||||
return None
|
||||
|
||||
attachment_time = parse_jira_datetime(attachment.get("created")) or parse_jira_datetime(attachment.get("updated"))
|
||||
updated_dt = attachment_time or parse_jira_datetime(issue.raw.get("fields", {}).get("updated")) or datetime.now(timezone.utc)
|
||||
|
||||
extension = os.path.splitext(filename)[1] or ""
|
||||
document_id = f"{issue.key}::attachment::{attachment.get('id') or filename}"
|
||||
semantic_identifier = f"{issue.key} attachment: {filename}"
|
||||
|
||||
return Document(
|
||||
id=document_id,
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=semantic_identifier,
|
||||
extension=extension,
|
||||
blob=blob,
|
||||
doc_updated_at=updated_dt,
|
||||
size_bytes=len(blob),
|
||||
)
|
||||
|
||||
def _download_attachment(self, url: str) -> bytes | None:
|
||||
if not self.jira_client:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
response = self.jira_client._session.get(url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
def _sync_timezone_from_server(self) -> None:
|
||||
if self._timezone_overridden or not self.jira_client:
|
||||
return
|
||||
try:
|
||||
server_info = self.jira_client.server_info()
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.info(f"[Jira] Unable to determine timezone from server info; continuing with offset {self.timezone_offset}. Error: {exc}")
|
||||
return
|
||||
|
||||
detected_offset = self._extract_timezone_offset(server_info)
|
||||
if detected_offset is None or detected_offset == self.timezone_offset:
|
||||
return
|
||||
|
||||
self.timezone_offset = detected_offset
|
||||
self.timezone = timezone(offset=timedelta(hours=detected_offset))
|
||||
logger.info(f"[Jira] Timezone offset adjusted to {detected_offset} hours using Jira server info.")
|
||||
|
||||
def _extract_timezone_offset(self, server_info: dict[str, Any]) -> float | None:
|
||||
server_time_raw = server_info.get("serverTime")
|
||||
if isinstance(server_time_raw, str):
|
||||
offset = self._parse_offset_from_datetime_string(server_time_raw)
|
||||
if offset is not None:
|
||||
return offset
|
||||
|
||||
tz_name = server_info.get("timeZone")
|
||||
if isinstance(tz_name, str):
|
||||
offset = self._offset_from_zone_name(tz_name)
|
||||
if offset is not None:
|
||||
return offset
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_offset_from_datetime_string(value: str) -> float | None:
|
||||
normalized = JiraConnector._normalize_datetime_string(value)
|
||||
try:
|
||||
dt = datetime.fromisoformat(normalized)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
if dt.tzinfo is None:
|
||||
return 0.0
|
||||
|
||||
offset = dt.tzinfo.utcoffset(dt)
|
||||
if offset is None:
|
||||
return None
|
||||
return offset.total_seconds() / 3600.0
|
||||
|
||||
@staticmethod
|
||||
def _normalize_datetime_string(value: str) -> str:
|
||||
trimmed = (value or "").strip()
|
||||
if trimmed.endswith("Z"):
|
||||
return f"{trimmed[:-1]}+00:00"
|
||||
|
||||
match = _TZ_OFFSET_PATTERN.search(trimmed)
|
||||
if match and match.group(3) != ":":
|
||||
sign, hours, _, minutes = match.groups()
|
||||
trimmed = f"{trimmed[: match.start()]}{sign}{hours}:{minutes}"
|
||||
return trimmed
|
||||
|
||||
@staticmethod
|
||||
def _offset_from_zone_name(name: str) -> float | None:
|
||||
try:
|
||||
tz = ZoneInfo(name)
|
||||
except (ZoneInfoNotFoundError, ValueError):
|
||||
return None
|
||||
reference = datetime.now(tz)
|
||||
offset = reference.utcoffset()
|
||||
if offset is None:
|
||||
return None
|
||||
return offset.total_seconds() / 3600.0
|
||||
|
||||
def _is_cloud_client(self) -> bool:
|
||||
if not self.jira_client:
|
||||
return False
|
||||
rest_version = str(self.jira_client._options.get("rest_api_version", "")).strip()
|
||||
return rest_version == str(JIRA_CLOUD_API_VERSION)
|
||||
|
||||
def _full_page_size(self) -> int:
|
||||
return max(1, min(self.batch_size, _JIRA_FULL_PAGE_SIZE))
|
||||
|
||||
def _perform_jql_search(
|
||||
self,
|
||||
*,
|
||||
jql: str,
|
||||
start: int,
|
||||
max_results: int,
|
||||
fields: str | None = None,
|
||||
all_issue_ids: list[list[str]] | None = None,
|
||||
checkpoint_callback: Callable[[Iterable[list[str]], str | None], None] | None = None,
|
||||
next_page_token: str | None = None,
|
||||
ids_done: bool = False,
|
||||
) -> Iterable[Issue]:
|
||||
assert self.jira_client, "Jira client not initialized."
|
||||
is_cloud = self._is_cloud_client()
|
||||
if is_cloud:
|
||||
if all_issue_ids is None:
|
||||
raise ValueError("all_issue_ids is required for Jira Cloud searches.")
|
||||
yield from self._perform_jql_search_v3(
|
||||
jql=jql,
|
||||
max_results=max_results,
|
||||
fields=fields,
|
||||
all_issue_ids=all_issue_ids,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
next_page_token=next_page_token,
|
||||
ids_done=ids_done,
|
||||
)
|
||||
else:
|
||||
yield from self._perform_jql_search_v2(
|
||||
jql=jql,
|
||||
start=start,
|
||||
max_results=max_results,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
def _perform_jql_search_v3(
|
||||
self,
|
||||
*,
|
||||
jql: str,
|
||||
max_results: int,
|
||||
all_issue_ids: list[list[str]],
|
||||
fields: str | None = None,
|
||||
checkpoint_callback: Callable[[Iterable[list[str]], str | None], None] | None = None,
|
||||
next_page_token: str | None = None,
|
||||
ids_done: bool = False,
|
||||
) -> Iterable[Issue]:
|
||||
assert self.jira_client, "Jira client not initialized."
|
||||
|
||||
if not ids_done:
|
||||
new_ids, page_token = self._enhanced_search_ids(jql, next_page_token)
|
||||
if checkpoint_callback is not None and new_ids:
|
||||
checkpoint_callback(
|
||||
self._chunk_issue_ids(new_ids, max_results),
|
||||
page_token,
|
||||
)
|
||||
elif checkpoint_callback is not None:
|
||||
checkpoint_callback([], page_token)
|
||||
|
||||
if all_issue_ids:
|
||||
issue_ids = all_issue_ids.pop()
|
||||
if issue_ids:
|
||||
yield from self._bulk_fetch_issues(issue_ids, fields)
|
||||
|
||||
def _perform_jql_search_v2(
|
||||
self,
|
||||
*,
|
||||
jql: str,
|
||||
start: int,
|
||||
max_results: int,
|
||||
fields: str | None = None,
|
||||
) -> Iterable[Issue]:
|
||||
assert self.jira_client, "Jira client not initialized."
|
||||
|
||||
issues = self.jira_client.search_issues(
|
||||
jql_str=jql,
|
||||
startAt=start,
|
||||
maxResults=max_results,
|
||||
fields=fields or self._fields_param,
|
||||
expand="renderedFields",
|
||||
)
|
||||
for issue in issues:
|
||||
yield issue
|
||||
|
||||
def _enhanced_search_ids(
|
||||
self,
|
||||
jql: str,
|
||||
next_page_token: str | None,
|
||||
) -> tuple[list[str], str | None]:
|
||||
assert self.jira_client, "Jira client not initialized."
|
||||
enhanced_search_path = self.jira_client._get_url("search/jql")
|
||||
params: dict[str, str | int | None] = {
|
||||
"jql": jql,
|
||||
"maxResults": _MAX_RESULTS_FETCH_IDS,
|
||||
"nextPageToken": next_page_token,
|
||||
"fields": "id",
|
||||
}
|
||||
response = self.jira_client._session.get(enhanced_search_path, params=params)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return [str(issue["id"]) for issue in data.get("issues", [])], data.get("nextPageToken")
|
||||
|
||||
def _bulk_fetch_issues(
|
||||
self,
|
||||
issue_ids: list[str],
|
||||
fields: str | None,
|
||||
) -> Iterable[Issue]:
|
||||
assert self.jira_client, "Jira client not initialized."
|
||||
if not issue_ids:
|
||||
return []
|
||||
|
||||
bulk_fetch_path = self.jira_client._get_url("issue/bulkfetch")
|
||||
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
|
||||
payload["fields"] = fields.split(",") if fields else ["*all"]
|
||||
|
||||
response = self.jira_client._session.post(bulk_fetch_path, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return [Issue(self.jira_client._options, self.jira_client._session, raw=issue) for issue in data.get("issues", [])]
|
||||
|
||||
@staticmethod
|
||||
def _chunk_issue_ids(issue_ids: list[str], chunk_size: int) -> Iterable[list[str]]:
|
||||
if chunk_size <= 0:
|
||||
chunk_size = _JIRA_FULL_PAGE_SIZE
|
||||
|
||||
for idx in range(0, len(issue_ids), chunk_size):
|
||||
yield issue_ids[idx : idx + chunk_size]
|
||||
|
||||
def _make_checkpoint_callback(self, checkpoint: JiraCheckpoint) -> Callable[[Iterable[list[str]], str | None], None]:
|
||||
def checkpoint_callback(
|
||||
issue_ids: Iterable[list[str]] | list[list[str]],
|
||||
page_token: str | None,
|
||||
) -> None:
|
||||
for id_batch in issue_ids:
|
||||
checkpoint.all_issue_ids.append(list(id_batch))
|
||||
checkpoint.cursor = page_token
|
||||
checkpoint.ids_done = page_token is None
|
||||
|
||||
return checkpoint_callback
|
||||
|
||||
def _update_checkpoint_for_next_run(
|
||||
self,
|
||||
*,
|
||||
checkpoint: JiraCheckpoint,
|
||||
current_offset: int,
|
||||
starting_offset: int,
|
||||
page_size: int,
|
||||
) -> None:
|
||||
if self._is_cloud_client():
|
||||
checkpoint.has_more = bool(checkpoint.all_issue_ids) or not checkpoint.ids_done
|
||||
else:
|
||||
checkpoint.has_more = current_offset - starting_offset == page_size
|
||||
checkpoint.cursor = None
|
||||
checkpoint.ids_done = True
|
||||
checkpoint.all_issue_ids = []
|
||||
|
||||
|
||||
def iterate_jira_documents(
|
||||
connector: "JiraConnector",
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
iteration_limit: int = 100_000,
|
||||
) -> Iterator[Document]:
|
||||
"""Yield documents without materializing the entire result set."""
|
||||
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
iterations = 0
|
||||
|
||||
while checkpoint.has_more:
|
||||
wrapper = CheckpointOutputWrapper[JiraCheckpoint]()
|
||||
generator = wrapper(connector.load_from_checkpoint(start=start, end=end, checkpoint=checkpoint))
|
||||
|
||||
for document, failure, next_checkpoint in generator:
|
||||
if failure is not None:
|
||||
failure_message = getattr(failure, "failure_message", str(failure))
|
||||
raise RuntimeError(f"Failed to load Jira documents: {failure_message}")
|
||||
if document is not None:
|
||||
yield document
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
iterations += 1
|
||||
if iterations > iteration_limit:
|
||||
raise RuntimeError("Too many iterations while loading Jira documents.")
|
||||
|
||||
|
||||
def test_jira(
|
||||
*,
|
||||
base_url: str,
|
||||
project_key: str | None = None,
|
||||
jql_query: str | None = None,
|
||||
credentials: dict[str, Any],
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
start_ts: float | None = None,
|
||||
end_ts: float | None = None,
|
||||
connector_options: dict[str, Any] | None = None,
|
||||
) -> list[Document]:
|
||||
"""Programmatic entry point that mirrors the CLI workflow."""
|
||||
|
||||
connector_kwargs = connector_options.copy() if connector_options else {}
|
||||
connector = JiraConnector(
|
||||
jira_base_url=base_url,
|
||||
project_key=project_key,
|
||||
jql_query=jql_query,
|
||||
batch_size=batch_size,
|
||||
**connector_kwargs,
|
||||
)
|
||||
connector.load_credentials(credentials)
|
||||
connector.validate_connector_settings()
|
||||
|
||||
now_ts = datetime.now(timezone.utc).timestamp()
|
||||
start = start_ts if start_ts is not None else 0.0
|
||||
end = end_ts if end_ts is not None else now_ts
|
||||
|
||||
documents = list(iterate_jira_documents(connector, start=start, end=end))
|
||||
logger.info(f"[Jira] Fetched {len(documents)} Jira documents.")
|
||||
for doc in documents[:5]:
|
||||
logger.info(f"[Jira] Document {doc.semantic_identifier} ({doc.id}) size={doc.size_bytes} bytes")
|
||||
return documents
|
||||
|
||||
|
||||
def _build_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="Fetch Jira issues and print summary statistics.")
|
||||
parser.add_argument("--base-url", dest="base_url", default=os.environ.get("JIRA_BASE_URL"))
|
||||
parser.add_argument("--project", dest="project_key", default=os.environ.get("JIRA_PROJECT_KEY"))
|
||||
parser.add_argument("--jql", dest="jql_query", default=os.environ.get("JIRA_JQL"))
|
||||
parser.add_argument("--email", dest="user_email", default=os.environ.get("JIRA_USER_EMAIL"))
|
||||
parser.add_argument("--token", dest="api_token", default=os.environ.get("JIRA_API_TOKEN"))
|
||||
parser.add_argument("--password", dest="password", default=os.environ.get("JIRA_PASSWORD"))
|
||||
parser.add_argument("--batch-size", dest="batch_size", type=int, default=int(os.environ.get("JIRA_BATCH_SIZE", INDEX_BATCH_SIZE)))
|
||||
parser.add_argument("--include_comments", dest="include_comments", type=bool, default=True)
|
||||
parser.add_argument("--include_attachments", dest="include_attachments", type=bool, default=True)
|
||||
parser.add_argument("--attachment_size_limit", dest="attachment_size_limit", type=float, default=_DEFAULT_ATTACHMENT_SIZE_LIMIT)
|
||||
parser.add_argument("--start-ts", dest="start_ts", type=float, default=None, help="Epoch seconds inclusive lower bound for updated issues.")
|
||||
parser.add_argument("--end-ts", dest="end_ts", type=float, default=9999999999, help="Epoch seconds inclusive upper bound for updated issues.")
|
||||
return parser
|
||||
|
||||
|
||||
def main(config: dict[str, Any] | None = None) -> None:
|
||||
if config is None:
|
||||
args = _build_arg_parser().parse_args()
|
||||
config = {
|
||||
"base_url": args.base_url,
|
||||
"project_key": args.project_key,
|
||||
"jql_query": args.jql_query,
|
||||
"batch_size": args.batch_size,
|
||||
"start_ts": args.start_ts,
|
||||
"end_ts": args.end_ts,
|
||||
"include_comments": args.include_comments,
|
||||
"include_attachments": args.include_attachments,
|
||||
"attachment_size_limit": args.attachment_size_limit,
|
||||
"credentials": {
|
||||
"jira_user_email": args.user_email,
|
||||
"jira_api_token": args.api_token,
|
||||
"jira_password": args.password,
|
||||
},
|
||||
}
|
||||
|
||||
base_url = config.get("base_url")
|
||||
credentials = config.get("credentials", {})
|
||||
|
||||
print(f"[Jira] {config=}", flush=True)
|
||||
print(f"[Jira] {credentials=}", flush=True)
|
||||
|
||||
if not base_url:
|
||||
raise RuntimeError("Jira base URL must be provided via config or CLI arguments.")
|
||||
if not (credentials.get("jira_api_token") or (credentials.get("jira_user_email") and credentials.get("jira_password"))):
|
||||
raise RuntimeError("Provide either an API token or both email/password for Jira authentication.")
|
||||
|
||||
connector_options = {
|
||||
key: value
|
||||
for key, value in (
|
||||
("include_comments", config.get("include_comments")),
|
||||
("include_attachments", config.get("include_attachments")),
|
||||
("attachment_size_limit", config.get("attachment_size_limit")),
|
||||
("labels_to_skip", config.get("labels_to_skip")),
|
||||
("comment_email_blacklist", config.get("comment_email_blacklist")),
|
||||
("scoped_token", config.get("scoped_token")),
|
||||
("timezone_offset", config.get("timezone_offset")),
|
||||
)
|
||||
if value is not None
|
||||
}
|
||||
|
||||
documents = test_jira(
|
||||
base_url=base_url,
|
||||
project_key=config.get("project_key"),
|
||||
jql_query=config.get("jql_query"),
|
||||
credentials=credentials,
|
||||
batch_size=config.get("batch_size", INDEX_BATCH_SIZE),
|
||||
start_ts=config.get("start_ts"),
|
||||
end_ts=config.get("end_ts"),
|
||||
connector_options=connector_options,
|
||||
)
|
||||
|
||||
preview_count = min(len(documents), 5)
|
||||
for idx in range(preview_count):
|
||||
doc = documents[idx]
|
||||
print(f"[Jira] [Sample {idx + 1}] {doc.semantic_identifier} | id={doc.id} | size={doc.size_bytes} bytes")
|
||||
|
||||
print(f"[Jira] Jira connector test completed. Documents fetched: {len(documents)}")
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover - manual execution path
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
main()
|
||||
149
common/data_source/jira/utils.py
Normal file
149
common/data_source/jira/utils.py
Normal file
@ -0,0 +1,149 @@
|
||||
"""Helper utilities for the Jira connector."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Collection
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Iterable
|
||||
|
||||
from jira.resources import Issue
|
||||
|
||||
from common.data_source.utils import datetime_from_string
|
||||
|
||||
JIRA_SERVER_API_VERSION = os.environ.get("JIRA_SERVER_API_VERSION", "2")
|
||||
JIRA_CLOUD_API_VERSION = os.environ.get("JIRA_CLOUD_API_VERSION", "3")
|
||||
|
||||
|
||||
def build_issue_url(base_url: str, issue_key: str) -> str:
|
||||
"""Return the canonical UI URL for a Jira issue."""
|
||||
return f"{base_url.rstrip('/')}/browse/{issue_key}"
|
||||
|
||||
|
||||
def parse_jira_datetime(value: Any) -> datetime | None:
|
||||
"""Best-effort parse of Jira datetime values to aware UTC datetimes."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value.astimezone(timezone.utc) if value.tzinfo else value.replace(tzinfo=timezone.utc)
|
||||
if isinstance(value, str):
|
||||
return datetime_from_string(value)
|
||||
return None
|
||||
|
||||
|
||||
def extract_named_value(value: Any) -> str | None:
|
||||
"""Extract a readable string out of Jira's typed objects."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
return value.get("name") or value.get("value")
|
||||
return getattr(value, "name", None)
|
||||
|
||||
|
||||
def extract_user(value: Any) -> tuple[str | None, str | None]:
|
||||
"""Return display name + email tuple for a Jira user blob."""
|
||||
if value is None:
|
||||
return None, None
|
||||
if isinstance(value, dict):
|
||||
return value.get("displayName"), value.get("emailAddress")
|
||||
|
||||
display = getattr(value, "displayName", None)
|
||||
email = getattr(value, "emailAddress", None)
|
||||
return display, email
|
||||
|
||||
|
||||
def extract_text_from_adf(adf: Any) -> str:
|
||||
"""Flatten Atlassian Document Format (ADF) structures to text."""
|
||||
texts: list[str] = []
|
||||
|
||||
def _walk(node: Any) -> None:
|
||||
if node is None:
|
||||
return
|
||||
if isinstance(node, dict):
|
||||
node_type = node.get("type")
|
||||
if node_type == "text":
|
||||
texts.append(node.get("text", ""))
|
||||
for child in node.get("content", []):
|
||||
_walk(child)
|
||||
elif isinstance(node, list):
|
||||
for child in node:
|
||||
_walk(child)
|
||||
|
||||
_walk(adf)
|
||||
return "\n".join(part for part in texts if part)
|
||||
|
||||
|
||||
def extract_body_text(value: Any) -> str:
|
||||
"""Normalize Jira description/comments (raw/adf/str) into plain text."""
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value.strip()
|
||||
if isinstance(value, dict):
|
||||
return extract_text_from_adf(value).strip()
|
||||
return str(value).strip()
|
||||
|
||||
|
||||
def format_comments(
|
||||
comment_block: Any,
|
||||
*,
|
||||
blacklist: Collection[str],
|
||||
) -> str:
|
||||
"""Convert Jira comments into a markdown-ish bullet list."""
|
||||
if not isinstance(comment_block, dict):
|
||||
return ""
|
||||
|
||||
comments = comment_block.get("comments") or []
|
||||
lines: list[str] = []
|
||||
normalized_blacklist = {email.lower() for email in blacklist if email}
|
||||
|
||||
for comment in comments:
|
||||
author = comment.get("author") or {}
|
||||
author_email = (author.get("emailAddress") or "").lower()
|
||||
if author_email and author_email in normalized_blacklist:
|
||||
continue
|
||||
|
||||
author_name = author.get("displayName") or author.get("name") or author_email or "Unknown"
|
||||
created = parse_jira_datetime(comment.get("created"))
|
||||
created_str = created.isoformat() if created else "Unknown time"
|
||||
body = extract_body_text(comment.get("body"))
|
||||
if not body:
|
||||
continue
|
||||
|
||||
lines.append(f"- {author_name} ({created_str}):\n{body}")
|
||||
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
def format_attachments(attachments: Any) -> str:
|
||||
"""List Jira attachments as bullet points."""
|
||||
if not isinstance(attachments, list):
|
||||
return ""
|
||||
|
||||
attachment_lines: list[str] = []
|
||||
for attachment in attachments:
|
||||
filename = attachment.get("filename")
|
||||
if not filename:
|
||||
continue
|
||||
size = attachment.get("size")
|
||||
size_text = f" ({size} bytes)" if isinstance(size, int) else ""
|
||||
content_url = attachment.get("content") or ""
|
||||
url_suffix = f" -> {content_url}" if content_url else ""
|
||||
attachment_lines.append(f"- {filename}{size_text}{url_suffix}")
|
||||
|
||||
return "\n".join(attachment_lines)
|
||||
|
||||
|
||||
def should_skip_issue(issue: Issue, labels_to_skip: set[str]) -> bool:
|
||||
"""Return True if the issue contains any label from the skip list."""
|
||||
if not labels_to_skip:
|
||||
return False
|
||||
|
||||
fields = getattr(issue, "raw", {}).get("fields", {})
|
||||
labels: Iterable[str] = fields.get("labels") or []
|
||||
for label in labels:
|
||||
if (label or "").lower() in labels_to_skip:
|
||||
return True
|
||||
return False
|
||||
@ -1,112 +0,0 @@
|
||||
"""Jira connector"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorValidationError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError, ConnectorMissingCredentialError
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
CheckpointedConnectorWithPermSync,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.models import (
|
||||
ConnectorCheckpoint
|
||||
)
|
||||
|
||||
|
||||
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||
"""Jira connector for accessing Jira issues and projects"""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.jira_client: JIRA | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load Jira credentials"""
|
||||
try:
|
||||
url = credentials.get("url")
|
||||
username = credentials.get("username")
|
||||
password = credentials.get("password")
|
||||
token = credentials.get("token")
|
||||
|
||||
if not url:
|
||||
raise ConnectorMissingCredentialError("Jira URL is required")
|
||||
|
||||
if token:
|
||||
# API token authentication
|
||||
self.jira_client = JIRA(server=url, token_auth=token)
|
||||
elif username and password:
|
||||
# Basic authentication
|
||||
self.jira_client = JIRA(server=url, basic_auth=(username, password))
|
||||
else:
|
||||
raise ConnectorMissingCredentialError("Jira credentials are incomplete")
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"Jira: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Jira connector settings"""
|
||||
if not self.jira_client:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
try:
|
||||
# Test connection by getting server info
|
||||
self.jira_client.server_info()
|
||||
except Exception as e:
|
||||
if "401" in str(e) or "403" in str(e):
|
||||
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
|
||||
elif "404" in str(e):
|
||||
raise ConnectorValidationError("Jira instance not found")
|
||||
else:
|
||||
raise UnexpectedValidationError(f"Jira validation error: {e}")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll Jira for recent issues"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Any:
|
||||
"""Load documents from checkpoint"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Any:
|
||||
"""Load documents from checkpoint with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||
"""Build dummy checkpoint"""
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||
"""Validate checkpoint JSON"""
|
||||
# Simplified implementation
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> Any:
|
||||
"""Retrieve all simplified documents with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
@ -94,6 +94,7 @@ class Document(BaseModel):
|
||||
blob: bytes
|
||||
doc_updated_at: datetime
|
||||
size_bytes: int
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class BasicExpertInfo(BaseModel):
|
||||
|
||||
378
common/data_source/moodle_connector.py
Normal file
378
common/data_source/moodle_connector.py
Normal file
@ -0,0 +1,378 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from retry import retry
|
||||
from typing import Any, Optional
|
||||
|
||||
from markdownify import markdownify as md
|
||||
from moodle import Moodle as MoodleClient, MoodleException
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
CredentialExpiredError,
|
||||
InsufficientPermissionsError,
|
||||
ConnectorValidationError,
|
||||
)
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
||||
from common.data_source.models import Document
|
||||
from common.data_source.utils import batch_generator, rl_requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MoodleConnector(LoadConnector, PollConnector):
|
||||
"""Moodle LMS connector for accessing course content"""
|
||||
|
||||
def __init__(self, moodle_url: str, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.moodle_url = moodle_url.rstrip("/")
|
||||
self.batch_size = batch_size
|
||||
self.moodle_client: Optional[MoodleClient] = None
|
||||
|
||||
def _add_token_to_url(self, file_url: str) -> str:
|
||||
"""Append Moodle token to URL if missing"""
|
||||
if not self.moodle_client:
|
||||
return file_url
|
||||
token = getattr(self.moodle_client, "token", "")
|
||||
if "token=" in file_url.lower():
|
||||
return file_url
|
||||
delimiter = "&" if "?" in file_url else "?"
|
||||
return f"{file_url}{delimiter}token={token}"
|
||||
|
||||
def _log_error(self, context: str, error: Exception, level: str = "warning") -> None:
|
||||
"""Simplified logging wrapper"""
|
||||
msg = f"{context}: {error}"
|
||||
if level == "error":
|
||||
logger.error(msg)
|
||||
else:
|
||||
logger.warning(msg)
|
||||
|
||||
def _get_latest_timestamp(self, *timestamps: int) -> int:
|
||||
"""Return latest valid timestamp"""
|
||||
return max((t for t in timestamps if t and t > 0), default=0)
|
||||
|
||||
def _yield_in_batches(
|
||||
self, generator: Generator[Document, None, None]
|
||||
) -> Generator[list[Document], None, None]:
|
||||
for batch in batch_generator(generator, self.batch_size):
|
||||
yield batch
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
token = credentials.get("moodle_token")
|
||||
if not token:
|
||||
raise ConnectorMissingCredentialError("Moodle API token is required")
|
||||
|
||||
try:
|
||||
self.moodle_client = MoodleClient(
|
||||
self.moodle_url + "/webservice/rest/server.php", token
|
||||
)
|
||||
self.moodle_client.core.webservice.get_site_info()
|
||||
except MoodleException as e:
|
||||
if "invalidtoken" in str(e).lower():
|
||||
raise CredentialExpiredError("Moodle token is invalid or expired")
|
||||
raise ConnectorMissingCredentialError(f"Failed to initialize Moodle client: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if not self.moodle_client:
|
||||
raise ConnectorMissingCredentialError("Moodle client not initialized")
|
||||
|
||||
try:
|
||||
site_info = self.moodle_client.core.webservice.get_site_info()
|
||||
if not site_info.sitename:
|
||||
raise InsufficientPermissionsError("Invalid Moodle API response")
|
||||
except MoodleException as e:
|
||||
msg = str(e).lower()
|
||||
if "invalidtoken" in msg:
|
||||
raise CredentialExpiredError("Moodle token is invalid or expired")
|
||||
if "accessexception" in msg:
|
||||
raise InsufficientPermissionsError(
|
||||
"Insufficient permissions. Ensure web services are enabled and permissions are correct."
|
||||
)
|
||||
raise ConnectorValidationError(f"Moodle validation error: {e}")
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(f"Unexpected validation error: {e}")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Data loading & polling
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||
if not self.moodle_client:
|
||||
raise ConnectorMissingCredentialError("Moodle client not initialized")
|
||||
|
||||
logger.info("Starting full load from Moodle workspace")
|
||||
courses = self._get_enrolled_courses()
|
||||
if not courses:
|
||||
logger.warning("No courses found to process")
|
||||
return
|
||||
|
||||
yield from self._yield_in_batches(self._process_courses(courses))
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> Generator[list[Document], None, None]:
|
||||
if not self.moodle_client:
|
||||
raise ConnectorMissingCredentialError("Moodle client not initialized")
|
||||
|
||||
logger.info(
|
||||
f"Polling Moodle updates between {datetime.fromtimestamp(start)} and {datetime.fromtimestamp(end)}"
|
||||
)
|
||||
courses = self._get_enrolled_courses()
|
||||
if not courses:
|
||||
logger.warning("No courses found to poll")
|
||||
return
|
||||
|
||||
yield from self._yield_in_batches(self._get_updated_content(courses, start, end))
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _get_enrolled_courses(self) -> list:
|
||||
if not self.moodle_client:
|
||||
raise ConnectorMissingCredentialError("Moodle client not initialized")
|
||||
|
||||
try:
|
||||
return self.moodle_client.core.course.get_courses()
|
||||
except MoodleException as e:
|
||||
self._log_error("fetching courses", e, "error")
|
||||
raise ConnectorValidationError(f"Failed to fetch courses: {e}")
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _get_course_contents(self, course_id: int):
|
||||
if not self.moodle_client:
|
||||
raise ConnectorMissingCredentialError("Moodle client not initialized")
|
||||
|
||||
try:
|
||||
return self.moodle_client.core.course.get_contents(courseid=course_id)
|
||||
except MoodleException as e:
|
||||
self._log_error(f"fetching course contents for {course_id}", e)
|
||||
return []
|
||||
|
||||
def _process_courses(self, courses) -> Generator[Document, None, None]:
|
||||
for course in courses:
|
||||
try:
|
||||
contents = self._get_course_contents(course.id)
|
||||
for section in contents:
|
||||
for module in section.modules:
|
||||
doc = self._process_module(course, section, module)
|
||||
if doc:
|
||||
yield doc
|
||||
except Exception as e:
|
||||
self._log_error(f"processing course {course.fullname}", e)
|
||||
|
||||
def _get_updated_content(
|
||||
self, courses, start: float, end: float
|
||||
) -> Generator[Document, None, None]:
|
||||
for course in courses:
|
||||
try:
|
||||
contents = self._get_course_contents(course.id)
|
||||
for section in contents:
|
||||
for module in section.modules:
|
||||
times = [
|
||||
getattr(module, "timecreated", 0),
|
||||
getattr(module, "timemodified", 0),
|
||||
]
|
||||
if hasattr(module, "contents"):
|
||||
times.extend(
|
||||
getattr(c, "timemodified", 0)
|
||||
for c in module.contents
|
||||
if c and getattr(c, "timemodified", 0)
|
||||
)
|
||||
last_mod = self._get_latest_timestamp(*times)
|
||||
if start < last_mod <= end:
|
||||
doc = self._process_module(course, section, module)
|
||||
if doc:
|
||||
yield doc
|
||||
except Exception as e:
|
||||
self._log_error(f"polling course {course.fullname}", e)
|
||||
|
||||
def _process_module(
|
||||
self, course, section, module
|
||||
) -> Optional[Document]:
|
||||
try:
|
||||
mtype = module.modname
|
||||
if mtype in ["label", "url"]:
|
||||
return None
|
||||
if mtype == "resource":
|
||||
return self._process_resource(course, section, module)
|
||||
if mtype == "forum":
|
||||
return self._process_forum(course, section, module)
|
||||
if mtype == "page":
|
||||
return self._process_page(course, section, module)
|
||||
if mtype in ["assign", "quiz"]:
|
||||
return self._process_activity(course, section, module)
|
||||
if mtype == "book":
|
||||
return self._process_book(course, section, module)
|
||||
except Exception as e:
|
||||
self._log_error(f"processing module {getattr(module, 'name', '?')}", e)
|
||||
return None
|
||||
|
||||
def _process_resource(self, course, section, module) -> Optional[Document]:
|
||||
if not getattr(module, "contents", None):
|
||||
return None
|
||||
|
||||
file_info = module.contents[0]
|
||||
if not getattr(file_info, "fileurl", None):
|
||||
return None
|
||||
|
||||
file_name = os.path.basename(file_info.filename)
|
||||
ts = self._get_latest_timestamp(
|
||||
getattr(module, "timecreated", 0),
|
||||
getattr(module, "timemodified", 0),
|
||||
getattr(file_info, "timemodified", 0),
|
||||
)
|
||||
|
||||
try:
|
||||
resp = rl_requests.get(self._add_token_to_url(file_info.fileurl), timeout=60)
|
||||
resp.raise_for_status()
|
||||
blob = resp.content
|
||||
ext = os.path.splitext(file_name)[1] or ".bin"
|
||||
semantic_id = f"{course.fullname} / {section.name} / {file_name}"
|
||||
return Document(
|
||||
id=f"moodle_resource_{module.id}",
|
||||
source="moodle",
|
||||
semantic_identifier=semantic_id,
|
||||
extension=ext,
|
||||
blob=blob,
|
||||
doc_updated_at=datetime.fromtimestamp(ts or 0, tz=timezone.utc),
|
||||
size_bytes=len(blob),
|
||||
)
|
||||
except Exception as e:
|
||||
self._log_error(f"downloading resource {file_name}", e, "error")
|
||||
return None
|
||||
|
||||
def _process_forum(self, course, section, module) -> Optional[Document]:
|
||||
if not self.moodle_client or not getattr(module, "instance", None):
|
||||
return None
|
||||
|
||||
try:
|
||||
result = self.moodle_client.mod.forum.get_forum_discussions(forumid=module.instance)
|
||||
disc_list = getattr(result, "discussions", [])
|
||||
if not disc_list:
|
||||
return None
|
||||
|
||||
markdown = [f"# {module.name}\n"]
|
||||
latest_ts = self._get_latest_timestamp(
|
||||
getattr(module, "timecreated", 0),
|
||||
getattr(module, "timemodified", 0),
|
||||
)
|
||||
|
||||
for d in disc_list:
|
||||
markdown.append(f"## {d.name}\n\n{md(d.message or '')}\n\n---\n")
|
||||
latest_ts = max(latest_ts, getattr(d, "timemodified", 0))
|
||||
|
||||
blob = "\n".join(markdown).encode("utf-8")
|
||||
semantic_id = f"{course.fullname} / {section.name} / {module.name}"
|
||||
return Document(
|
||||
id=f"moodle_forum_{module.id}",
|
||||
source="moodle",
|
||||
semantic_identifier=semantic_id,
|
||||
extension=".md",
|
||||
blob=blob,
|
||||
doc_updated_at=datetime.fromtimestamp(latest_ts or 0, tz=timezone.utc),
|
||||
size_bytes=len(blob),
|
||||
)
|
||||
except Exception as e:
|
||||
self._log_error(f"processing forum {module.name}", e)
|
||||
return None
|
||||
|
||||
def _process_page(self, course, section, module) -> Optional[Document]:
|
||||
if not getattr(module, "contents", None):
|
||||
return None
|
||||
|
||||
file_info = module.contents[0]
|
||||
if not getattr(file_info, "fileurl", None):
|
||||
return None
|
||||
|
||||
file_name = os.path.basename(file_info.filename)
|
||||
ts = self._get_latest_timestamp(
|
||||
getattr(module, "timecreated", 0),
|
||||
getattr(module, "timemodified", 0),
|
||||
getattr(file_info, "timemodified", 0),
|
||||
)
|
||||
|
||||
try:
|
||||
resp = rl_requests.get(self._add_token_to_url(file_info.fileurl), timeout=60)
|
||||
resp.raise_for_status()
|
||||
blob = resp.content
|
||||
ext = os.path.splitext(file_name)[1] or ".html"
|
||||
semantic_id = f"{course.fullname} / {section.name} / {module.name}"
|
||||
return Document(
|
||||
id=f"moodle_page_{module.id}",
|
||||
source="moodle",
|
||||
semantic_identifier=semantic_id,
|
||||
extension=ext,
|
||||
blob=blob,
|
||||
doc_updated_at=datetime.fromtimestamp(ts or 0, tz=timezone.utc),
|
||||
size_bytes=len(blob),
|
||||
)
|
||||
except Exception as e:
|
||||
self._log_error(f"processing page {file_name}", e, "error")
|
||||
return None
|
||||
|
||||
def _process_activity(self, course, section, module) -> Optional[Document]:
|
||||
desc = getattr(module, "description", "")
|
||||
if not desc:
|
||||
return None
|
||||
|
||||
mtype, mname = module.modname, module.name
|
||||
markdown = f"# {mname}\n\n**Type:** {mtype.capitalize()}\n\n{md(desc)}"
|
||||
ts = self._get_latest_timestamp(
|
||||
getattr(module, "timecreated", 0),
|
||||
getattr(module, "timemodified", 0),
|
||||
getattr(module, "added", 0),
|
||||
)
|
||||
|
||||
semantic_id = f"{course.fullname} / {section.name} / {mname}"
|
||||
blob = markdown.encode("utf-8")
|
||||
return Document(
|
||||
id=f"moodle_{mtype}_{module.id}",
|
||||
source="moodle",
|
||||
semantic_identifier=semantic_id,
|
||||
extension=".md",
|
||||
blob=blob,
|
||||
doc_updated_at=datetime.fromtimestamp(ts or 0, tz=timezone.utc),
|
||||
size_bytes=len(blob),
|
||||
)
|
||||
|
||||
def _process_book(self, course, section, module) -> Optional[Document]:
|
||||
if not getattr(module, "contents", None):
|
||||
return None
|
||||
|
||||
contents = module.contents
|
||||
chapters = [
|
||||
c for c in contents
|
||||
if getattr(c, "fileurl", None) and os.path.basename(c.filename) == "index.html"
|
||||
]
|
||||
if not chapters:
|
||||
return None
|
||||
|
||||
latest_ts = self._get_latest_timestamp(
|
||||
getattr(module, "timecreated", 0),
|
||||
getattr(module, "timemodified", 0),
|
||||
*[getattr(c, "timecreated", 0) for c in contents],
|
||||
*[getattr(c, "timemodified", 0) for c in contents],
|
||||
)
|
||||
|
||||
markdown_parts = [f"# {module.name}\n"]
|
||||
for ch in chapters:
|
||||
try:
|
||||
resp = rl_requests.get(self._add_token_to_url(ch.fileurl), timeout=60)
|
||||
resp.raise_for_status()
|
||||
html = resp.content.decode("utf-8", errors="ignore")
|
||||
markdown_parts.append(md(html) + "\n\n---\n")
|
||||
except Exception as e:
|
||||
self._log_error(f"processing book chapter {ch.filename}", e)
|
||||
|
||||
blob = "\n".join(markdown_parts).encode("utf-8")
|
||||
semantic_id = f"{course.fullname} / {section.name} / {module.name}"
|
||||
return Document(
|
||||
id=f"moodle_book_{module.id}",
|
||||
source="moodle",
|
||||
semantic_identifier=semantic_id,
|
||||
extension=".md",
|
||||
blob=blob,
|
||||
doc_updated_at=datetime.fromtimestamp(latest_ts or 0, tz=timezone.utc),
|
||||
size_bytes=len(blob),
|
||||
)
|
||||
@ -1,38 +1,45 @@
|
||||
import html
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from retry import retry
|
||||
|
||||
from common.data_source.config import (
|
||||
INDEX_BATCH_SIZE,
|
||||
DocumentSource, NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP
|
||||
NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP,
|
||||
DocumentSource,
|
||||
)
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
ConnectorValidationError,
|
||||
CredentialExpiredError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError,
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SecondsSinceUnixEpoch
|
||||
SecondsSinceUnixEpoch,
|
||||
)
|
||||
from common.data_source.models import (
|
||||
Document,
|
||||
TextSection, GenerateDocumentsOutput
|
||||
)
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorValidationError,
|
||||
CredentialExpiredError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError, ConnectorMissingCredentialError
|
||||
)
|
||||
from common.data_source.models import (
|
||||
NotionPage,
|
||||
GenerateDocumentsOutput,
|
||||
NotionBlock,
|
||||
NotionSearchResponse
|
||||
NotionPage,
|
||||
NotionSearchResponse,
|
||||
TextSection,
|
||||
)
|
||||
from common.data_source.utils import (
|
||||
rl_requests,
|
||||
batch_generator,
|
||||
datetime_from_string,
|
||||
fetch_notion_data,
|
||||
filter_pages_by_time,
|
||||
properties_to_str,
|
||||
filter_pages_by_time, datetime_from_string
|
||||
rl_requests,
|
||||
)
|
||||
|
||||
|
||||
@ -61,11 +68,9 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
self.recursive_index_enabled = recursive_index_enabled or bool(root_page_id)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_child_blocks(
|
||||
self, block_id: str, cursor: Optional[str] = None
|
||||
) -> dict[str, Any] | None:
|
||||
def _fetch_child_blocks(self, block_id: str, cursor: Optional[str] = None) -> dict[str, Any] | None:
|
||||
"""Fetch all child blocks via the Notion API."""
|
||||
logging.debug(f"Fetching children of block with ID '{block_id}'")
|
||||
logging.debug(f"[Notion]: Fetching children of block with ID {block_id}")
|
||||
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
query_params = {"start_cursor": cursor} if cursor else None
|
||||
|
||||
@ -79,49 +84,42 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
if hasattr(e, 'response') and e.response.status_code == 404:
|
||||
logging.error(
|
||||
f"Unable to access block with ID '{block_id}'. "
|
||||
f"This is likely due to the block not being shared with the integration."
|
||||
)
|
||||
if hasattr(e, "response") and e.response.status_code == 404:
|
||||
logging.error(f"[Notion]: Unable to access block with ID {block_id}. This is likely due to the block not being shared with the integration.")
|
||||
return None
|
||||
else:
|
||||
logging.exception(f"Error fetching blocks: {e}")
|
||||
logging.exception(f"[Notion]: Error fetching blocks: {e}")
|
||||
raise
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_page(self, page_id: str) -> NotionPage:
|
||||
"""Fetch a page from its ID via the Notion API."""
|
||||
logging.debug(f"Fetching page for ID '{page_id}'")
|
||||
logging.debug(f"[Notion]: Fetching page for ID {page_id}")
|
||||
page_url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||
|
||||
try:
|
||||
data = fetch_notion_data(page_url, self.headers, "GET")
|
||||
return NotionPage(**data)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to fetch page, trying database for ID '{page_id}': {e}")
|
||||
logging.warning(f"[Notion]: Failed to fetch page, trying database for ID {page_id}: {e}")
|
||||
return self._fetch_database_as_page(page_id)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database_as_page(self, database_id: str) -> NotionPage:
|
||||
"""Attempt to fetch a database as a page."""
|
||||
logging.debug(f"Fetching database for ID '{database_id}' as a page")
|
||||
logging.debug(f"[Notion]: Fetching database for ID {database_id} as a page")
|
||||
database_url = f"https://api.notion.com/v1/databases/{database_id}"
|
||||
|
||||
data = fetch_notion_data(database_url, self.headers, "GET")
|
||||
database_name = data.get("title")
|
||||
database_name = (
|
||||
database_name[0].get("text", {}).get("content") if database_name else None
|
||||
)
|
||||
database_name = database_name[0].get("text", {}).get("content") if database_name else None
|
||||
|
||||
return NotionPage(**data, database_name=database_name)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database(
|
||||
self, database_id: str, cursor: Optional[str] = None
|
||||
) -> dict[str, Any]:
|
||||
def _fetch_database(self, database_id: str, cursor: Optional[str] = None) -> dict[str, Any]:
|
||||
"""Fetch a database from its ID via the Notion API."""
|
||||
logging.debug(f"Fetching database for ID '{database_id}'")
|
||||
logging.debug(f"[Notion]: Fetching database for ID {database_id}")
|
||||
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
|
||||
body = {"start_cursor": cursor} if cursor else None
|
||||
|
||||
@ -129,17 +127,12 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
data = fetch_notion_data(block_url, self.headers, "POST", body)
|
||||
return data
|
||||
except Exception as e:
|
||||
if hasattr(e, 'response') and e.response.status_code in [404, 400]:
|
||||
logging.error(
|
||||
f"Unable to access database with ID '{database_id}'. "
|
||||
f"This is likely due to the database not being shared with the integration."
|
||||
)
|
||||
if hasattr(e, "response") and e.response.status_code in [404, 400]:
|
||||
logging.error(f"[Notion]: Unable to access database with ID {database_id}. This is likely due to the database not being shared with the integration.")
|
||||
return {"results": [], "next_cursor": None}
|
||||
raise
|
||||
|
||||
def _read_pages_from_database(
|
||||
self, database_id: str
|
||||
) -> tuple[list[NotionBlock], list[str]]:
|
||||
def _read_pages_from_database(self, database_id: str) -> tuple[list[NotionBlock], list[str]]:
|
||||
"""Returns a list of top level blocks and all page IDs in the database."""
|
||||
result_blocks: list[NotionBlock] = []
|
||||
result_pages: list[str] = []
|
||||
@ -158,10 +151,10 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
if self.recursive_index_enabled:
|
||||
if obj_type == "page":
|
||||
logging.debug(f"Found page with ID '{obj_id}' in database '{database_id}'")
|
||||
logging.debug(f"[Notion]: Found page with ID {obj_id} in database {database_id}")
|
||||
result_pages.append(result["id"])
|
||||
elif obj_type == "database":
|
||||
logging.debug(f"Found database with ID '{obj_id}' in database '{database_id}'")
|
||||
logging.debug(f"[Notion]: Found database with ID {obj_id} in database {database_id}")
|
||||
_, child_pages = self._read_pages_from_database(obj_id)
|
||||
result_pages.extend(child_pages)
|
||||
|
||||
@ -172,44 +165,229 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
return result_blocks, result_pages
|
||||
|
||||
def _read_blocks(self, base_block_id: str) -> tuple[list[NotionBlock], list[str]]:
|
||||
"""Reads all child blocks for the specified block, returns blocks and child page ids."""
|
||||
def _extract_rich_text(self, rich_text_array: list[dict[str, Any]]) -> str:
|
||||
collected_text: list[str] = []
|
||||
for rich_text in rich_text_array:
|
||||
content = ""
|
||||
r_type = rich_text.get("type")
|
||||
|
||||
if r_type == "equation":
|
||||
expr = rich_text.get("equation", {}).get("expression")
|
||||
if expr:
|
||||
content = expr
|
||||
elif r_type == "mention":
|
||||
mention = rich_text.get("mention", {}) or {}
|
||||
mention_type = mention.get("type")
|
||||
mention_value = mention.get(mention_type, {}) if mention_type else {}
|
||||
if mention_type == "date":
|
||||
start = mention_value.get("start")
|
||||
end = mention_value.get("end")
|
||||
if start and end:
|
||||
content = f"{start} - {end}"
|
||||
elif start:
|
||||
content = start
|
||||
elif mention_type in {"page", "database"}:
|
||||
content = mention_value.get("id", rich_text.get("plain_text", ""))
|
||||
elif mention_type == "link_preview":
|
||||
content = mention_value.get("url", rich_text.get("plain_text", ""))
|
||||
else:
|
||||
content = rich_text.get("plain_text", "") or str(mention_value)
|
||||
else:
|
||||
if rich_text.get("plain_text"):
|
||||
content = rich_text["plain_text"]
|
||||
elif "text" in rich_text and rich_text["text"].get("content"):
|
||||
content = rich_text["text"]["content"]
|
||||
|
||||
href = rich_text.get("href")
|
||||
if content and href:
|
||||
content = f"{content} ({href})"
|
||||
|
||||
if content:
|
||||
collected_text.append(content)
|
||||
|
||||
return "".join(collected_text).strip()
|
||||
|
||||
def _build_table_html(self, table_block_id: str) -> str | None:
|
||||
rows: list[str] = []
|
||||
cursor = None
|
||||
while True:
|
||||
data = self._fetch_child_blocks(table_block_id, cursor)
|
||||
if data is None:
|
||||
break
|
||||
|
||||
for result in data["results"]:
|
||||
if result.get("type") != "table_row":
|
||||
continue
|
||||
cells_html: list[str] = []
|
||||
for cell in result["table_row"].get("cells", []):
|
||||
cell_text = self._extract_rich_text(cell)
|
||||
cell_html = html.escape(cell_text) if cell_text else ""
|
||||
cells_html.append(f"<td>{cell_html}</td>")
|
||||
rows.append(f"<tr>{''.join(cells_html)}</tr>")
|
||||
|
||||
if data.get("next_cursor") is None:
|
||||
break
|
||||
cursor = data["next_cursor"]
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
return "<table>\n" + "\n".join(rows) + "\n</table>"
|
||||
|
||||
def _download_file(self, url: str) -> bytes | None:
|
||||
try:
|
||||
response = rl_requests.get(url, timeout=60)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
except Exception as exc:
|
||||
logging.warning(f"[Notion]: Failed to download Notion file from {url}: {exc}")
|
||||
return None
|
||||
|
||||
def _extract_file_metadata(self, result_obj: dict[str, Any], block_id: str) -> tuple[str | None, str, str | None]:
|
||||
file_source_type = result_obj.get("type")
|
||||
file_source = result_obj.get(file_source_type, {}) if file_source_type else {}
|
||||
url = file_source.get("url")
|
||||
|
||||
name = result_obj.get("name") or file_source.get("name")
|
||||
if url and not name:
|
||||
parsed_name = Path(urlparse(url).path).name
|
||||
name = parsed_name or f"notion_file_{block_id}"
|
||||
elif not name:
|
||||
name = f"notion_file_{block_id}"
|
||||
|
||||
caption = self._extract_rich_text(result_obj.get("caption", [])) if "caption" in result_obj else None
|
||||
|
||||
return url, name, caption
|
||||
|
||||
def _build_attachment_document(
|
||||
self,
|
||||
block_id: str,
|
||||
url: str,
|
||||
name: str,
|
||||
caption: Optional[str],
|
||||
page_last_edited_time: Optional[str],
|
||||
) -> Document | None:
|
||||
file_bytes = self._download_file(url)
|
||||
if file_bytes is None:
|
||||
return None
|
||||
|
||||
extension = Path(name).suffix or Path(urlparse(url).path).suffix or ".bin"
|
||||
if extension and not extension.startswith("."):
|
||||
extension = f".{extension}"
|
||||
if not extension:
|
||||
extension = ".bin"
|
||||
|
||||
updated_at = datetime_from_string(page_last_edited_time) if page_last_edited_time else datetime.now(timezone.utc)
|
||||
semantic_identifier = caption or name or f"Notion file {block_id}"
|
||||
|
||||
return Document(
|
||||
id=block_id,
|
||||
blob=file_bytes,
|
||||
source=DocumentSource.NOTION,
|
||||
semantic_identifier=semantic_identifier,
|
||||
extension=extension,
|
||||
size_bytes=len(file_bytes),
|
||||
doc_updated_at=updated_at,
|
||||
)
|
||||
|
||||
def _read_blocks(self, base_block_id: str, page_last_edited_time: Optional[str] = None) -> tuple[list[NotionBlock], list[str], list[Document]]:
|
||||
result_blocks: list[NotionBlock] = []
|
||||
child_pages: list[str] = []
|
||||
attachments: list[Document] = []
|
||||
cursor = None
|
||||
|
||||
while True:
|
||||
data = self._fetch_child_blocks(base_block_id, cursor)
|
||||
|
||||
if data is None:
|
||||
return result_blocks, child_pages
|
||||
return result_blocks, child_pages, attachments
|
||||
|
||||
for result in data["results"]:
|
||||
logging.debug(f"Found child block for block with ID '{base_block_id}': {result}")
|
||||
logging.debug(f"[Notion]: Found child block for block with ID {base_block_id}: {result}")
|
||||
result_block_id = result["id"]
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
|
||||
if result_type in ["ai_block", "unsupported", "external_object_instance_page"]:
|
||||
logging.warning(f"Skipping unsupported block type '{result_type}'")
|
||||
logging.warning(f"[Notion]: Skipping unsupported block type {result_type}")
|
||||
continue
|
||||
|
||||
if result_type == "table":
|
||||
table_html = self._build_table_html(result_block_id)
|
||||
if table_html:
|
||||
result_blocks.append(
|
||||
NotionBlock(
|
||||
id=result_block_id,
|
||||
text=table_html,
|
||||
prefix="\n\n",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if result_type == "equation":
|
||||
expr = result_obj.get("expression")
|
||||
if expr:
|
||||
result_blocks.append(
|
||||
NotionBlock(
|
||||
id=result_block_id,
|
||||
text=expr,
|
||||
prefix="\n",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
cur_result_text_arr = []
|
||||
if "rich_text" in result_obj:
|
||||
for rich_text in result_obj["rich_text"]:
|
||||
if "text" in rich_text:
|
||||
text = rich_text["text"]["content"]
|
||||
cur_result_text_arr.append(text)
|
||||
text = self._extract_rich_text(result_obj["rich_text"])
|
||||
if text:
|
||||
cur_result_text_arr.append(text)
|
||||
|
||||
if result_type == "bulleted_list_item":
|
||||
if cur_result_text_arr:
|
||||
cur_result_text_arr[0] = f"- {cur_result_text_arr[0]}"
|
||||
else:
|
||||
cur_result_text_arr = ["- "]
|
||||
|
||||
if result_type == "numbered_list_item":
|
||||
if cur_result_text_arr:
|
||||
cur_result_text_arr[0] = f"1. {cur_result_text_arr[0]}"
|
||||
else:
|
||||
cur_result_text_arr = ["1. "]
|
||||
|
||||
if result_type == "to_do":
|
||||
checked = result_obj.get("checked")
|
||||
checkbox_prefix = "[x]" if checked else "[ ]"
|
||||
if cur_result_text_arr:
|
||||
cur_result_text_arr = [f"{checkbox_prefix} {cur_result_text_arr[0]}"] + cur_result_text_arr[1:]
|
||||
else:
|
||||
cur_result_text_arr = [checkbox_prefix]
|
||||
|
||||
if result_type in {"file", "image", "pdf", "video", "audio"}:
|
||||
file_url, file_name, caption = self._extract_file_metadata(result_obj, result_block_id)
|
||||
if file_url:
|
||||
attachment_doc = self._build_attachment_document(
|
||||
block_id=result_block_id,
|
||||
url=file_url,
|
||||
name=file_name,
|
||||
caption=caption,
|
||||
page_last_edited_time=page_last_edited_time,
|
||||
)
|
||||
if attachment_doc:
|
||||
attachments.append(attachment_doc)
|
||||
|
||||
attachment_label = caption or file_name
|
||||
if attachment_label:
|
||||
cur_result_text_arr.append(f"{result_type.capitalize()}: {attachment_label}")
|
||||
|
||||
if result["has_children"]:
|
||||
if result_type == "child_page":
|
||||
child_pages.append(result_block_id)
|
||||
else:
|
||||
logging.debug(f"Entering sub-block: {result_block_id}")
|
||||
subblocks, subblock_child_pages = self._read_blocks(result_block_id)
|
||||
logging.debug(f"Finished sub-block: {result_block_id}")
|
||||
logging.debug(f"[Notion]: Entering sub-block: {result_block_id}")
|
||||
subblocks, subblock_child_pages, subblock_attachments = self._read_blocks(result_block_id, page_last_edited_time)
|
||||
logging.debug(f"[Notion]: Finished sub-block: {result_block_id}")
|
||||
result_blocks.extend(subblocks)
|
||||
child_pages.extend(subblock_child_pages)
|
||||
attachments.extend(subblock_attachments)
|
||||
|
||||
if result_type == "child_database":
|
||||
inner_blocks, inner_child_pages = self._read_pages_from_database(result_block_id)
|
||||
@ -231,7 +409,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
cursor = data["next_cursor"]
|
||||
|
||||
return result_blocks, child_pages
|
||||
return result_blocks, child_pages, attachments
|
||||
|
||||
def _read_page_title(self, page: NotionPage) -> Optional[str]:
|
||||
"""Extracts the title from a Notion page."""
|
||||
@ -245,9 +423,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
return None
|
||||
|
||||
def _read_pages(
|
||||
self, pages: list[NotionPage]
|
||||
) -> Generator[Document, None, None]:
|
||||
def _read_pages(self, pages: list[NotionPage], start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None) -> Generator[Document, None, None]:
|
||||
"""Reads pages for rich text content and generates Documents."""
|
||||
all_child_page_ids: list[str] = []
|
||||
|
||||
@ -255,11 +431,17 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
if isinstance(page, dict):
|
||||
page = NotionPage(**page)
|
||||
if page.id in self.indexed_pages:
|
||||
logging.debug(f"Already indexed page with ID '{page.id}'. Skipping.")
|
||||
logging.debug(f"[Notion]: Already indexed page with ID {page.id}. Skipping.")
|
||||
continue
|
||||
|
||||
logging.info(f"Reading page with ID '{page.id}', with url {page.url}")
|
||||
page_blocks, child_page_ids = self._read_blocks(page.id)
|
||||
if start is not None and end is not None:
|
||||
page_ts = datetime_from_string(page.last_edited_time).timestamp()
|
||||
if not (page_ts > start and page_ts <= end):
|
||||
logging.debug(f"[Notion]: Skipping page {page.id} outside polling window.")
|
||||
continue
|
||||
|
||||
logging.info(f"[Notion]: Reading page with ID {page.id}, with url {page.url}")
|
||||
page_blocks, child_page_ids, attachment_docs = self._read_blocks(page.id, page.last_edited_time)
|
||||
all_child_page_ids.extend(child_page_ids)
|
||||
self.indexed_pages.add(page.id)
|
||||
|
||||
@ -268,14 +450,12 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
if not page_blocks:
|
||||
if not raw_page_title:
|
||||
logging.warning(f"No blocks OR title found for page with ID '{page.id}'. Skipping.")
|
||||
logging.warning(f"[Notion]: No blocks OR title found for page with ID {page.id}. Skipping.")
|
||||
continue
|
||||
|
||||
text = page_title
|
||||
if page.properties:
|
||||
text += "\n\n" + "\n".join(
|
||||
[f"{key}: {value}" for key, value in page.properties.items()]
|
||||
)
|
||||
text += "\n\n" + "\n".join([f"{key}: {value}" for key, value in page.properties.items()])
|
||||
sections = [TextSection(link=page.url, text=text)]
|
||||
else:
|
||||
sections = [
|
||||
@ -286,45 +466,39 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
for block in page_blocks
|
||||
]
|
||||
|
||||
blob = ("\n".join([sec.text for sec in sections])).encode("utf-8")
|
||||
joined_text = "\n".join(sec.text for sec in sections)
|
||||
blob = joined_text.encode("utf-8")
|
||||
yield Document(
|
||||
id=page.id,
|
||||
blob=blob,
|
||||
source=DocumentSource.NOTION,
|
||||
semantic_identifier=page_title,
|
||||
extension=".txt",
|
||||
size_bytes=len(blob),
|
||||
doc_updated_at=datetime_from_string(page.last_edited_time)
|
||||
id=page.id, blob=blob, source=DocumentSource.NOTION, semantic_identifier=page_title, extension=".txt", size_bytes=len(blob), doc_updated_at=datetime_from_string(page.last_edited_time)
|
||||
)
|
||||
|
||||
for attachment_doc in attachment_docs:
|
||||
yield attachment_doc
|
||||
|
||||
if self.recursive_index_enabled and all_child_page_ids:
|
||||
for child_page_batch_ids in batch_generator(all_child_page_ids, INDEX_BATCH_SIZE):
|
||||
child_page_batch = [
|
||||
self._fetch_page(page_id)
|
||||
for page_id in child_page_batch_ids
|
||||
if page_id not in self.indexed_pages
|
||||
]
|
||||
yield from self._read_pages(child_page_batch)
|
||||
child_page_batch = [self._fetch_page(page_id) for page_id in child_page_batch_ids if page_id not in self.indexed_pages]
|
||||
yield from self._read_pages(child_page_batch, start, end)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _search_notion(self, query_dict: dict[str, Any]) -> NotionSearchResponse:
|
||||
"""Search for pages from a Notion database."""
|
||||
logging.debug(f"Searching for pages in Notion with query_dict: {query_dict}")
|
||||
logging.debug(f"[Notion]: Searching for pages in Notion with query_dict: {query_dict}")
|
||||
data = fetch_notion_data("https://api.notion.com/v1/search", self.headers, "POST", query_dict)
|
||||
return NotionSearchResponse(**data)
|
||||
|
||||
def _recursive_load(self) -> Generator[list[Document], None, None]:
|
||||
def _recursive_load(self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None) -> Generator[list[Document], None, None]:
|
||||
"""Recursively load pages starting from root page ID."""
|
||||
if self.root_page_id is None or not self.recursive_index_enabled:
|
||||
raise RuntimeError("Recursive page lookup is not enabled")
|
||||
|
||||
logging.info(f"Recursively loading pages from Notion based on root page with ID: {self.root_page_id}")
|
||||
logging.info(f"[Notion]: Recursively loading pages from Notion based on root page with ID: {self.root_page_id}")
|
||||
pages = [self._fetch_page(page_id=self.root_page_id)]
|
||||
yield from batch_generator(self._read_pages(pages), self.batch_size)
|
||||
yield from batch_generator(self._read_pages(pages, start, end), self.batch_size)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Applies integration token to headers."""
|
||||
self.headers["Authorization"] = f'Bearer {credentials["notion_integration_token"]}'
|
||||
self.headers["Authorization"] = f"Bearer {credentials['notion_integration_token']}"
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
@ -348,12 +522,10 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
else:
|
||||
break
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput:
|
||||
"""Poll Notion for updated pages within a time period."""
|
||||
if self.recursive_index_enabled and self.root_page_id:
|
||||
yield from self._recursive_load()
|
||||
yield from self._recursive_load(start, end)
|
||||
return
|
||||
|
||||
query_dict = {
|
||||
@ -367,7 +539,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
pages = filter_pages_by_time(db_res.results, start, end, "last_edited_time")
|
||||
|
||||
if pages:
|
||||
yield from batch_generator(self._read_pages(pages), self.batch_size)
|
||||
yield from batch_generator(self._read_pages(pages, start, end), self.batch_size)
|
||||
if db_res.has_more:
|
||||
query_dict["start_cursor"] = db_res.next_cursor
|
||||
else:
|
||||
|
||||
@ -48,17 +48,35 @@ from common.data_source.exceptions import RateLimitTriedTooManyTimesError
|
||||
from common.data_source.interfaces import CT, CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, LoadFunction, OnyxExtensionType, SecondsSinceUnixEpoch, TokenResponse
|
||||
from common.data_source.models import BasicExpertInfo, Document
|
||||
|
||||
_TZ_SUFFIX_PATTERN = re.compile(r"([+-])([\d:]+)$")
|
||||
|
||||
|
||||
def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_string = datetime_string.strip()
|
||||
|
||||
match_jira_format = _TZ_SUFFIX_PATTERN.search(datetime_string)
|
||||
if match_jira_format:
|
||||
sign, tz_field = match_jira_format.groups()
|
||||
digits = tz_field.replace(":", "")
|
||||
|
||||
if digits.isdigit() and 1 <= len(digits) <= 4:
|
||||
if len(digits) >= 3:
|
||||
hours = digits[:-2].rjust(2, "0")
|
||||
minutes = digits[-2:]
|
||||
else:
|
||||
hours = digits.rjust(2, "0")
|
||||
minutes = "00"
|
||||
|
||||
normalized = f"{sign}{hours}:{minutes}"
|
||||
datetime_string = f"{datetime_string[: match_jira_format.start()]}{normalized}"
|
||||
|
||||
# Handle the case where the datetime string ends with 'Z' (Zulu time)
|
||||
if datetime_string.endswith('Z'):
|
||||
datetime_string = datetime_string[:-1] + '+00:00'
|
||||
if datetime_string.endswith("Z"):
|
||||
datetime_string = datetime_string[:-1] + "+00:00"
|
||||
|
||||
# Handle timezone format "+0000" -> "+00:00"
|
||||
if datetime_string.endswith('+0000'):
|
||||
datetime_string = datetime_string[:-5] + '+00:00'
|
||||
if datetime_string.endswith("+0000"):
|
||||
datetime_string = datetime_string[:-5] + "+00:00"
|
||||
|
||||
datetime_object = datetime.fromisoformat(datetime_string)
|
||||
|
||||
@ -293,6 +311,16 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
|
||||
aws_secret_access_key=credentials["secret_access_key"],
|
||||
region_name=credentials["region"],
|
||||
)
|
||||
elif bucket_type == BlobType.S3_COMPATIBLE:
|
||||
addressing_style = credentials.get("addressing_style", "virtual")
|
||||
|
||||
return boto3.client(
|
||||
"s3",
|
||||
endpoint_url=credentials["endpoint_url"],
|
||||
aws_access_key_id=credentials["aws_access_key_id"],
|
||||
aws_secret_access_key=credentials["aws_secret_access_key"],
|
||||
config=Config(s3={'addressing_style': addressing_style}),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported bucket type: {bucket_type}")
|
||||
@ -480,7 +508,7 @@ def get_file_ext(file_name: str) -> str:
|
||||
|
||||
|
||||
def is_accepted_file_ext(file_ext: str, extension_type: OnyxExtensionType) -> bool:
|
||||
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
|
||||
image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"}
|
||||
text_extensions = {".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql"}
|
||||
document_extensions = {".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html"}
|
||||
|
||||
@ -902,6 +930,18 @@ def load_all_docs_from_checkpoint_connector(
|
||||
)
|
||||
|
||||
|
||||
_ATLASSIAN_CLOUD_DOMAINS = (".atlassian.net", ".jira.com", ".jira-dev.com")
|
||||
|
||||
|
||||
def is_atlassian_cloud_url(url: str) -> bool:
|
||||
try:
|
||||
host = urlparse(url).hostname or ""
|
||||
except ValueError:
|
||||
return False
|
||||
host = host.lower()
|
||||
return any(host.endswith(domain) for domain in _ATLASSIAN_CLOUD_DOMAINS)
|
||||
|
||||
|
||||
def get_cloudId(base_url: str) -> str:
|
||||
tenant_info_url = urljoin(base_url, "/_edge/tenant_info")
|
||||
response = requests.get(tenant_info_url, timeout=10)
|
||||
|
||||
370
common/data_source/webdav_connector.py
Normal file
370
common/data_source/webdav_connector.py
Normal file
@ -0,0 +1,370 @@
|
||||
"""WebDAV connector"""
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from webdav4.client import Client as WebDAVClient
|
||||
|
||||
from common.data_source.utils import (
|
||||
get_file_ext,
|
||||
)
|
||||
from common.data_source.config import DocumentSource, INDEX_BATCH_SIZE, BLOB_STORAGE_SIZE_THRESHOLD
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
ConnectorValidationError,
|
||||
CredentialExpiredError,
|
||||
InsufficientPermissionsError
|
||||
)
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector
|
||||
from common.data_source.models import Document, SecondsSinceUnixEpoch, GenerateDocumentsOutput
|
||||
|
||||
|
||||
class WebDAVConnector(LoadConnector, PollConnector):
|
||||
"""WebDAV connector for syncing files from WebDAV servers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
remote_path: str = "/",
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
"""Initialize WebDAV connector
|
||||
|
||||
Args:
|
||||
base_url: Base URL of the WebDAV server (e.g., "https://webdav.example.com")
|
||||
remote_path: Remote path to sync from (default: "/")
|
||||
batch_size: Number of documents per batch
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
if not remote_path:
|
||||
remote_path = "/"
|
||||
if not remote_path.startswith("/"):
|
||||
remote_path = f"/{remote_path}"
|
||||
if remote_path.endswith("/") and remote_path != "/":
|
||||
remote_path = remote_path.rstrip("/")
|
||||
self.remote_path = remote_path
|
||||
self.batch_size = batch_size
|
||||
self.client: Optional[WebDAVClient] = None
|
||||
self._allow_images: bool | None = None
|
||||
self.size_threshold: int | None = BLOB_STORAGE_SIZE_THRESHOLD
|
||||
|
||||
def set_allow_images(self, allow_images: bool) -> None:
|
||||
"""Set whether to process images"""
|
||||
logging.info(f"Setting allow_images to {allow_images}.")
|
||||
self._allow_images = allow_images
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load credentials and initialize WebDAV client
|
||||
|
||||
Args:
|
||||
credentials: Dictionary containing 'username' and 'password'
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
ConnectorMissingCredentialError: If required credentials are missing
|
||||
"""
|
||||
logging.debug(f"Loading credentials for WebDAV server {self.base_url}")
|
||||
|
||||
username = credentials.get("username")
|
||||
password = credentials.get("password")
|
||||
|
||||
if not username or not password:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"WebDAV requires 'username' and 'password' credentials"
|
||||
)
|
||||
|
||||
try:
|
||||
# Initialize WebDAV client
|
||||
self.client = WebDAVClient(
|
||||
base_url=self.base_url,
|
||||
auth=(username, password)
|
||||
)
|
||||
|
||||
# Test connection
|
||||
self.client.exists(self.remote_path)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to connect to WebDAV server: {e}")
|
||||
raise ConnectorMissingCredentialError(
|
||||
f"Failed to authenticate with WebDAV server: {e}"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _list_files_recursive(
|
||||
self,
|
||||
path: str,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> list[tuple[str, dict]]:
|
||||
"""Recursively list all files in the given path
|
||||
|
||||
Args:
|
||||
path: Path to list files from
|
||||
start: Start datetime for filtering
|
||||
end: End datetime for filtering
|
||||
|
||||
Returns:
|
||||
List of tuples containing (file_path, file_info)
|
||||
"""
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("WebDAV client not initialized")
|
||||
|
||||
files = []
|
||||
|
||||
try:
|
||||
logging.debug(f"Listing directory: {path}")
|
||||
for item in self.client.ls(path, detail=True):
|
||||
item_path = item['name']
|
||||
|
||||
if item_path == path or item_path == path + '/':
|
||||
continue
|
||||
|
||||
logging.debug(f"Found item: {item_path}, type: {item.get('type')}")
|
||||
|
||||
if item.get('type') == 'directory':
|
||||
try:
|
||||
files.extend(self._list_files_recursive(item_path, start, end))
|
||||
except Exception as e:
|
||||
logging.error(f"Error recursing into directory {item_path}: {e}")
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
modified_time = item.get('modified')
|
||||
if modified_time:
|
||||
if isinstance(modified_time, datetime):
|
||||
modified = modified_time
|
||||
if modified.tzinfo is None:
|
||||
modified = modified.replace(tzinfo=timezone.utc)
|
||||
elif isinstance(modified_time, str):
|
||||
try:
|
||||
modified = datetime.strptime(modified_time, '%a, %d %b %Y %H:%M:%S %Z')
|
||||
modified = modified.replace(tzinfo=timezone.utc)
|
||||
except (ValueError, TypeError):
|
||||
try:
|
||||
modified = datetime.fromisoformat(modified_time.replace('Z', '+00:00'))
|
||||
except (ValueError, TypeError):
|
||||
logging.warning(f"Could not parse modified time for {item_path}: {modified_time}")
|
||||
modified = datetime.now(timezone.utc)
|
||||
else:
|
||||
modified = datetime.now(timezone.utc)
|
||||
else:
|
||||
modified = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
logging.debug(f"File {item_path}: modified={modified}, start={start}, end={end}, include={start < modified <= end}")
|
||||
if start < modified <= end:
|
||||
files.append((item_path, item))
|
||||
else:
|
||||
logging.debug(f"File {item_path} filtered out by time range")
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file {item_path}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error listing directory {path}: {e}")
|
||||
|
||||
return files
|
||||
|
||||
def _yield_webdav_documents(
|
||||
self,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Generate documents from WebDAV server
|
||||
|
||||
Args:
|
||||
start: Start datetime for filtering
|
||||
end: End datetime for filtering
|
||||
|
||||
Yields:
|
||||
Batches of documents
|
||||
"""
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("WebDAV client not initialized")
|
||||
|
||||
logging.info(f"Searching for files in {self.remote_path} between {start} and {end}")
|
||||
files = self._list_files_recursive(self.remote_path, start, end)
|
||||
logging.info(f"Found {len(files)} files matching time criteria")
|
||||
|
||||
batch: list[Document] = []
|
||||
for file_path, file_info in files:
|
||||
file_name = os.path.basename(file_path)
|
||||
|
||||
size_bytes = file_info.get('size', 0)
|
||||
if (
|
||||
self.size_threshold is not None
|
||||
and isinstance(size_bytes, int)
|
||||
and size_bytes > self.size_threshold
|
||||
):
|
||||
logging.warning(
|
||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
logging.debug(f"Downloading file: {file_path}")
|
||||
from io import BytesIO
|
||||
buffer = BytesIO()
|
||||
self.client.download_fileobj(file_path, buffer)
|
||||
blob = buffer.getvalue()
|
||||
|
||||
if blob is None or len(blob) == 0:
|
||||
logging.warning(f"Downloaded content is empty for {file_path}")
|
||||
continue
|
||||
|
||||
modified_time = file_info.get('modified')
|
||||
if modified_time:
|
||||
if isinstance(modified_time, datetime):
|
||||
modified = modified_time
|
||||
if modified.tzinfo is None:
|
||||
modified = modified.replace(tzinfo=timezone.utc)
|
||||
elif isinstance(modified_time, str):
|
||||
try:
|
||||
modified = datetime.strptime(modified_time, '%a, %d %b %Y %H:%M:%S %Z')
|
||||
modified = modified.replace(tzinfo=timezone.utc)
|
||||
except (ValueError, TypeError):
|
||||
try:
|
||||
modified = datetime.fromisoformat(modified_time.replace('Z', '+00:00'))
|
||||
except (ValueError, TypeError):
|
||||
logging.warning(f"Could not parse modified time for {file_path}: {modified_time}")
|
||||
modified = datetime.now(timezone.utc)
|
||||
else:
|
||||
modified = datetime.now(timezone.utc)
|
||||
else:
|
||||
modified = datetime.now(timezone.utc)
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"webdav:{self.base_url}:{file_path}",
|
||||
blob=blob,
|
||||
source=DocumentSource.WEBDAV,
|
||||
semantic_identifier=file_name,
|
||||
extension=get_file_ext(file_name),
|
||||
doc_updated_at=modified,
|
||||
size_bytes=size_bytes if size_bytes else 0
|
||||
)
|
||||
)
|
||||
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f"Error downloading file {file_path}: {e}")
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""Load all documents from WebDAV server
|
||||
|
||||
Yields:
|
||||
Batches of documents
|
||||
"""
|
||||
logging.debug(f"Loading documents from WebDAV server {self.base_url}")
|
||||
return self._yield_webdav_documents(
|
||||
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
end=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Poll WebDAV server for updated documents
|
||||
|
||||
Args:
|
||||
start: Start timestamp (seconds since Unix epoch)
|
||||
end: End timestamp (seconds since Unix epoch)
|
||||
|
||||
Yields:
|
||||
Batches of documents
|
||||
"""
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("WebDAV client not initialized")
|
||||
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
for batch in self._yield_webdav_documents(start_datetime, end_datetime):
|
||||
yield batch
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate WebDAV connector settings
|
||||
|
||||
Raises:
|
||||
ConnectorMissingCredentialError: If credentials are not loaded
|
||||
ConnectorValidationError: If settings are invalid
|
||||
"""
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"WebDAV credentials not loaded."
|
||||
)
|
||||
|
||||
if not self.base_url:
|
||||
raise ConnectorValidationError(
|
||||
"No base URL was provided in connector settings."
|
||||
)
|
||||
|
||||
try:
|
||||
if not self.client.exists(self.remote_path):
|
||||
raise ConnectorValidationError(
|
||||
f"Remote path '{self.remote_path}' does not exist on WebDAV server."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
|
||||
if "401" in error_message or "unauthorized" in error_message.lower():
|
||||
raise CredentialExpiredError(
|
||||
"WebDAV credentials appear invalid or expired."
|
||||
)
|
||||
|
||||
if "403" in error_message or "forbidden" in error_message.lower():
|
||||
raise InsufficientPermissionsError(
|
||||
f"Insufficient permissions to access path '{self.remote_path}' on WebDAV server."
|
||||
)
|
||||
|
||||
if "404" in error_message or "not found" in error_message.lower():
|
||||
raise ConnectorValidationError(
|
||||
f"Remote path '{self.remote_path}' does not exist on WebDAV server."
|
||||
)
|
||||
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected WebDAV client error: {e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
credentials_dict = {
|
||||
"username": os.environ.get("WEBDAV_USERNAME"),
|
||||
"password": os.environ.get("WEBDAV_PASSWORD"),
|
||||
}
|
||||
|
||||
connector = WebDAVConnector(
|
||||
base_url=os.environ.get("WEBDAV_URL") or "https://webdav.example.com",
|
||||
remote_path=os.environ.get("WEBDAV_PATH") or "/",
|
||||
)
|
||||
|
||||
try:
|
||||
connector.load_credentials(credentials_dict)
|
||||
connector.validate_connector_settings()
|
||||
|
||||
document_batch_generator = connector.load_from_state()
|
||||
for document_batch in document_batch_generator:
|
||||
print("First batch of documents:")
|
||||
for doc in document_batch:
|
||||
print(f"Document ID: {doc.id}")
|
||||
print(f"Semantic Identifier: {doc.semantic_identifier}")
|
||||
print(f"Source: {doc.source}")
|
||||
print(f"Updated At: {doc.doc_updated_at}")
|
||||
print("---")
|
||||
break
|
||||
|
||||
except ConnectorMissingCredentialError as e:
|
||||
print(f"Error: {e}")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
@ -80,4 +80,4 @@ def log_exception(e, *args):
|
||||
raise Exception(a.text)
|
||||
else:
|
||||
logging.error(str(a))
|
||||
raise e
|
||||
raise e
|
||||
|
||||
@ -21,7 +21,7 @@ import weakref
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
||||
from string import Template
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@ -30,12 +30,15 @@ from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
|
||||
from rag.llm.chat_model import ToolCallSession
|
||||
|
||||
MCPTaskType = Literal["list_tools", "tool_call"]
|
||||
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
|
||||
|
||||
|
||||
class ToolCallSession(Protocol):
|
||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
|
||||
|
||||
|
||||
class MCPToolCallSession(ToolCallSession):
|
||||
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
|
||||
|
||||
@ -106,7 +109,8 @@ class MCPToolCallSession(ToolCallSession):
|
||||
await self._process_mcp_tasks(None, msg)
|
||||
|
||||
else:
|
||||
await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
|
||||
await self._process_mcp_tasks(None,
|
||||
f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
|
||||
|
||||
async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None:
|
||||
while not self._close:
|
||||
@ -164,7 +168,8 @@ class MCPToolCallSession(ToolCallSession):
|
||||
raise
|
||||
|
||||
async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
|
||||
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, timeout=timeout)
|
||||
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments,
|
||||
timeout=timeout)
|
||||
|
||||
if result.isError:
|
||||
return f"MCP server error: {result.content}"
|
||||
@ -283,7 +288,8 @@ def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) ->
|
||||
except Exception:
|
||||
logging.exception("Exception during MCP session cleanup thread management")
|
||||
|
||||
logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
|
||||
logging.info(
|
||||
f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
|
||||
|
||||
|
||||
def shutdown_all_mcp_sessions():
|
||||
@ -298,7 +304,7 @@ def shutdown_all_mcp_sessions():
|
||||
logging.info("All MCPToolCallSession instances have been closed.")
|
||||
|
||||
|
||||
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool|dict) -> dict[str, Any]:
|
||||
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict) -> dict[str, Any]:
|
||||
if isinstance(mcp_tool, dict):
|
||||
return {
|
||||
"type": "function",
|
||||
@ -27,6 +27,7 @@ from common.constants import SVR_QUEUE_NAME, Storage
|
||||
import rag.utils
|
||||
import rag.utils.es_conn
|
||||
import rag.utils.infinity_conn
|
||||
import rag.utils.ob_conn
|
||||
import rag.utils.opensearch_conn
|
||||
from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob
|
||||
from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob
|
||||
@ -73,6 +74,8 @@ GITHUB_OAUTH = None
|
||||
FEISHU_OAUTH = None
|
||||
OAUTH_CONFIG = None
|
||||
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
|
||||
DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity")
|
||||
|
||||
|
||||
docStoreConn = None
|
||||
|
||||
@ -103,6 +106,7 @@ INFINITY = {}
|
||||
AZURE = {}
|
||||
S3 = {}
|
||||
MINIO = {}
|
||||
OB = {}
|
||||
OSS = {}
|
||||
OS = {}
|
||||
|
||||
@ -137,7 +141,7 @@ def _get_or_create_secret_key():
|
||||
import logging
|
||||
|
||||
new_key = secrets.token_hex(32)
|
||||
logging.warning(f"SECURITY WARNING: Using auto-generated SECRET_KEY. Generated key: {new_key}")
|
||||
logging.warning("SECURITY WARNING: Using auto-generated SECRET_KEY.")
|
||||
return new_key
|
||||
|
||||
class StorageFactory:
|
||||
@ -227,9 +231,9 @@ def init_settings():
|
||||
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
|
||||
OAUTH_CONFIG = get_base_config("oauth", {})
|
||||
|
||||
global DOC_ENGINE, docStoreConn, ES, OS, INFINITY
|
||||
global DOC_ENGINE, DOC_ENGINE_INFINITY, docStoreConn, ES, OB, OS, INFINITY
|
||||
DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
|
||||
# DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
|
||||
DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity")
|
||||
lower_case_doc_engine = DOC_ENGINE.lower()
|
||||
if lower_case_doc_engine == "elasticsearch":
|
||||
ES = get_base_config("es", {})
|
||||
@ -240,6 +244,9 @@ def init_settings():
|
||||
elif lower_case_doc_engine == "opensearch":
|
||||
OS = get_base_config("os", {})
|
||||
docStoreConn = rag.utils.opensearch_conn.OSConnection()
|
||||
elif lower_case_doc_engine == "oceanbase":
|
||||
OB = get_base_config("oceanbase", {})
|
||||
docStoreConn = rag.utils.ob_conn.OBConnection()
|
||||
else:
|
||||
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user