mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-04 03:25:30 +08:00
Compare commits
79 Commits
42edecc98f
...
v0.22.0
| Author | SHA1 | Date | |
|---|---|---|---|
| a36a0fe71c | |||
| a81f6d1b24 | |||
| 8406a5ea47 | |||
| 20b6dafbd8 | |||
| 33cc9cafa9 | |||
| 6567ecf15a | |||
| 3a7322f5b2 | |||
| 829e5f287b | |||
| 1e8efa2631 | |||
| e7f7c09b0b | |||
| 8ae562504b | |||
| bacc9d3ab9 | |||
| d226764ed0 | |||
| 39120d49cf | |||
| 27211a9b34 | |||
| e9de25c973 | |||
| 09e971dcc8 | |||
| 883df22aa2 | |||
| 2bd7abadd3 | |||
| 435479adb3 | |||
| 2c727a4a9c | |||
| a15f522dc9 | |||
| de53498b39 | |||
| 72740eb5b9 | |||
| c30ffb5716 | |||
| 6dcff7db97 | |||
| 9213568692 | |||
| d81e4095de | |||
| 8ddeaca3d6 | |||
| f441f8ffc2 | |||
| 522c7b7ac6 | |||
| 377c0fb4fa | |||
| 7dd9758056 | |||
| 26cf5131c9 | |||
| 93207f83ba | |||
| f77604db26 | |||
| dd5b8e2e1a | |||
| 83ff8e8009 | |||
| 7db6cb8ca3 | |||
| ba6470a7a5 | |||
| df16a80f25 | |||
| 29ea059f90 | |||
| a191933f81 | |||
| 6e1ebb2855 | |||
| 68b952abb1 | |||
| 0879b6af2c | |||
| 2b9145948f | |||
| 726473fd39 | |||
| d207291217 | |||
| bf382e5c4d | |||
| 4338e706c6 | |||
| 86af330f06 | |||
| d016a06fd5 | |||
| 7423a5806e | |||
| b6cd282ccd | |||
| 82ca2e0378 | |||
| 1cd54832b5 | |||
| 660386d3b5 | |||
| 4cdaa77545 | |||
| 9fcc4946e2 | |||
| 98e9d68c75 | |||
| 8f34824aa4 | |||
| 9a6808230a | |||
| c7bd0a755c | |||
| dd1c8c5779 | |||
| 526ba3388f | |||
| cb95072ecf | |||
| f6aeebc608 | |||
| 307f53dae8 | |||
| fa98cc2bb9 | |||
| c58d95ed69 | |||
| edbc396bc6 | |||
| b137de1def | |||
| 2cb1046cbf | |||
| a880beb1f6 | |||
| 34283d4db4 | |||
| 5629fbd2ca | |||
| b7aa6d6c4f | |||
| 0b7b88592f |
28
.github/workflows/release.yml
vendored
28
.github/workflows/release.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
||||
runs-on: [ "self-hosted", "ragflow-test" ]
|
||||
steps:
|
||||
- name: Ensure workspace ownership
|
||||
run: echo "chown -R $USER $GITHUB_WORKSPACE" && sudo chown -R $USER $GITHUB_WORKSPACE
|
||||
run: echo "chown -R ${USER} ${GITHUB_WORKSPACE}" && sudo chown -R ${USER} ${GITHUB_WORKSPACE}
|
||||
|
||||
# https://github.com/actions/checkout/blob/v3/README.md
|
||||
- name: Check out code
|
||||
@ -31,37 +31,37 @@ jobs:
|
||||
|
||||
- name: Prepare release body
|
||||
run: |
|
||||
if [[ $GITHUB_EVENT_NAME == 'create' ]]; then
|
||||
if [[ ${GITHUB_EVENT_NAME} == "create" ]]; then
|
||||
RELEASE_TAG=${GITHUB_REF#refs/tags/}
|
||||
if [[ $RELEASE_TAG == 'nightly' ]]; then
|
||||
if [[ ${RELEASE_TAG} == "nightly" ]]; then
|
||||
PRERELEASE=true
|
||||
else
|
||||
PRERELEASE=false
|
||||
fi
|
||||
echo "Workflow triggered by create tag: $RELEASE_TAG"
|
||||
echo "Workflow triggered by create tag: ${RELEASE_TAG}"
|
||||
else
|
||||
RELEASE_TAG=nightly
|
||||
PRERELEASE=true
|
||||
echo "Workflow triggered by schedule"
|
||||
fi
|
||||
echo "RELEASE_TAG=$RELEASE_TAG" >> $GITHUB_ENV
|
||||
echo "PRERELEASE=$PRERELEASE" >> $GITHUB_ENV
|
||||
echo "RELEASE_TAG=${RELEASE_TAG}" >> ${GITHUB_ENV}
|
||||
echo "PRERELEASE=${PRERELEASE}" >> ${GITHUB_ENV}
|
||||
RELEASE_DATETIME=$(date --rfc-3339=seconds)
|
||||
echo Release $RELEASE_TAG created from $GITHUB_SHA at $RELEASE_DATETIME > release_body.md
|
||||
echo Release ${RELEASE_TAG} created from ${GITHUB_SHA} at ${RELEASE_DATETIME} > release_body.md
|
||||
|
||||
- name: Move the existing mutable tag
|
||||
# https://github.com/softprops/action-gh-release/issues/171
|
||||
run: |
|
||||
git fetch --tags
|
||||
if [[ $GITHUB_EVENT_NAME == 'schedule' ]]; then
|
||||
if [[ ${GITHUB_EVENT_NAME} == "schedule" ]]; then
|
||||
# Determine if a given tag exists and matches a specific Git commit.
|
||||
# actions/checkout@v4 fetch-tags doesn't work when triggered by schedule
|
||||
if [ "$(git rev-parse -q --verify "refs/tags/$RELEASE_TAG")" = "$GITHUB_SHA" ]; then
|
||||
echo "mutable tag $RELEASE_TAG exists and matches $GITHUB_SHA"
|
||||
if [ "$(git rev-parse -q --verify "refs/tags/${RELEASE_TAG}")" = "${GITHUB_SHA}" ]; then
|
||||
echo "mutable tag ${RELEASE_TAG} exists and matches ${GITHUB_SHA}"
|
||||
else
|
||||
git tag -f $RELEASE_TAG $GITHUB_SHA
|
||||
git push -f origin $RELEASE_TAG:refs/tags/$RELEASE_TAG
|
||||
echo "created/moved mutable tag $RELEASE_TAG to $GITHUB_SHA"
|
||||
git tag -f ${RELEASE_TAG} ${GITHUB_SHA}
|
||||
git push -f origin ${RELEASE_TAG}:refs/tags/${RELEASE_TAG}
|
||||
echo "created/moved mutable tag ${RELEASE_TAG} to ${GITHUB_SHA}"
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -87,7 +87,7 @@ jobs:
|
||||
|
||||
- name: Build and push image
|
||||
run: |
|
||||
echo ${{ secrets.DOCKERHUB_TOKEN }} | sudo docker login --username infiniflow --password-stdin
|
||||
sudo docker login --username infiniflow --password-stdin <<< ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
sudo docker build --build-arg NEED_MIRROR=1 -t infiniflow/ragflow:${RELEASE_TAG} -f Dockerfile .
|
||||
sudo docker tag infiniflow/ragflow:${RELEASE_TAG} infiniflow/ragflow:latest
|
||||
sudo docker push infiniflow/ragflow:${RELEASE_TAG}
|
||||
|
||||
43
.github/workflows/tests.yml
vendored
43
.github/workflows/tests.yml
vendored
@ -9,8 +9,11 @@ on:
|
||||
- 'docs/**'
|
||||
- '*.md'
|
||||
- '*.mdx'
|
||||
pull_request:
|
||||
types: [ labeled, synchronize, reopened ]
|
||||
# The only difference between pull_request and pull_request_target is the context in which the workflow runs:
|
||||
# — pull_request_target workflows use the workflow files from the default branch, and secrets are available.
|
||||
# — pull_request workflows use the workflow files from the pull request branch, and secrets are unavailable.
|
||||
pull_request_target:
|
||||
types: [ synchronize, ready_for_review ]
|
||||
paths-ignore:
|
||||
- 'docs/**'
|
||||
- '*.md'
|
||||
@ -28,7 +31,7 @@ jobs:
|
||||
name: ragflow_tests
|
||||
# https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution
|
||||
# https://github.com/orgs/community/discussions/26261
|
||||
if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'ci') }}
|
||||
if: ${{ github.event_name != 'pull_request_target' || contains(github.event.pull_request.labels.*.name, 'ci') }}
|
||||
runs-on: [ "self-hosted", "ragflow-test" ]
|
||||
steps:
|
||||
# https://github.com/hmarr/debug-action
|
||||
@ -37,19 +40,20 @@ jobs:
|
||||
- name: Ensure workspace ownership
|
||||
run: |
|
||||
echo "Workflow triggered by ${{ github.event_name }}"
|
||||
echo "chown -R $USER $GITHUB_WORKSPACE" && sudo chown -R $USER $GITHUB_WORKSPACE
|
||||
echo "chown -R ${USER} ${GITHUB_WORKSPACE}" && sudo chown -R ${USER} ${GITHUB_WORKSPACE}
|
||||
|
||||
# https://github.com/actions/checkout/issues/1781
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ (github.event_name == 'pull_request' || github.event_name == 'pull_request_target') && format('refs/pull/{0}/merge', github.event.pull_request.number) || github.sha }}
|
||||
fetch-depth: 0
|
||||
fetch-tags: true
|
||||
|
||||
- name: Check workflow duplication
|
||||
if: ${{ !cancelled() && !failure() && (github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'ci')) }}
|
||||
if: ${{ !cancelled() && !failure() }}
|
||||
run: |
|
||||
if [[ "$GITHUB_EVENT_NAME" != "pull_request" && "$GITHUB_EVENT_NAME" != "schedule" ]]; then
|
||||
if [[ ${GITHUB_EVENT_NAME} != "pull_request_target" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then
|
||||
HEAD=$(git rev-parse HEAD)
|
||||
# Find a PR that introduced a given commit
|
||||
gh auth login --with-token <<< "${{ secrets.GITHUB_TOKEN }}"
|
||||
@ -67,14 +71,14 @@ jobs:
|
||||
gh run cancel ${GITHUB_RUN_ID}
|
||||
while true; do
|
||||
status=$(gh run view ${GITHUB_RUN_ID} --json status -q .status)
|
||||
[ "$status" = "completed" ] && break
|
||||
[ "${status}" = "completed" ] && break
|
||||
sleep 5
|
||||
done
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
else
|
||||
elif [[ ${GITHUB_EVENT_NAME} == "pull_request_target" ]]; then
|
||||
PR_NUMBER=${{ github.event.pull_request.number }}
|
||||
PR_SHA_FP=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/PR_${PR_NUMBER}
|
||||
# Calculate the hash of the current workspace content
|
||||
@ -93,18 +97,18 @@ jobs:
|
||||
|
||||
- name: Build ragflow:nightly
|
||||
run: |
|
||||
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-$HOME}
|
||||
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}
|
||||
RAGFLOW_IMAGE=infiniflow/ragflow:${GITHUB_RUN_ID}
|
||||
echo "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> $GITHUB_ENV
|
||||
echo "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> ${GITHUB_ENV}
|
||||
sudo docker pull ubuntu:22.04
|
||||
sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 -f Dockerfile -t ${RAGFLOW_IMAGE} .
|
||||
if [[ "$GITHUB_EVENT_NAME" == "schedule" ]]; then
|
||||
if [[ ${GITHUB_EVENT_NAME} == "schedule" ]]; then
|
||||
export HTTP_API_TEST_LEVEL=p3
|
||||
else
|
||||
export HTTP_API_TEST_LEVEL=p2
|
||||
fi
|
||||
echo "HTTP_API_TEST_LEVEL=${HTTP_API_TEST_LEVEL}" >> $GITHUB_ENV
|
||||
echo "RAGFLOW_CONTAINER=${GITHUB_RUN_ID}-ragflow-cpu-1" >> $GITHUB_ENV
|
||||
echo "HTTP_API_TEST_LEVEL=${HTTP_API_TEST_LEVEL}" >> ${GITHUB_ENV}
|
||||
echo "RAGFLOW_CONTAINER=${GITHUB_RUN_ID}-ragflow-cpu-1" >> ${GITHUB_ENV}
|
||||
|
||||
- name: Start ragflow:nightly
|
||||
run: |
|
||||
@ -154,7 +158,7 @@ jobs:
|
||||
echo -e "COMPOSE_PROFILES=\${COMPOSE_PROFILES},tei-cpu" >> docker/.env
|
||||
echo -e "TEI_MODEL=BAAI/bge-small-en-v1.5" >> docker/.env
|
||||
echo -e "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> docker/.env
|
||||
echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> $GITHUB_ENV
|
||||
echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> ${GITHUB_ENV}
|
||||
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
|
||||
uv sync --python 3.10 --only-group test --no-default-groups --frozen && uv pip install sdk/python
|
||||
@ -189,7 +193,8 @@ jobs:
|
||||
- name: Stop ragflow:nightly
|
||||
if: always() # always run this step even if previous steps failed
|
||||
run: |
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} down -v
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} down -v || true
|
||||
sudo docker ps -a --filter "label=com.docker.compose.project=${GITHUB_RUN_ID}" -q | xargs -r sudo docker rm -f
|
||||
|
||||
- name: Start ragflow:nightly
|
||||
run: |
|
||||
@ -226,5 +231,9 @@ jobs:
|
||||
- name: Stop ragflow:nightly
|
||||
if: always() # always run this step even if previous steps failed
|
||||
run: |
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} down -v
|
||||
sudo docker rmi -f ${RAGFLOW_IMAGE:-NO_IMAGE} || true
|
||||
# Sometimes `docker compose down` fail due to hang container, heavy load etc. Need to remove such containers to release resources(for example, listen ports).
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} down -v || true
|
||||
sudo docker ps -a --filter "label=com.docker.compose.project=${GITHUB_RUN_ID}" -q | xargs -r sudo docker rm -f
|
||||
if [[ -n ${RAGFLOW_IMAGE} ]]; then
|
||||
sudo docker rmi -f ${RAGFLOW_IMAGE}
|
||||
fi
|
||||
|
||||
31
README.md
31
README.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.21.1">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -61,8 +61,7 @@
|
||||
- 🔎 [System Architecture](#-system-architecture)
|
||||
- 🎬 [Get Started](#-get-started)
|
||||
- 🔧 [Configurations](#-configurations)
|
||||
- 🔧 [Build a docker image without embedding models](#-build-a-docker-image-without-embedding-models)
|
||||
- 🔧 [Build a docker image including embedding models](#-build-a-docker-image-including-embedding-models)
|
||||
- 🔧 [Build a Docker image](#-build-a-docker-image)
|
||||
- 🔨 [Launch service from source for development](#-launch-service-from-source-for-development)
|
||||
- 📚 [Documentation](#-documentation)
|
||||
- 📜 [Roadmap](#-roadmap)
|
||||
@ -86,6 +85,7 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
## 🔥 Latest Updates
|
||||
|
||||
- 2025-11-12 Supports data synchronization from Confluence, AWS S3, Discord, Google Drive.
|
||||
- 2025-10-23 Supports MinerU & Docling as document parsing methods.
|
||||
- 2025-10-15 Supports orchestrable ingestion pipeline.
|
||||
- 2025-08-08 Supports OpenAI's latest GPT-5 series models.
|
||||
@ -93,7 +93,6 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
- 2025-05-23 Adds a Python/JavaScript code executor component to Agent.
|
||||
- 2025-05-05 Supports cross-language query.
|
||||
- 2025-03-19 Supports using a multi-modal model to make sense of images within PDF or DOCX files.
|
||||
- 2025-02-28 Combined with Internet search (Tavily), supports reasoning like Deep Research for any LLMs.
|
||||
- 2024-12-18 Upgrades Document Layout Analysis model in DeepDoc.
|
||||
- 2024-08-22 Support text to SQL statements through RAG.
|
||||
|
||||
@ -189,25 +188,29 @@ releases! 🌟
|
||||
> All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64.
|
||||
> If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system.
|
||||
|
||||
> The command below downloads the `v0.21.1-slim` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.21.1-slim`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
|
||||
> The command below downloads the `v0.22.0` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.22.0`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
# Use CPU for embedding and DeepDoc tasks:
|
||||
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
# To use GPU to accelerate embedding and DeepDoc tasks:
|
||||
# To use GPU to accelerate DeepDoc tasks:
|
||||
# sed -i '1i DEVICE=gpu' .env
|
||||
# docker compose -f docker-compose.yml up -d
|
||||
```
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | -------------------------- |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
| nightly | ≈2 | ❌ | _Unstable_ nightly build |
|
||||
> Note: Prior to `v0.22.0`, we provided both images with embedding models and slim images without embedding models. Details as follows:
|
||||
|
||||
> Note: Starting with `v0.22.0`, we ship only the slim edition and no longer append the **-slim** suffix to the image tag.
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
|
||||
> Starting with `v0.22.0`, we ship only the slim edition and no longer append the **-slim** suffix to the image tag.
|
||||
|
||||
4. Check the server status after having the server up and running:
|
||||
|
||||
@ -288,7 +291,7 @@ RAGFlow uses Elasticsearch by default for storing full text and vectors. To swit
|
||||
> [!WARNING]
|
||||
> Switching to Infinity on a Linux/arm64 machine is not yet officially supported.
|
||||
|
||||
## 🔧 Build a Docker image without embedding models
|
||||
## 🔧 Build a Docker image
|
||||
|
||||
This image is approximately 2 GB in size and relies on external LLM and embedding services.
|
||||
|
||||
|
||||
31
README_id.md
31
README_id.md
@ -22,7 +22,7 @@
|
||||
<img alt="Lencana Daring" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.21.1">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Rilis%20Terbaru" alt="Rilis Terbaru">
|
||||
@ -61,8 +61,7 @@
|
||||
- 🔎 [Arsitektur Sistem](#-arsitektur-sistem)
|
||||
- 🎬 [Mulai](#-mulai)
|
||||
- 🔧 [Konfigurasi](#-konfigurasi)
|
||||
- 🔧 [Membangun Image Docker tanpa Model Embedding](#-membangun-image-docker-tanpa-model-embedding)
|
||||
- 🔧 [Membangun Image Docker dengan Model Embedding](#-membangun-image-docker-dengan-model-embedding)
|
||||
- 🔧 [Membangun Image Docker](#-membangun-docker-image)
|
||||
- 🔨 [Meluncurkan aplikasi dari Sumber untuk Pengembangan](#-meluncurkan-aplikasi-dari-sumber-untuk-pengembangan)
|
||||
- 📚 [Dokumentasi](#-dokumentasi)
|
||||
- 📜 [Peta Jalan](#-peta-jalan)
|
||||
@ -86,6 +85,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
## 🔥 Pembaruan Terbaru
|
||||
|
||||
- 2025-11-12 Mendukung sinkronisasi data dari Confluence, AWS S3, Discord, Google Drive.
|
||||
- 2025-10-23 Mendukung MinerU & Docling sebagai metode penguraian dokumen.
|
||||
- 2025-10-15 Dukungan untuk jalur data yang terorkestrasi.
|
||||
- 2025-08-08 Mendukung model seri GPT-5 terbaru dari OpenAI.
|
||||
@ -93,7 +93,6 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
- 2025-05-23 Menambahkan komponen pelaksana kode Python/JS ke Agen.
|
||||
- 2025-05-05 Mendukung kueri lintas bahasa.
|
||||
- 2025-03-19 Mendukung penggunaan model multi-modal untuk memahami gambar di dalam file PDF atau DOCX.
|
||||
- 2025-02-28 dikombinasikan dengan pencarian Internet (TAVILY), mendukung penelitian mendalam untuk LLM apa pun.
|
||||
- 2024-12-18 Meningkatkan model Analisis Tata Letak Dokumen di DeepDoc.
|
||||
- 2024-08-22 Dukungan untuk teks ke pernyataan SQL melalui RAG.
|
||||
|
||||
@ -187,25 +186,29 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
> Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64.
|
||||
> Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image).
|
||||
|
||||
> Perintah di bawah ini mengunduh edisi v0.21.1 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.21.1, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
|
||||
> Perintah di bawah ini mengunduh edisi v0.22.0 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.22.0, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
# Use CPU for embedding and DeepDoc tasks:
|
||||
|
||||
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases), contoh: git checkout v0.22.0
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
# To use GPU to accelerate embedding and DeepDoc tasks:
|
||||
# To use GPU to accelerate DeepDoc tasks:
|
||||
# sed -i '1i DEVICE=gpu' .env
|
||||
# docker compose -f docker-compose.yml up -d
|
||||
```
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | -------------------------- |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
| nightly | ≈2 | ❌ | _Unstable_ nightly build |
|
||||
> Catatan: Sebelum `v0.22.0`, kami menyediakan image dengan model embedding dan image slim tanpa model embedding. Detailnya sebagai berikut:
|
||||
|
||||
> Catatan: Mulai dari `v0.22.0`, kami hanya menyediakan edisi slim dan tidak lagi menambahkan akhiran **-slim** pada tag image.
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
|
||||
> Mulai dari `v0.22.0`, kami hanya menyediakan edisi slim dan tidak lagi menambahkan akhiran **-slim** pada tag image.
|
||||
|
||||
1. Periksa status server setelah server aktif dan berjalan:
|
||||
|
||||
@ -260,7 +263,7 @@ Pembaruan konfigurasi ini memerlukan reboot semua kontainer agar efektif:
|
||||
> $ docker compose -f docker-compose.yml up -d
|
||||
> ```
|
||||
|
||||
## 🔧 Membangun Docker Image tanpa Model Embedding
|
||||
## 🔧 Membangun Docker Image
|
||||
|
||||
Image ini berukuran sekitar 2 GB dan bergantung pada aplikasi LLM eksternal dan embedding.
|
||||
|
||||
|
||||
30
README_ja.md
30
README_ja.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.21.1">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -66,6 +66,7 @@
|
||||
|
||||
## 🔥 最新情報
|
||||
|
||||
- 2025-11-12 Confluence、AWS S3、Discord、Google Drive からのデータ同期をサポートします。
|
||||
- 2025-10-23 ドキュメント解析方法として MinerU と Docling をサポートします。
|
||||
- 2025-10-15 オーケストレーションされたデータパイプラインのサポート。
|
||||
- 2025-08-08 OpenAI の最新 GPT-5 シリーズモデルをサポートします。
|
||||
@ -73,7 +74,6 @@
|
||||
- 2025-05-23 エージェントに Python/JS コードエグゼキュータコンポーネントを追加しました。
|
||||
- 2025-05-05 言語間クエリをサポートしました。
|
||||
- 2025-03-19 PDFまたはDOCXファイル内の画像を理解するために、多モーダルモデルを使用することをサポートします。
|
||||
- 2025-02-28 インターネット検索 (TAVILY) と組み合わせて、あらゆる LLM の詳細な調査をサポートします。
|
||||
- 2024-12-18 DeepDoc のドキュメント レイアウト分析モデルをアップグレードします。
|
||||
- 2024-08-22 RAG を介して SQL ステートメントへのテキストをサポートします。
|
||||
|
||||
@ -166,28 +166,32 @@
|
||||
> 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。
|
||||
> ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。
|
||||
|
||||
> 以下のコマンドは、RAGFlow Docker イメージの v0.21.1 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.21.1 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
|
||||
> 以下のコマンドは、RAGFlow Docker イメージの v0.22.0 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.22.0 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
# Use CPU for embedding and DeepDoc tasks:
|
||||
|
||||
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) 例: git checkout v0.22.0
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
# To use GPU to accelerate embedding and DeepDoc tasks:
|
||||
# To use GPU to accelerate DeepDoc tasks:
|
||||
# sed -i '1i DEVICE=gpu' .env
|
||||
# docker compose -f docker-compose.yml up -d
|
||||
```
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | -------------------------- |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
| nightly | ≈2 | ❌ | _Unstable_ nightly build |
|
||||
> 注意:`v0.22.0` より前のバージョンでは、embedding モデルを含むイメージと、embedding モデルを含まない slim イメージの両方を提供していました。詳細は以下の通りです:
|
||||
|
||||
> 注意:`v0.22.0` 以降、当プロジェクトでは slim エディションのみを提供し、イメージタグに **-slim** サフィックスを付けなくなりました。
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
|
||||
1. サーバーを立ち上げた後、サーバーの状態を確認する:
|
||||
> `v0.22.0` 以降、当プロジェクトでは slim エディションのみを提供し、イメージタグに **-slim** サフィックスを付けなくなりました。
|
||||
|
||||
1. サーバーを立ち上げた後、サーバーの状態を確認する:
|
||||
|
||||
```bash
|
||||
$ docker logs -f docker-ragflow-cpu-1
|
||||
```
|
||||
@ -259,7 +263,7 @@ RAGFlow はデフォルトで Elasticsearch を使用して全文とベクトル
|
||||
> Linux/arm64 マシンでの Infinity への切り替えは正式にサポートされていません。
|
||||
>
|
||||
|
||||
## 🔧 ソースコードで Docker イメージを作成(埋め込みモデルなし)
|
||||
## 🔧 ソースコードで Docker イメージを作成
|
||||
|
||||
この Docker イメージのサイズは約 1GB で、外部の大モデルと埋め込みサービスに依存しています。
|
||||
|
||||
|
||||
30
README_ko.md
30
README_ko.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.21.1">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -67,6 +67,7 @@
|
||||
|
||||
## 🔥 업데이트
|
||||
|
||||
- 2025-11-12 Confluence, AWS S3, Discord, Google Drive에서 데이터 동기화를 지원합니다.
|
||||
- 2025-10-23 문서 파싱 방법으로 MinerU 및 Docling을 지원합니다.
|
||||
- 2025-10-15 조정된 데이터 파이프라인 지원.
|
||||
- 2025-08-08 OpenAI의 최신 GPT-5 시리즈 모델을 지원합니다.
|
||||
@ -74,7 +75,6 @@
|
||||
- 2025-05-23 Agent에 Python/JS 코드 실행기 구성 요소를 추가합니다.
|
||||
- 2025-05-05 언어 간 쿼리를 지원합니다.
|
||||
- 2025-03-19 PDF 또는 DOCX 파일 내의 이미지를 이해하기 위해 다중 모드 모델을 사용하는 것을 지원합니다.
|
||||
- 2025-02-28 인터넷 검색(TAVILY)과 결합되어 모든 LLM에 대한 심층 연구를 지원합니다.
|
||||
- 2024-12-18 DeepDoc의 문서 레이아웃 분석 모델 업그레이드.
|
||||
- 2024-08-22 RAG를 통해 SQL 문에 텍스트를 지원합니다.
|
||||
|
||||
@ -168,25 +168,29 @@
|
||||
> 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다.
|
||||
> ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image).
|
||||
|
||||
> 아래 명령어는 RAGFlow Docker 이미지의 v0.21.1 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.21.1과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
|
||||
> 아래 명령어는 RAGFlow Docker 이미지의 v0.22.0 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.22.0과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
# Use CPU for embedding and DeepDoc tasks:
|
||||
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
# To use GPU to accelerate embedding and DeepDoc tasks:
|
||||
# To use GPU to accelerate DeepDoc tasks:
|
||||
# sed -i '1i DEVICE=gpu' .env
|
||||
# docker compose -f docker-compose.yml up -d
|
||||
```
|
||||
```
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
| nightly | ≈2 | ❌ | _Unstable_ nightly build |
|
||||
> 참고: `v0.22.0` 이전 버전에서는 embedding 모델이 포함된 이미지와 embedding 모델이 포함되지 않은 slim 이미지를 모두 제공했습니다. 자세한 내용은 다음과 같습니다:
|
||||
|
||||
> 참고: `v0.22.0`부터는 slim 에디션만 배포하며 이미지 태그에 **-slim** 접미사를 더 이상 붙이지 않습니다.
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
|
||||
> `v0.22.0`부터는 slim 에디션만 배포하며 이미지 태그에 **-slim** 접미사를 더 이상 붙이지 않습니다.
|
||||
|
||||
1. 서버가 시작된 후 서버 상태를 확인하세요:
|
||||
|
||||
@ -253,7 +257,7 @@ RAGFlow 는 기본적으로 Elasticsearch 를 사용하여 전체 텍스트 및
|
||||
> [!WARNING]
|
||||
> Linux/arm64 시스템에서 Infinity로 전환하는 것은 공식적으로 지원되지 않습니다.
|
||||
|
||||
## 🔧 소스 코드로 Docker 이미지를 컴파일합니다(임베딩 모델 포함하지 않음)
|
||||
## 🔧 소스 코드로 Docker 이미지를 컴파일합니다
|
||||
|
||||
이 Docker 이미지의 크기는 약 1GB이며, 외부 대형 모델과 임베딩 서비스에 의존합니다.
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
<img alt="Badge Estático" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.21.1">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Última%20Relese" alt="Última Versão">
|
||||
@ -86,6 +86,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
|
||||
## 🔥 Últimas Atualizações
|
||||
|
||||
- 12-11-2025 Suporta a sincronização de dados do Confluence, AWS S3, Discord e Google Drive.
|
||||
- 23-10-2025 Suporta MinerU e Docling como métodos de análise de documentos.
|
||||
- 15-10-2025 Suporte para pipelines de dados orquestrados.
|
||||
- 08-08-2025 Suporta a mais recente série GPT-5 da OpenAI.
|
||||
@ -93,7 +94,6 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
- 23-05-2025 Adicione o componente executor de código Python/JS ao Agente.
|
||||
- 05-05-2025 Suporte a consultas entre idiomas.
|
||||
- 19-03-2025 Suporta o uso de um modelo multi-modal para entender imagens dentro de arquivos PDF ou DOCX.
|
||||
- 28-02-2025 combinado com a pesquisa na Internet (T AVI LY), suporta pesquisas profundas para qualquer LLM.
|
||||
- 18-12-2024 Atualiza o modelo de Análise de Layout de Documentos no DeepDoc.
|
||||
- 22-08-2024 Suporta conversão de texto para comandos SQL via RAG.
|
||||
|
||||
@ -186,25 +186,29 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
> Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64.
|
||||
> Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema.
|
||||
|
||||
> O comando abaixo baixa a edição`v0.21.1` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.21.1`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
|
||||
> O comando abaixo baixa a edição`v0.22.0` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.22.0`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor.
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
# Use CPU for embedding and DeepDoc tasks:
|
||||
|
||||
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases), ex.: git checkout v0.22.0
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
# To use GPU to accelerate embedding and DeepDoc tasks:
|
||||
# To use GPU to accelerate DeepDoc tasks:
|
||||
# sed -i '1i DEVICE=gpu' .env
|
||||
# docker compose -f docker-compose.yml up -d
|
||||
```
|
||||
|
||||
| Tag da imagem RAGFlow | Tamanho da imagem (GB) | Possui modelos de incorporação? | Estável? |
|
||||
| --------------------- | ---------------------- | --------------------------------- | ------------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Lançamento estável |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Lançamento estável |
|
||||
| nightly | ≈2 | ❌ | Construção noturna instável |
|
||||
> Nota: Antes da `v0.22.0`, fornecíamos imagens com modelos de embedding e imagens slim sem modelos de embedding. Detalhes a seguir:
|
||||
|
||||
> Observação: A partir da`v0.22.0`, distribuímos apenas a edição slim e não adicionamos mais o sufixo **-slim** às tags das imagens.
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
|
||||
> A partir da `v0.22.0`, distribuímos apenas a edição slim e não adicionamos mais o sufixo **-slim** às tags das imagens.
|
||||
|
||||
4. Verifique o status do servidor após tê-lo iniciado:
|
||||
|
||||
@ -274,9 +278,9 @@ O RAGFlow usa o Elasticsearch por padrão para armazenar texto completo e vetore
|
||||
```
|
||||
|
||||
> [!ATENÇÃO]
|
||||
> A mudança para o Infinity em uma máquina Linux/arm64 ainda não é oficialmente suportada.
|
||||
> A mudança para o Infinity em uma máquina Linux/arm64 ainda não é oficialmente suportada.
|
||||
|
||||
## 🔧 Criar uma imagem Docker sem modelos de incorporação
|
||||
## 🔧 Criar uma imagem Docker
|
||||
|
||||
Esta imagem tem cerca de 2 GB de tamanho e depende de serviços externos de LLM e incorporação.
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.21.1">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -85,6 +85,7 @@
|
||||
|
||||
## 🔥 近期更新
|
||||
|
||||
- 2025-11-12 支援從 Confluence、AWS S3、Discord、Google Drive 進行資料同步。
|
||||
- 2025-10-23 支援 MinerU 和 Docling 作為文件解析方法。
|
||||
- 2025-10-15 支援可編排的資料管道。
|
||||
- 2025-08-08 支援 OpenAI 最新的 GPT-5 系列模型。
|
||||
@ -92,7 +93,6 @@
|
||||
- 2025-05-23 為 Agent 新增 Python/JS 程式碼執行器元件。
|
||||
- 2025-05-05 支援跨語言查詢。
|
||||
- 2025-03-19 PDF和DOCX中的圖支持用多模態大模型去解析得到描述.
|
||||
- 2025-02-28 結合網路搜尋(Tavily),對於任意大模型實現類似 Deep Research 的推理功能.
|
||||
- 2024-12-18 升級了 DeepDoc 的文檔佈局分析模型。
|
||||
- 2024-08-22 支援用 RAG 技術實現從自然語言到 SQL 語句的轉換。
|
||||
|
||||
@ -185,25 +185,29 @@
|
||||
> 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。
|
||||
> 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。
|
||||
|
||||
> 執行以下指令會自動下載 RAGFlow slim Docker 映像 `v0.21.1`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.21.1` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
|
||||
> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.22.0`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.22.0` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
# Use CPU for embedding and DeepDoc tasks:
|
||||
|
||||
# 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases),例:git checkout v0.22.0
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
# To use GPU to accelerate embedding and DeepDoc tasks:
|
||||
# To use GPU to accelerate DeepDoc tasks:
|
||||
# sed -i '1i DEVICE=gpu' .env
|
||||
# docker compose -f docker-compose.yml up -d
|
||||
```
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | -------------------------- |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
| nightly | ≈2 | ❌ | _Unstable_ nightly build |
|
||||
> 注意:在 `v0.22.0` 之前的版本,我們會同時提供包含 embedding 模型的映像和不含 embedding 模型的 slim 映像。具體如下:
|
||||
|
||||
> 注意:自 `v0.22.0` 起,我們僅發佈 slim 版本,並且不再在映像標籤後附加 **-slim** 後綴。
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
|
||||
> 從 `v0.22.0` 開始,我們只發佈 slim 版本,並且不再在映像標籤後附加 **-slim** 後綴。
|
||||
|
||||
> [!TIP]
|
||||
> 如果你遇到 Docker 映像檔拉不下來的問題,可以在 **docker/.env** 檔案內根據變數 `RAGFLOW_IMAGE` 的註解提示選擇華為雲或阿里雲的對應映像。
|
||||
@ -285,7 +289,7 @@ RAGFlow 預設使用 Elasticsearch 儲存文字和向量資料. 如果要切換
|
||||
> [!WARNING]
|
||||
> Infinity 目前官方並未正式支援在 Linux/arm64 架構下的機器上運行.
|
||||
|
||||
## 🔧 原始碼編譯 Docker 映像(不含 embedding 模型)
|
||||
## 🔧 原始碼編譯 Docker 映像
|
||||
|
||||
本 Docker 映像大小約 2 GB 左右並且依賴外部的大模型和 embedding 服務。
|
||||
|
||||
|
||||
20
README_zh.md
20
README_zh.md
@ -22,7 +22,7 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.21.1">
|
||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.22.0">
|
||||
</a>
|
||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||
@ -85,6 +85,7 @@
|
||||
|
||||
## 🔥 近期更新
|
||||
|
||||
- 2025-11-12 支持从 Confluence、AWS S3、Discord、Google Drive 进行数据同步。
|
||||
- 2025-10-23 支持 MinerU 和 Docling 作为文档解析方法。
|
||||
- 2025-10-15 支持可编排的数据管道。
|
||||
- 2025-08-08 支持 OpenAI 最新的 GPT-5 系列模型。
|
||||
@ -92,7 +93,6 @@
|
||||
- 2025-05-23 Agent 新增 Python/JS 代码执行器组件。
|
||||
- 2025-05-05 支持跨语言查询。
|
||||
- 2025-03-19 PDF 和 DOCX 中的图支持用多模态大模型去解析得到描述.
|
||||
- 2025-02-28 结合互联网搜索(Tavily),对于任意大模型实现类似 Deep Research 的推理功能.
|
||||
- 2024-12-18 升级了 DeepDoc 的文档布局分析模型。
|
||||
- 2024-08-22 支持用 RAG 技术实现从自然语言到 SQL 语句的转换。
|
||||
|
||||
@ -186,25 +186,29 @@
|
||||
> 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。
|
||||
> 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。
|
||||
|
||||
> 运行以下命令会自动下载 RAGFlow slim Docker 镜像 `v0.21.1`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.21.1` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
|
||||
> 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.22.0`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.22.0` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
# Use CPU for embedding and DeepDoc tasks:
|
||||
|
||||
# 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases),例如:git checkout v0.22.0
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
# To use GPU to accelerate embedding and DeepDoc tasks:
|
||||
# To use GPU to accelerate DeepDoc tasks:
|
||||
# sed -i '1i DEVICE=gpu' .env
|
||||
# docker compose -f docker-compose.yml up -d
|
||||
```
|
||||
|
||||
> 注意:在 `v0.22.0` 之前的版本,我们会同时提供包含 embedding 模型的镜像和不含 embedding 模型的 slim 镜像。具体如下:
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||
| v0.21.1 | ≈9 | ✔️ | Stable release |
|
||||
| v0.21.1-slim | ≈2 | ❌ | Stable release |
|
||||
| nightly | ≈2 | ❌ | _Unstable_ nightly build |
|
||||
|
||||
> 注意:从 `v0.22.0` 开始,我们只发布 slim 版本,并且不再在镜像标签后附加 **-slim** 后缀。
|
||||
> 从 `v0.22.0` 开始,我们只发布 slim 版本,并且不再在镜像标签后附加 **-slim** 后缀。
|
||||
|
||||
> [!TIP]
|
||||
> 如果你遇到 Docker 镜像拉不下来的问题,可以在 **docker/.env** 文件内根据变量 `RAGFLOW_IMAGE` 的注释提示选择华为云或者阿里云的相应镜像。
|
||||
@ -284,7 +288,7 @@ RAGFlow 默认使用 Elasticsearch 存储文本和向量数据. 如果要切换
|
||||
> [!WARNING]
|
||||
> Infinity 目前官方并未正式支持在 Linux/arm64 架构下的机器上运行.
|
||||
|
||||
## 🔧 源码编译 Docker 镜像(不含 embedding 模型)
|
||||
## 🔧 源码编译 Docker 镜像
|
||||
|
||||
本 Docker 镜像大小约 2 GB 左右并且依赖外部的大模型和 embedding 服务。
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ It consists of a server-side Service and a command-line client (CLI), both imple
|
||||
1. Ensure the Admin Service is running.
|
||||
2. Install ragflow-cli.
|
||||
```bash
|
||||
pip install ragflow-cli==0.21.1
|
||||
pip install ragflow-cli==0.22.0
|
||||
```
|
||||
3. Launch the CLI client:
|
||||
```bash
|
||||
|
||||
@ -378,7 +378,7 @@ class AdminCLI(Cmd):
|
||||
self.session.headers.update({
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': response.headers['Authorization'],
|
||||
'User-Agent': 'RAGFlow-CLI/0.21.1'
|
||||
'User-Agent': 'RAGFlow-CLI/0.22.0'
|
||||
})
|
||||
print("Authentication successful.")
|
||||
return True
|
||||
@ -392,6 +392,21 @@ class AdminCLI(Cmd):
|
||||
print(str(e))
|
||||
print(f"Can't access {self.host}, port: {self.port}")
|
||||
|
||||
def _format_service_detail_table(self, data):
|
||||
if not any([isinstance(v, list) for v in data.values()]):
|
||||
# normal table
|
||||
return data
|
||||
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
|
||||
task_executor_list = []
|
||||
for k, v in data.items():
|
||||
# display latest status
|
||||
heartbeats = sorted(v, key=lambda x: x["now"], reverse=True)
|
||||
task_executor_list.append({
|
||||
"task_executor_name": k,
|
||||
**heartbeats[0],
|
||||
})
|
||||
return task_executor_list
|
||||
|
||||
def _print_table_simple(self, data):
|
||||
if not data:
|
||||
print("No data to print")
|
||||
@ -595,7 +610,8 @@ class AdminCLI(Cmd):
|
||||
if isinstance(res_data['message'], str):
|
||||
print(res_data['message'])
|
||||
else:
|
||||
self._print_table_simple(res_data['message'])
|
||||
data = self._format_service_detail_table(res_data['message'])
|
||||
self._print_table_simple(data)
|
||||
else:
|
||||
print(f"Service {res_data['service_name']} is down, {res_data['message']}")
|
||||
else:
|
||||
@ -632,7 +648,9 @@ class AdminCLI(Cmd):
|
||||
response = self.session.get(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
table_data = res_json['data']
|
||||
table_data.pop('avatar')
|
||||
self._print_table_simple(table_data)
|
||||
else:
|
||||
print(f"Fail to get user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
@ -705,7 +723,10 @@ class AdminCLI(Cmd):
|
||||
response = self.session.get(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
table_data = res_json['data']
|
||||
for t in table_data:
|
||||
t.pop('avatar')
|
||||
self._print_table_simple(table_data)
|
||||
else:
|
||||
print(f"Fail to get all datasets of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
@ -717,7 +738,10 @@ class AdminCLI(Cmd):
|
||||
response = self.session.get(url)
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
self._print_table_simple(res_json['data'])
|
||||
table_data = res_json['data']
|
||||
for t in table_data:
|
||||
t.pop('avatar')
|
||||
self._print_table_simple(table_data)
|
||||
else:
|
||||
print(f"Fail to get all agents of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ragflow-cli"
|
||||
version = "0.21.1"
|
||||
version = "0.22.0"
|
||||
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
|
||||
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
|
||||
license = { text = "Apache License, Version 2.0" }
|
||||
|
||||
@ -183,11 +183,13 @@ class RAGFlowServerConfig(BaseConfig):
|
||||
|
||||
|
||||
class TaskExecutorConfig(BaseConfig):
|
||||
message_queue_type: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if 'extra' not in result:
|
||||
result['extra'] = dict()
|
||||
result['extra']['message_queue_type'] = self.message_queue_type
|
||||
return result
|
||||
|
||||
|
||||
@ -299,6 +301,15 @@ def load_configurations(config_path: str) -> list[BaseConfig]:
|
||||
id_count += 1
|
||||
case "admin":
|
||||
pass
|
||||
case "task_executor":
|
||||
name: str = 'task_executor'
|
||||
host: str = v.get('host', '')
|
||||
port: int = v.get('port', 0)
|
||||
message_queue_type: str = v.get('message_queue_type')
|
||||
config = TaskExecutorConfig(id=id_count, name=name, host=host, port=port, message_queue_type=message_queue_type,
|
||||
service_type="task_executor", detail_func_name="check_task_executor_alive")
|
||||
configurations.append(config)
|
||||
id_count += 1
|
||||
case _:
|
||||
logging.warning(f"Unknown configuration key: {k}")
|
||||
continue
|
||||
|
||||
@ -52,6 +52,7 @@ class UserMgr:
|
||||
result = []
|
||||
for user in users:
|
||||
result.append({
|
||||
'avatar': user.avatar,
|
||||
'email': user.email,
|
||||
'language': user.language,
|
||||
'last_login_time': user.last_login_time,
|
||||
@ -170,7 +171,8 @@ class UserServiceMgr:
|
||||
return [{
|
||||
'title': r['title'],
|
||||
'permission': r['permission'],
|
||||
'canvas_category': r['canvas_category'].split('_')[0]
|
||||
'canvas_category': r['canvas_category'].split('_')[0],
|
||||
'avatar': r['avatar']
|
||||
} for r in res]
|
||||
|
||||
|
||||
@ -190,6 +192,10 @@ class ServiceMgr:
|
||||
config_dict['status'] = 'timeout'
|
||||
except Exception:
|
||||
config_dict['status'] = 'timeout'
|
||||
if not config_dict['host']:
|
||||
config_dict['host'] = '-'
|
||||
if not config_dict['port']:
|
||||
config_dict['port'] = '-'
|
||||
result.append(config_dict)
|
||||
return result
|
||||
|
||||
|
||||
@ -26,7 +26,9 @@ from typing import Any, Union, Tuple
|
||||
from agent.component import component_class
|
||||
from agent.component.base import ComponentBase
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.misc_utils import get_uuid, hash_str2int
|
||||
from common.exceptions import TaskCanceledException
|
||||
from rag.prompts.generator import chunks_format
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
@ -126,6 +128,7 @@ class Graph:
|
||||
self.components[k]["obj"].reset()
|
||||
try:
|
||||
REDIS_CONN.delete(f"{self.task_id}-logs")
|
||||
REDIS_CONN.delete(f"{self.task_id}-cancel")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@ -154,7 +157,7 @@ class Graph:
|
||||
return self._tenant_id
|
||||
|
||||
def get_value_with_variable(self,value: str) -> Any:
|
||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z:0-9_.-]+|sys\.[a-z_]+)\} *\}*")
|
||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
||||
out_parts = []
|
||||
last = 0
|
||||
|
||||
@ -196,7 +199,7 @@ class Graph:
|
||||
if not rest:
|
||||
return root_val
|
||||
return self.get_variable_param_value(root_val,rest)
|
||||
|
||||
|
||||
def get_variable_param_value(self, obj: Any, path: str) -> Any:
|
||||
cur = obj
|
||||
if not path:
|
||||
@ -215,6 +218,17 @@ class Graph:
|
||||
cur = getattr(cur, key, None)
|
||||
return cur
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
return has_canceled(self.task_id)
|
||||
|
||||
def cancel_task(self) -> bool:
|
||||
try:
|
||||
REDIS_CONN.set(f"{self.task_id}-cancel", "x")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class Canvas(Graph):
|
||||
|
||||
@ -239,7 +253,7 @@ class Canvas(Graph):
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
|
||||
|
||||
self.retrieval = self.dsl["retrieval"]
|
||||
self.memory = self.dsl.get("memory", [])
|
||||
|
||||
@ -256,18 +270,19 @@ class Canvas(Graph):
|
||||
self.retrieval = []
|
||||
self.memory = []
|
||||
for k in self.globals.keys():
|
||||
if isinstance(self.globals[k], str):
|
||||
self.globals[k] = ""
|
||||
elif isinstance(self.globals[k], int):
|
||||
self.globals[k] = 0
|
||||
elif isinstance(self.globals[k], float):
|
||||
self.globals[k] = 0
|
||||
elif isinstance(self.globals[k], list):
|
||||
self.globals[k] = []
|
||||
elif isinstance(self.globals[k], dict):
|
||||
self.globals[k] = {}
|
||||
else:
|
||||
self.globals[k] = None
|
||||
if k.startswith("sys."):
|
||||
if isinstance(self.globals[k], str):
|
||||
self.globals[k] = ""
|
||||
elif isinstance(self.globals[k], int):
|
||||
self.globals[k] = 0
|
||||
elif isinstance(self.globals[k], float):
|
||||
self.globals[k] = 0
|
||||
elif isinstance(self.globals[k], list):
|
||||
self.globals[k] = []
|
||||
elif isinstance(self.globals[k], dict):
|
||||
self.globals[k] = {}
|
||||
else:
|
||||
self.globals[k] = None
|
||||
|
||||
def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
@ -310,10 +325,20 @@ class Canvas(Graph):
|
||||
self.path.append("begin")
|
||||
self.retrieval.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
if self.is_canceled():
|
||||
msg = f"Task {self.task_id} has been canceled before starting."
|
||||
logging.info(msg)
|
||||
raise TaskCanceledException(msg)
|
||||
|
||||
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
|
||||
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
|
||||
|
||||
def _run_batch(f, t):
|
||||
if self.is_canceled():
|
||||
msg = f"Task {self.task_id} has been canceled during batch execution."
|
||||
logging.info(msg)
|
||||
raise TaskCanceledException(msg)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
thr = []
|
||||
i = f
|
||||
@ -324,7 +349,7 @@ class Canvas(Graph):
|
||||
i += 1
|
||||
else:
|
||||
for _, ele in cpn.get_input_elements().items():
|
||||
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i]:
|
||||
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
|
||||
self.path.pop(i)
|
||||
t -= 1
|
||||
break
|
||||
@ -455,9 +480,10 @@ class Canvas(Graph):
|
||||
for c in path:
|
||||
o = self.get_component_obj(c)
|
||||
if o.component_name.lower() == "userfillup":
|
||||
o.invoke()
|
||||
another_inputs.update(o.get_input_elements())
|
||||
if o.get_param("enable_tips"):
|
||||
tips = o.get_param("tips")
|
||||
tips = o.output("tips")
|
||||
self.path = path
|
||||
yield decorate("user_inputs", {"inputs": another_inputs, "tips": tips})
|
||||
return
|
||||
@ -471,6 +497,14 @@ class Canvas(Graph):
|
||||
"created_at": st,
|
||||
})
|
||||
self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
|
||||
elif "Task has been canceled" in self.error:
|
||||
yield decorate("workflow_finished",
|
||||
{
|
||||
"inputs": kwargs.get("inputs"),
|
||||
"outputs": "Task has been canceled",
|
||||
"elapsed_time": time.perf_counter() - st,
|
||||
"created_at": st,
|
||||
})
|
||||
|
||||
def is_reff(self, exp: str) -> bool:
|
||||
exp = exp.strip("{").strip("}")
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import importlib
|
||||
import inspect
|
||||
@ -53,7 +52,7 @@ def component_class(class_name):
|
||||
for module_name in ["agent.component", "agent.tools", "rag.flow"]:
|
||||
try:
|
||||
return getattr(importlib.import_module(module_name), class_name)
|
||||
except Exception as e:
|
||||
logging.warning(f"Can't import module: {module_name}, error: {e}")
|
||||
except Exception:
|
||||
# logging.warning(f"Can't import module: {module_name}, error: {e}")
|
||||
pass
|
||||
assert False, f"Can't import {class_name}"
|
||||
|
||||
@ -139,6 +139,9 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
|
||||
if kwargs.get("user_prompt"):
|
||||
usr_pmt = ""
|
||||
if kwargs.get("reasoning"):
|
||||
@ -152,6 +155,8 @@ class Agent(LLM, ToolBase):
|
||||
self._param.prompts = [{"role": "user", "content": usr_pmt}]
|
||||
|
||||
if not self.tools:
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
return LLM._invoke(self, **kwargs)
|
||||
|
||||
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
||||
@ -171,6 +176,8 @@ class Agent(LLM, ToolBase):
|
||||
use_tools = []
|
||||
ans = ""
|
||||
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
ans += delta_ans
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
@ -191,12 +198,16 @@ class Agent(LLM, ToolBase):
|
||||
answer_without_toolcall = ""
|
||||
use_tools = []
|
||||
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
|
||||
if delta_ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", delta_ans)
|
||||
return
|
||||
answer_without_toolcall += delta_ans
|
||||
yield delta_ans
|
||||
|
||||
@ -271,6 +282,8 @@ class Agent(LLM, ToolBase):
|
||||
st = timer()
|
||||
txt = ""
|
||||
for delta_ans in self._gen_citations(entire_txt):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
yield delta_ans, 0
|
||||
txt += delta_ans
|
||||
|
||||
@ -286,6 +299,8 @@ class Agent(LLM, ToolBase):
|
||||
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||
for _ in range(self._param.max_rounds + 1):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||
# self.callback("next_step", {}, str(response)[:256]+"...")
|
||||
token_count += tk
|
||||
@ -333,6 +348,8 @@ Instructions:
|
||||
6. Focus on delivering VALUE with the information already gathered
|
||||
Respond immediately with your final comprehensive answer.
|
||||
"""
|
||||
if self.check_if_canceled("Agent final instruction"):
|
||||
return
|
||||
append_user_content(hist, final_instruction)
|
||||
|
||||
for txt, tkcnt in complete():
|
||||
|
||||
@ -393,7 +393,7 @@ class ComponentParamBase(ABC):
|
||||
class ComponentBase(ABC):
|
||||
component_name: str
|
||||
thread_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
|
||||
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z:0-9_.-]+|sys\.[a-z_]+)\} *\}*"
|
||||
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
@ -417,6 +417,20 @@ class ComponentBase(ABC):
|
||||
self._param = param
|
||||
self._param.check()
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
return self._canvas.is_canceled()
|
||||
|
||||
def check_if_canceled(self, message: str = "") -> bool:
|
||||
if self.is_canceled():
|
||||
task_id = getattr(self._canvas, 'task_id', 'unknown')
|
||||
log_message = f"Task {task_id} has been canceled"
|
||||
if message:
|
||||
log_message += f" during {message}"
|
||||
logging.info(log_message)
|
||||
self.set_output("_ERROR", "Task has been canceled")
|
||||
return True
|
||||
return False
|
||||
|
||||
def invoke(self, **kwargs) -> dict[str, Any]:
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
try:
|
||||
|
||||
@ -37,7 +37,13 @@ class Begin(UserFillUp):
|
||||
component_name = "Begin"
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Begin processing"):
|
||||
return
|
||||
|
||||
for k, v in kwargs.get("inputs", {}).items():
|
||||
if self.check_if_canceled("Begin processing"):
|
||||
return
|
||||
|
||||
if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
|
||||
if v.get("optional") and v.get("value", None) is None:
|
||||
v = None
|
||||
|
||||
@ -98,6 +98,9 @@ class Categorize(LLM, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Categorize processing"):
|
||||
return
|
||||
|
||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||
if not msg:
|
||||
msg = [{"role": "user", "content": ""}]
|
||||
@ -114,10 +117,18 @@ class Categorize(LLM, ABC):
|
||||
---- Real Data ----
|
||||
{} →
|
||||
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
|
||||
|
||||
if self.check_if_canceled("Categorize processing"):
|
||||
return
|
||||
|
||||
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
||||
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
||||
if ERROR_PREFIX in ans:
|
||||
raise Exception(ans)
|
||||
|
||||
if self.check_if_canceled("Categorize processing"):
|
||||
return
|
||||
|
||||
# Count the number of times each category appears in the answer.
|
||||
category_counts = {}
|
||||
for c in self._param.category_description.keys():
|
||||
|
||||
@ -47,6 +47,7 @@ class DataOperations(ComponentBase,ABC):
|
||||
inputs = [inputs]
|
||||
for input_ref in inputs:
|
||||
input_object=self._canvas.get_variable_value(input_ref)
|
||||
self.set_input_value(input_ref, input_object)
|
||||
if input_object is None:
|
||||
continue
|
||||
if isinstance(input_object,dict):
|
||||
|
||||
@ -13,7 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
import json
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
|
||||
|
||||
class UserFillUpParam(ComponentParamBase):
|
||||
@ -31,10 +35,35 @@ class UserFillUp(ComponentBase):
|
||||
component_name = "UserFillUp"
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("UserFillUp processing"):
|
||||
return
|
||||
|
||||
if self._param.enable_tips:
|
||||
content = self._param.tips
|
||||
for k, v in self.get_input_elements_from_text(self._param.tips).items():
|
||||
v = v["value"]
|
||||
ans = ""
|
||||
if isinstance(v, partial):
|
||||
for t in v():
|
||||
ans += t
|
||||
elif isinstance(v, list):
|
||||
ans = ",".join([str(vv) for vv in v])
|
||||
elif not isinstance(v, str):
|
||||
try:
|
||||
ans = json.dumps(v, ensure_ascii=False)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
ans = v
|
||||
if not ans:
|
||||
ans = ""
|
||||
content = re.sub(r"\{%s\}"%k, ans, content)
|
||||
|
||||
self.set_output("tips", content)
|
||||
for k, v in kwargs.get("inputs", {}).items():
|
||||
if self.check_if_canceled("UserFillUp processing"):
|
||||
return
|
||||
self.set_output(k, v)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Waiting for your input..."
|
||||
|
||||
|
||||
|
||||
@ -56,6 +56,9 @@ class Invoke(ComponentBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Invoke processing"):
|
||||
return
|
||||
|
||||
args = {}
|
||||
for para in self._param.variables:
|
||||
if para.get("value"):
|
||||
@ -89,6 +92,9 @@ class Invoke(ComponentBase, ABC):
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
if self.check_if_canceled("Invoke processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
if method == "get":
|
||||
response = requests.get(url=url, params=args, headers=headers, proxies=proxies, timeout=self._param.timeout)
|
||||
@ -121,6 +127,9 @@ class Invoke(ComponentBase, ABC):
|
||||
|
||||
return self.output("result")
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("Invoke processing"):
|
||||
return
|
||||
|
||||
last_e = e
|
||||
logging.exception(f"Http request error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
@ -56,6 +56,9 @@ class Iteration(ComponentBase, ABC):
|
||||
return cid
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Iteration processing"):
|
||||
return
|
||||
|
||||
arr = self._canvas.get_variable_value(self._param.items_ref)
|
||||
if not isinstance(arr, list):
|
||||
self.set_output("_ERROR", self._param.items_ref + " must be an array, but its type is "+str(type(arr)))
|
||||
|
||||
@ -33,6 +33,9 @@ class IterationItem(ComponentBase, ABC):
|
||||
self._idx = 0
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("IterationItem processing"):
|
||||
return
|
||||
|
||||
parent = self.get_parent()
|
||||
arr = self._canvas.get_variable_value(parent._param.items_ref)
|
||||
if not isinstance(arr, list):
|
||||
@ -40,12 +43,17 @@ class IterationItem(ComponentBase, ABC):
|
||||
raise Exception(parent._param.items_ref + " must be an array, but its type is "+str(type(arr)))
|
||||
|
||||
if self._idx > 0:
|
||||
if self.check_if_canceled("IterationItem processing"):
|
||||
return
|
||||
self.output_collation()
|
||||
|
||||
if self._idx >= len(arr):
|
||||
self._idx = -1
|
||||
return
|
||||
|
||||
if self.check_if_canceled("IterationItem processing"):
|
||||
return
|
||||
|
||||
self.set_output("item", arr[self._idx])
|
||||
self.set_output("index", self._idx)
|
||||
|
||||
@ -80,4 +88,4 @@ class IterationItem(ComponentBase, ABC):
|
||||
return self._idx == -1
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Next turn..."
|
||||
return "Next turn..."
|
||||
|
||||
@ -207,6 +207,9 @@ class LLM(ComponentBase):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
def clean_formated_answer(ans: str) -> str:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
|
||||
@ -223,6 +226,9 @@ class LLM(ComponentBase):
|
||||
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||
prompt += structured_output_prompt(schema)
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
error = ""
|
||||
ans = self._generate(msg)
|
||||
@ -248,6 +254,9 @@ class LLM(ComponentBase):
|
||||
return
|
||||
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
error = ""
|
||||
ans = self._generate(msg)
|
||||
@ -269,6 +278,9 @@ class LLM(ComponentBase):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer = ""
|
||||
for ans in self._generate_streamly(msg):
|
||||
if self.check_if_canceled("LLM streaming"):
|
||||
return
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
@ -287,4 +299,4 @@ class LLM(ComponentBase):
|
||||
|
||||
def thoughts(self) -> str:
|
||||
_, msg,_ = self._prepare_prompt_variables()
|
||||
return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nI’ll figure out our best next move."
|
||||
return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nI’ll figure out our best next move."
|
||||
|
||||
@ -89,6 +89,9 @@ class Message(ComponentBase):
|
||||
all_content = ""
|
||||
cache = {}
|
||||
for r in re.finditer(self.variable_ref_patt, rand_cnt, flags=re.DOTALL):
|
||||
if self.check_if_canceled("Message streaming"):
|
||||
return
|
||||
|
||||
all_content += rand_cnt[s: r.start()]
|
||||
yield rand_cnt[s: r.start()]
|
||||
s = r.end()
|
||||
@ -99,26 +102,33 @@ class Message(ComponentBase):
|
||||
continue
|
||||
|
||||
v = self._canvas.get_variable_value(exp)
|
||||
if not v:
|
||||
if v is None:
|
||||
v = ""
|
||||
if isinstance(v, partial):
|
||||
cnt = ""
|
||||
for t in v():
|
||||
if self.check_if_canceled("Message streaming"):
|
||||
return
|
||||
|
||||
all_content += t
|
||||
cnt += t
|
||||
yield t
|
||||
|
||||
self.set_input_value(exp, cnt)
|
||||
continue
|
||||
elif not isinstance(v, str):
|
||||
try:
|
||||
v = json.dumps(v, ensure_ascii=False, indent=2)
|
||||
v = json.dumps(v, ensure_ascii=False)
|
||||
except Exception:
|
||||
v = str(v)
|
||||
yield v
|
||||
self.set_input_value(exp, v)
|
||||
all_content += v
|
||||
cache[exp] = v
|
||||
|
||||
if s < len(rand_cnt):
|
||||
if self.check_if_canceled("Message streaming"):
|
||||
return
|
||||
|
||||
all_content += rand_cnt[s: ]
|
||||
yield rand_cnt[s: ]
|
||||
|
||||
@ -132,6 +142,9 @@ class Message(ComponentBase):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Message processing"):
|
||||
return
|
||||
|
||||
rand_cnt = random.choice(self._param.content)
|
||||
if self._param.stream and not self._is_jinjia2(rand_cnt):
|
||||
self.set_output("content", partial(self._stream, rand_cnt))
|
||||
@ -144,6 +157,9 @@ class Message(ComponentBase):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self.check_if_canceled("Message processing"):
|
||||
return
|
||||
|
||||
for n, v in kwargs.items():
|
||||
content = re.sub(n, v, content)
|
||||
|
||||
|
||||
@ -63,17 +63,24 @@ class StringTransform(Message, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("StringTransform processing"):
|
||||
return
|
||||
|
||||
if self._param.method == "split":
|
||||
self._split(kwargs.get("line"))
|
||||
else:
|
||||
self._merge(kwargs)
|
||||
|
||||
def _split(self, line:str|None = None):
|
||||
if self.check_if_canceled("StringTransform split processing"):
|
||||
return
|
||||
|
||||
var = self._canvas.get_variable_value(self._param.split_ref) if not line else line
|
||||
if not var:
|
||||
var = ""
|
||||
assert isinstance(var, str), "The input variable is not a string: {}".format(type(var))
|
||||
self.set_input_value(self._param.split_ref, var)
|
||||
|
||||
res = []
|
||||
for i,s in enumerate(re.split(r"(%s)"%("|".join([re.escape(d) for d in self._param.delimiters])), var, flags=re.DOTALL)):
|
||||
if i % 2 == 1:
|
||||
@ -82,6 +89,9 @@ class StringTransform(Message, ABC):
|
||||
self.set_output("result", res)
|
||||
|
||||
def _merge(self, kwargs:dict[str, str] = {}):
|
||||
if self.check_if_canceled("StringTransform merge processing"):
|
||||
return
|
||||
|
||||
script = self._param.script
|
||||
script, kwargs = self.get_kwargs(script, kwargs, self._param.delimiters[0])
|
||||
|
||||
|
||||
@ -63,9 +63,18 @@ class Switch(ComponentBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Switch processing"):
|
||||
return
|
||||
|
||||
for cond in self._param.conditions:
|
||||
if self.check_if_canceled("Switch processing"):
|
||||
return
|
||||
|
||||
res = []
|
||||
for item in cond["items"]:
|
||||
if self.check_if_canceled("Switch processing"):
|
||||
return
|
||||
|
||||
if not item["cpn_id"]:
|
||||
continue
|
||||
cpn_v = self._canvas.get_variable_value(item["cpn_id"])
|
||||
@ -128,4 +137,4 @@ class Switch(ComponentBase, ABC):
|
||||
raise ValueError('Not supported operator' + operator)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "I’m weighing a few options and will pick the next step shortly."
|
||||
return "I’m weighing a few options and will pick the next step shortly."
|
||||
|
||||
84
agent/component/varaiable_aggregator.py
Normal file
84
agent/component/varaiable_aggregator.py
Normal file
@ -0,0 +1,84 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
import os
|
||||
|
||||
from common.connection_utils import timeout
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class VariableAggregatorParam(ComponentParamBase):
|
||||
"""
|
||||
Parameters for VariableAggregator
|
||||
|
||||
- groups: list of dicts {"group_name": str, "variables": [variable selectors]}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# each group expects: {"group_name": str, "variables": List[str]}
|
||||
self.groups = []
|
||||
|
||||
def check(self):
|
||||
self.check_empty(self.groups, "[VariableAggregator] groups")
|
||||
for g in self.groups:
|
||||
if not g.get("group_name"):
|
||||
raise ValueError("[VariableAggregator] group_name can not be empty!")
|
||||
if not g.get("variables"):
|
||||
raise ValueError(
|
||||
f"[VariableAggregator] variables of group `{g.get('group_name')}` can not be empty"
|
||||
)
|
||||
if not isinstance(g.get("variables"), list):
|
||||
raise ValueError(
|
||||
f"[VariableAggregator] variables of group `{g.get('group_name')}` should be a list of strings"
|
||||
)
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"variables": {
|
||||
"name": "Variables",
|
||||
"type": "list",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class VariableAggregator(ComponentBase):
|
||||
component_name = "VariableAggregator"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
|
||||
def _invoke(self, **kwargs):
|
||||
# Group mode: for each group, pick the first available variable
|
||||
for group in self._param.groups:
|
||||
gname = group.get("group_name")
|
||||
|
||||
# record candidate selectors within this group
|
||||
self.set_input_value(f"{gname}.variables", list(group.get("variables", [])))
|
||||
for selector in group.get("variables", []):
|
||||
val = self._canvas.get_variable_value(selector['value'])
|
||||
if val:
|
||||
self.set_output(gname, val)
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
def _to_object(value: Any) -> Any:
|
||||
# Try to convert value to serializable object if it has to_object()
|
||||
try:
|
||||
return value.to_object() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Aggregating variables from canvas and grouping as configured."
|
||||
519
agent/templates/user_interaction.json
Normal file
519
agent/templates/user_interaction.json
Normal file
@ -0,0 +1,519 @@
|
||||
{
|
||||
"id": 27,
|
||||
"title": {
|
||||
"en": "Interactive Agent",
|
||||
"zh": "可交互的 Agent"
|
||||
},
|
||||
"description": {
|
||||
"en": "During the Agent’s execution, users can actively intervene and interact with the Agent to adjust or guide its output, ensuring the final result aligns with their intentions.",
|
||||
"zh": "在 Agent 的运行过程中,用户可以随时介入,与 Agent 进行交互,以调整或引导生成结果,使最终输出更符合预期。"
|
||||
},
|
||||
"canvas_type": "Agent",
|
||||
"dsl": {
|
||||
"components": {
|
||||
"Agent:LargeFliesMelt": {
|
||||
"downstream": [
|
||||
"UserFillUp:GoldBroomsRelate"
|
||||
],
|
||||
"obj": {
|
||||
"component_name": "Agent",
|
||||
"params": {
|
||||
"cite": true,
|
||||
"delay_after_error": 1,
|
||||
"description": "",
|
||||
"exception_default_value": "",
|
||||
"exception_goto": [],
|
||||
"exception_method": "",
|
||||
"frequencyPenaltyEnabled": false,
|
||||
"frequency_penalty": 0.7,
|
||||
"llm_id": "qwen-turbo@Tongyi-Qianwen",
|
||||
"maxTokensEnabled": false,
|
||||
"max_retries": 3,
|
||||
"max_rounds": 1,
|
||||
"max_tokens": 256,
|
||||
"mcp": [],
|
||||
"message_history_window_size": 12,
|
||||
"outputs": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"value": ""
|
||||
},
|
||||
"structured": {}
|
||||
},
|
||||
"presencePenaltyEnabled": false,
|
||||
"presence_penalty": 0.4,
|
||||
"prompts": [
|
||||
{
|
||||
"content": "User query:{sys.query}",
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
"sys_prompt": "<role>\nYou are the Planning Agent in a multi-agent RAG workflow.\nYour sole job is to design a crisp, executable Search Plan for the next agent. Do not search or answer the user’s question.\n</role>\n<objectives>\nUnderstand the user’s task and decompose it into evidence-seeking steps.\nProduce high-quality queries and retrieval settings tailored to the task type (fact lookup, multi-hop reasoning, comparison, statistics, how-to, etc.).\nIdentify missing information that would materially change the plan (≤3 concise questions).\nOptimize for source trustworthiness, diversity, and recency; define stopping criteria to avoid over-searching.\nAnswer in 150 words.\n<objectives>",
|
||||
"temperature": 0.1,
|
||||
"temperatureEnabled": false,
|
||||
"tools": [],
|
||||
"topPEnabled": false,
|
||||
"top_p": 0.3,
|
||||
"user_prompt": "",
|
||||
"visual_files_var": ""
|
||||
}
|
||||
},
|
||||
"upstream": [
|
||||
"begin"
|
||||
]
|
||||
},
|
||||
"Agent:TangyWordsType": {
|
||||
"downstream": [
|
||||
"Message:FreshWallsStudy"
|
||||
],
|
||||
"obj": {
|
||||
"component_name": "Agent",
|
||||
"params": {
|
||||
"cite": true,
|
||||
"delay_after_error": 1,
|
||||
"description": "",
|
||||
"exception_default_value": "",
|
||||
"exception_goto": [],
|
||||
"exception_method": "",
|
||||
"frequencyPenaltyEnabled": false,
|
||||
"frequency_penalty": 0.7,
|
||||
"llm_id": "qwen-turbo@Tongyi-Qianwen",
|
||||
"maxTokensEnabled": false,
|
||||
"max_retries": 3,
|
||||
"max_rounds": 1,
|
||||
"max_tokens": 256,
|
||||
"mcp": [],
|
||||
"message_history_window_size": 12,
|
||||
"outputs": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"value": ""
|
||||
},
|
||||
"structured": {}
|
||||
},
|
||||
"presencePenaltyEnabled": false,
|
||||
"presence_penalty": 0.4,
|
||||
"prompts": [
|
||||
{
|
||||
"content": "Search Plan: {Agent:LargeFliesMelt@content}\n\n\n\nAwait Response feedback:{UserFillUp:GoldBroomsRelate@instructions}\n",
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
"sys_prompt": "<role>\nYou are the Search Agent.\nYour job is to execute the approved Search Plan, integrate the Await Response feedback, retrieve evidence, and produce a well-grounded answer.\n</role>\n<objectives>\nTranslate the plan + feedback into concrete searches.\nCollect diverse, trustworthy, and recent evidence meeting the plan’s evidence bar.\nSynthesize a concise answer; include citations next to claims they support.\nIf evidence is insufficient or conflicting, clearly state limitations and propose next steps.\n</objectives>\n <tools>\nRetrieval: You must use Retrieval to do the search.\n </tools>\n",
|
||||
"temperature": 0.1,
|
||||
"temperatureEnabled": false,
|
||||
"tools": [
|
||||
{
|
||||
"component_name": "Retrieval",
|
||||
"name": "Retrieval",
|
||||
"params": {
|
||||
"cross_languages": [],
|
||||
"description": "",
|
||||
"empty_response": "",
|
||||
"kb_ids": [],
|
||||
"keywords_similarity_weight": 0.7,
|
||||
"outputs": {
|
||||
"formalized_content": {
|
||||
"type": "string",
|
||||
"value": ""
|
||||
},
|
||||
"json": {
|
||||
"type": "Array<Object>",
|
||||
"value": []
|
||||
}
|
||||
},
|
||||
"rerank_id": "",
|
||||
"similarity_threshold": 0.2,
|
||||
"toc_enhance": false,
|
||||
"top_k": 1024,
|
||||
"top_n": 8,
|
||||
"use_kg": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"topPEnabled": false,
|
||||
"top_p": 0.3,
|
||||
"user_prompt": "",
|
||||
"visual_files_var": ""
|
||||
}
|
||||
},
|
||||
"upstream": [
|
||||
"UserFillUp:GoldBroomsRelate"
|
||||
]
|
||||
},
|
||||
"Message:FreshWallsStudy": {
|
||||
"downstream": [],
|
||||
"obj": {
|
||||
"component_name": "Message",
|
||||
"params": {
|
||||
"content": [
|
||||
"{Agent:TangyWordsType@content}"
|
||||
]
|
||||
}
|
||||
},
|
||||
"upstream": [
|
||||
"Agent:TangyWordsType"
|
||||
]
|
||||
},
|
||||
"UserFillUp:GoldBroomsRelate": {
|
||||
"downstream": [
|
||||
"Agent:TangyWordsType"
|
||||
],
|
||||
"obj": {
|
||||
"component_name": "UserFillUp",
|
||||
"params": {
|
||||
"enable_tips": true,
|
||||
"inputs": {
|
||||
"instructions": {
|
||||
"name": "instructions",
|
||||
"optional": false,
|
||||
"options": [],
|
||||
"type": "paragraph"
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"instructions": {
|
||||
"name": "instructions",
|
||||
"optional": false,
|
||||
"options": [],
|
||||
"type": "paragraph"
|
||||
}
|
||||
},
|
||||
"tips": "Here is my search plan:\n{Agent:LargeFliesMelt@content}\nAre you okay with it?"
|
||||
}
|
||||
},
|
||||
"upstream": [
|
||||
"Agent:LargeFliesMelt"
|
||||
]
|
||||
},
|
||||
"begin": {
|
||||
"downstream": [
|
||||
"Agent:LargeFliesMelt"
|
||||
],
|
||||
"obj": {
|
||||
"component_name": "Begin",
|
||||
"params": {}
|
||||
},
|
||||
"upstream": []
|
||||
}
|
||||
},
|
||||
"globals": {
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": [],
|
||||
"sys.query": "",
|
||||
"sys.user_id": ""
|
||||
},
|
||||
"graph": {
|
||||
"edges": [
|
||||
{
|
||||
"data": {
|
||||
"isHovered": false
|
||||
},
|
||||
"id": "xy-edge__beginstart-Agent:LargeFliesMeltend",
|
||||
"source": "begin",
|
||||
"sourceHandle": "start",
|
||||
"target": "Agent:LargeFliesMelt",
|
||||
"targetHandle": "end"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"isHovered": false
|
||||
},
|
||||
"id": "xy-edge__Agent:LargeFliesMeltstart-UserFillUp:GoldBroomsRelateend",
|
||||
"source": "Agent:LargeFliesMelt",
|
||||
"sourceHandle": "start",
|
||||
"target": "UserFillUp:GoldBroomsRelate",
|
||||
"targetHandle": "end"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"isHovered": false
|
||||
},
|
||||
"id": "xy-edge__UserFillUp:GoldBroomsRelatestart-Agent:TangyWordsTypeend",
|
||||
"source": "UserFillUp:GoldBroomsRelate",
|
||||
"sourceHandle": "start",
|
||||
"target": "Agent:TangyWordsType",
|
||||
"targetHandle": "end"
|
||||
},
|
||||
{
|
||||
"id": "xy-edge__Agent:TangyWordsTypetool-Tool:NastyBatsGoend",
|
||||
"source": "Agent:TangyWordsType",
|
||||
"sourceHandle": "tool",
|
||||
"target": "Tool:NastyBatsGo",
|
||||
"targetHandle": "end"
|
||||
},
|
||||
{
|
||||
"id": "xy-edge__Agent:TangyWordsTypestart-Message:FreshWallsStudyend",
|
||||
"source": "Agent:TangyWordsType",
|
||||
"sourceHandle": "start",
|
||||
"target": "Message:FreshWallsStudy",
|
||||
"targetHandle": "end"
|
||||
}
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"label": "Begin",
|
||||
"name": "begin"
|
||||
},
|
||||
"dragging": false,
|
||||
"id": "begin",
|
||||
"measured": {
|
||||
"height": 50,
|
||||
"width": 200
|
||||
},
|
||||
"position": {
|
||||
"x": 154.9008789064451,
|
||||
"y": 119.51001744285344
|
||||
},
|
||||
"selected": false,
|
||||
"sourcePosition": "left",
|
||||
"targetPosition": "right",
|
||||
"type": "beginNode"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"form": {
|
||||
"cite": true,
|
||||
"delay_after_error": 1,
|
||||
"description": "",
|
||||
"exception_default_value": "",
|
||||
"exception_goto": [],
|
||||
"exception_method": "",
|
||||
"frequencyPenaltyEnabled": false,
|
||||
"frequency_penalty": 0.7,
|
||||
"llm_id": "qwen-turbo@Tongyi-Qianwen",
|
||||
"maxTokensEnabled": false,
|
||||
"max_retries": 3,
|
||||
"max_rounds": 1,
|
||||
"max_tokens": 256,
|
||||
"mcp": [],
|
||||
"message_history_window_size": 12,
|
||||
"outputs": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"value": ""
|
||||
},
|
||||
"structured": {}
|
||||
},
|
||||
"presencePenaltyEnabled": false,
|
||||
"presence_penalty": 0.4,
|
||||
"prompts": [
|
||||
{
|
||||
"content": "User query:{sys.query}",
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
"sys_prompt": "<role>\nYou are the Planning Agent in a multi-agent RAG workflow.\nYour sole job is to design a crisp, executable Search Plan for the next agent. Do not search or answer the user’s question.\n</role>\n<objectives>\nUnderstand the user’s task and decompose it into evidence-seeking steps.\nProduce high-quality queries and retrieval settings tailored to the task type (fact lookup, multi-hop reasoning, comparison, statistics, how-to, etc.).\nIdentify missing information that would materially change the plan (≤3 concise questions).\nOptimize for source trustworthiness, diversity, and recency; define stopping criteria to avoid over-searching.\nAnswer in 150 words.\n<objectives>",
|
||||
"temperature": 0.1,
|
||||
"temperatureEnabled": false,
|
||||
"tools": [],
|
||||
"topPEnabled": false,
|
||||
"top_p": 0.3,
|
||||
"user_prompt": "",
|
||||
"visual_files_var": ""
|
||||
},
|
||||
"label": "Agent",
|
||||
"name": "Planning Agent"
|
||||
},
|
||||
"dragging": false,
|
||||
"id": "Agent:LargeFliesMelt",
|
||||
"measured": {
|
||||
"height": 90,
|
||||
"width": 200
|
||||
},
|
||||
"position": {
|
||||
"x": 443.96309330796714,
|
||||
"y": 104.61370811205677
|
||||
},
|
||||
"selected": false,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "agentNode"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"form": {
|
||||
"enable_tips": true,
|
||||
"inputs": {
|
||||
"instructions": {
|
||||
"name": "instructions",
|
||||
"optional": false,
|
||||
"options": [],
|
||||
"type": "paragraph"
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"instructions": {
|
||||
"name": "instructions",
|
||||
"optional": false,
|
||||
"options": [],
|
||||
"type": "paragraph"
|
||||
}
|
||||
},
|
||||
"tips": "Here is my search plan:\n{Agent:LargeFliesMelt@content}\nAre you okay with it?"
|
||||
},
|
||||
"label": "UserFillUp",
|
||||
"name": "Await Response"
|
||||
},
|
||||
"dragging": false,
|
||||
"id": "UserFillUp:GoldBroomsRelate",
|
||||
"measured": {
|
||||
"height": 50,
|
||||
"width": 200
|
||||
},
|
||||
"position": {
|
||||
"x": 683.3409492927474,
|
||||
"y": 116.76274137645598
|
||||
},
|
||||
"selected": false,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "ragNode"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"form": {
|
||||
"cite": true,
|
||||
"delay_after_error": 1,
|
||||
"description": "",
|
||||
"exception_default_value": "",
|
||||
"exception_goto": [],
|
||||
"exception_method": "",
|
||||
"frequencyPenaltyEnabled": false,
|
||||
"frequency_penalty": 0.7,
|
||||
"llm_id": "qwen-turbo@Tongyi-Qianwen",
|
||||
"maxTokensEnabled": false,
|
||||
"max_retries": 3,
|
||||
"max_rounds": 1,
|
||||
"max_tokens": 256,
|
||||
"mcp": [],
|
||||
"message_history_window_size": 12,
|
||||
"outputs": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"value": ""
|
||||
},
|
||||
"structured": {}
|
||||
},
|
||||
"presencePenaltyEnabled": false,
|
||||
"presence_penalty": 0.4,
|
||||
"prompts": [
|
||||
{
|
||||
"content": "Search Plan: {Agent:LargeFliesMelt@content}\n\n\n\nAwait Response feedback:{UserFillUp:GoldBroomsRelate@instructions}\n",
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
"sys_prompt": "<role>\nYou are the Search Agent.\nYour job is to execute the approved Search Plan, integrate the Await Response feedback, retrieve evidence, and produce a well-grounded answer.\n</role>\n<objectives>\nTranslate the plan + feedback into concrete searches.\nCollect diverse, trustworthy, and recent evidence meeting the plan’s evidence bar.\nSynthesize a concise answer; include citations next to claims they support.\nIf evidence is insufficient or conflicting, clearly state limitations and propose next steps.\n</objectives>\n <tools>\nRetrieval: You must use Retrieval to do the search.\n </tools>\n",
|
||||
"temperature": 0.1,
|
||||
"temperatureEnabled": false,
|
||||
"tools": [
|
||||
{
|
||||
"component_name": "Retrieval",
|
||||
"name": "Retrieval",
|
||||
"params": {
|
||||
"cross_languages": [],
|
||||
"description": "",
|
||||
"empty_response": "",
|
||||
"kb_ids": [],
|
||||
"keywords_similarity_weight": 0.7,
|
||||
"outputs": {
|
||||
"formalized_content": {
|
||||
"type": "string",
|
||||
"value": ""
|
||||
},
|
||||
"json": {
|
||||
"type": "Array<Object>",
|
||||
"value": []
|
||||
}
|
||||
},
|
||||
"rerank_id": "",
|
||||
"similarity_threshold": 0.2,
|
||||
"toc_enhance": false,
|
||||
"top_k": 1024,
|
||||
"top_n": 8,
|
||||
"use_kg": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"topPEnabled": false,
|
||||
"top_p": 0.3,
|
||||
"user_prompt": "",
|
||||
"visual_files_var": ""
|
||||
},
|
||||
"label": "Agent",
|
||||
"name": "Search Agent"
|
||||
},
|
||||
"dragging": false,
|
||||
"id": "Agent:TangyWordsType",
|
||||
"measured": {
|
||||
"height": 90,
|
||||
"width": 200
|
||||
},
|
||||
"position": {
|
||||
"x": 944.6411255659472,
|
||||
"y": 99.84499066368488
|
||||
},
|
||||
"selected": true,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "agentNode"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"form": {
|
||||
"description": "This is an agent for a specific task.",
|
||||
"user_prompt": "This is the order you need to send to the agent."
|
||||
},
|
||||
"label": "Tool",
|
||||
"name": "flow.tool_0"
|
||||
},
|
||||
"id": "Tool:NastyBatsGo",
|
||||
"measured": {
|
||||
"height": 50,
|
||||
"width": 200
|
||||
},
|
||||
"position": {
|
||||
"x": 862.6411255659472,
|
||||
"y": 239.84499066368488
|
||||
},
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "toolNode"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"form": {
|
||||
"content": [
|
||||
"{Agent:TangyWordsType@content}"
|
||||
]
|
||||
},
|
||||
"label": "Message",
|
||||
"name": "Message"
|
||||
},
|
||||
"dragging": false,
|
||||
"id": "Message:FreshWallsStudy",
|
||||
"measured": {
|
||||
"height": 50,
|
||||
"width": 200
|
||||
},
|
||||
"position": {
|
||||
"x": 1216.7057997987163,
|
||||
"y": 120.48541298149814
|
||||
},
|
||||
"selected": false,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "messageNode"
|
||||
}
|
||||
]
|
||||
},
|
||||
"history": [],
|
||||
"messages": [],
|
||||
"path": [],
|
||||
"retrieval": [],
|
||||
"variables": {}
|
||||
},
|
||||
"avatar":
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAAACXBIWXMAABYlAAAWJQFJUiTwAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA1FSURBVHgBzVppcFRVFv7e672TTjrprGTrkI1tgFjiAI4TmgEGcEFgVFAHxr2mLFe0cKxyBOeHljVjwVDqzDhVTqksOlOiuCBrCwk7hCAkLAGy70kn3eklvb03571e8rrTCQGi5em6ecu9fe85557znXNuh0EUHTGb5/Ast4TnmXsB3oifBTFVxEsVy7PrZ5lM9RE9oZvTZrPeAe51Gvh85ABGMuonJh6Ra9Mzw2KDm2PXm0ymPoS6ReZ5n5n6p4f45QOsj0ihEXxwdGg9Pta3Ax2R98ErH+wIzMFjFFTlhcwkCMEKT4qE+Nc5np8e5Aa8MAf94UNtmEl5DPYFrpHPER8+xn3U3KNkXqDpKtFaSGi1Wm188MH7655+6inoExLR0toW0N211R+8jJ198YHtiFyDH5EHE8Oy7Iccx/1BeJ4yZRI2bXgHSfokyGQsLH19YDgax4RXGJz450A8v1HGMMw62soMEgRxcVqcrDyG9X95E339fSidPh05Odnot9vB0Qe84CB8wMQglSe8HWPI3HAmJVmLYTOYJ55YxVedOYcTxyux59svsfqpJ5GWkorGhhZYLL0oKSnC2pfX4DemOaIg9n473B63ZCFGOu2YU4QIgikIJiF5yaxcuZTPzzfC42WQnZWJLdu2IiXFgG+/3oOs7Ey0t3fC7/cjMyMdU6dOwpyyMswtmwO9PhmdnYE+5hrcS9Eq8MzSHRfuvT7Dj5qbiKeGY4fKsWL1ShQWFKG1uQVz55Thh5pqVJ87D6fLDY/HA7/PD87PiYxMnToRLzz7LI0zoamphQTxYewpKIzUB6OUJcIoOTEOVpjhcroRH69FQ2MLHANOZKZn4PcPP4DikgKoVWrI5CzkcjmUSiUuXriKRx77I2bMuh0tba1IStKL84TXDTU26hp9z0QzJYVVbhBuJdDL8Vy4yeRK+TrBNHp7e9HR0QWXawDNTa3o7ulB2R2zER+nQXZ2FhRKBVwOl7BloO+IVwGpBtxebPv0v0hNN+CO2beHHT4cD6I9Pvo+iobElmtYk6xkQuE6nvOjpaWNmHeDJxPp7bVCp4vHvv0HsXfvQVSeOgOtRoOFi+aJu+C0O6FVa2AwJEOhkMPhcODEiVNwuJ0w5ubRLsbB7/UPMsvH5HTYxvNBwSUAIT5yQcG4wTlk3+/5bt2cX5chIzMdbW3tOH/+EtRaNex2B1RkKkuXLkJCog41NZdw8MBhFBUVICU1WTQDwZwMKcloJJN74snHMeByoq7uCr1XECBkEzB4I3ge6p4xkp3hiOFj+rdcYFpg5vaZs3H3nXehh0znyPFjWLv2NcyeNQPbtn0hDlRr1GKc2L1rP5QqJeQyMiMZRH8QqKujDa+9+gruWf4A4siPlLQzeXlG+Dl/hAjR2B7KnURQEnyDC4kyXPoS5dN7dn7FR0umJ4d8aNUjBKEdePSxh8iMDuDixcsiCglbK5PJ4PMFUEe4F6B01zfbMeD1iTx8vPUTMi0lXn7uJXT1dIr+Eph/EFViAMqoKGBag89sZFIVuMaTplsJWZYtvwv7yQdqai5CQWYhoMyECUWYv8BE2s0RJ8gk03v/3Q24UHsJ27/YjnaKDa4BD1INBpQfrgDlWoFZRRPgcdNBW4hlLBNGM3mEdMGrmzDfSNt/4OBhdLR1Yv26tXj77U0iEqUYUvDvf34AjUpOaUclssaNQ2cXoZdzAC++/Cds3vwZLl+oxtZPt2INPVefOYH6hsbALtwkDabtg8jG8pK0OURtbR34dMtHaKhrImlZvPXW35FsSMKKFfeivOIwsnKMWLDwbrjdbtEHOB8HFfnF8YoD+Md7m7Bv3y7s2rtfXOjkqUqMFYnMR5tQsCdCQwK+t5MQ+3d/g4cfXiHCZHeXBZ98/D/R5gX0sfRZsGTZAzDNW4SG5kbk5uSSL3DIzc1GaekMkNyQ0Z+PN29FYmICxooEPkNN5Hv3zh0jhgq9PhFXrtSh+nw1Kg4dIah04Oix02JfQUEepk2bgs8//0bcocWLfotXXlqD7u5uzF1wp7hIcnIyKr7fjSZKTyDV5HXkO9EkrfjksQZI62Brnw2pqQbMz5yH5fcug4/zUtCqxIcffSRmsI2NrSKjlp5ebNnyGWG/B2++8QYW3TlfTAhZGQOej8EAw0SYbcTaGLk6kxZR8miGY4wWGRAY6+7pFu1PSLH/88G/yAcG8NcNG3H85CkxSgrBT4gPDqcDE4sn4GjKCbBCsGAki7KDO8AybIBRugqLMFFMhuPDCALJI9CBGfxidKkYfi+Mp0mFGCEkUy888wy0Wq3YJyf/sFOa0W/rR093DzLS0tDXZw3WHkPnE9li5IhzdMCu0EGupnn8/sgxLEISBJTMISKSsVKnCDEtvYY+ovLIKQVGQk1waCF56yDs7+zsotjRDlu/DR6fF4sXL0RN9QWwFD/ik1JJwdwQ7XEspSKudqSqFUhsqcWAtUccH4tC/Ilrh/hlmbB84RcjtZgTSwQMPQtayhqXiTPvvIDHJyXjtWVlUCUkQ6ZQBHZDXIvMR6akOmMALtoxTUom9K2XaPesYl9wsmsSG2aQYTDih4ndQvl89LOb0o72wlkonTYZ9y25BwffWYOexstQJ+jhJ8377b1Ic9ngam4A53GC9zqgpCrPd7YcnFwZqB4l64aVHNXYmBoMDkbMgmMUJNqqH6rEFChvW0wMWjFv0RKom37AuX3bkcnbUKCOQ3LqOFgb69Gu1GGgs5HiiwYTbp0JS8UOMBpdxLrDRXI2WruICmrRGh5VEz9kIj4PtIZMqGdQ1HZZEZ+ZhxmFuVD73WT/bhx5/8+wZk1AvHEqrAYjCdoHH9UUxRMnoefIDggFHitjR9QVK9VyQIBYzNxAHhPEeI6EkJM2vfm3QSnzwpA/FZaOeuzdvAn9GUYUlC2Gj2zfkJoPRfpkNFdW0DgWeVlGKJqr4SKQkFNmOxwHbDSTTFjxTFAuaQ9z7Y8EIUK+RdgIXq1DUsFM2Pva0dfWDGu/C8ZxBhiaTyGXbF/ltsBSfRI+dQKhmB8KbTwUcYlgG2vQ095CZaxy0KwxaOLscEyEtyWYx4ey4ZhloFQ/XLD5o5pgDl46V+q4igOU6PVyGtxy10ooEtPQfKYc+z/cBBfrRM7EKZR3dYs7x/IK6JJT4K2vRq8QTxgpfwgaaizn4MN/pFsyfAsWKRFNFlCPCNRC3k7PjVZKCpvrULZ8BV7927vobmxG7dG9OHK4HDYqW6/W1cI3wCAtdyql8c0U/Z3wUuGUWzwFF8w7oIzTiZYptQ2WCUdoVnKVoFL4lgn3D+5UYKx45YVdCupEsmNhUyRUUhrywGWUoGhCKdpbWmGpO42LZ8/BR/EgnmqNxb97FCqFH9UHd8JG6YjT1U9w7IOP6pNJhUYc3f0VtCSElBt5KDeJFEKKnbwElQLhX0oMI70O52rBBM3nRELRL9HQeh5x8KGy3IzjVbWwWfvxwc7dsFssuHz2NBo7W1Ciywdnt4KOQ+ChkxC1SoWmc+XwmBZRYaUWq0Oe80H26KpV60TNhTQW0fiAVoNFBBPSMs9E+ED08/CNxgm2HZcCy5XjkGt16Dp/Fms3vUfiyGFtv4KK/ftQWFxEhwEMvFQwyTihWFKIJ4Ie+u6V0+XkUjzcDhvcNJ9ckkxgSN6LoNfzsYA08t3ooTYA1UrjrXD+cAC/mj+X0CYJ/cT8oe++ppQiBf2OfjE5tFD2qzekIk9GdXpGDtRUSMnpN4wByptSUhOhdLVK6gFecu4SfY3Nxo0TbX18qhHujHqkjy9G84WT6GptwZUOC2qqziM7Lx16jRI54/MJwTzgKYdiBSgWYopSQ0edJXB6/NhyslYCo+F8Q5KDYGiUHVTiTYkAn9MGfcks9NFRpnAUqc/MIgY5LF1xH9Ioe1247H5MnjQNxZNvAUMptttuQyYdlrXRaYmccqnaLhvSEjWRFdlozEDK+I0KIVZiAsRSkaSdOBfeumPobWkUS9RcOqZJjtchJbsEjt42eEhQpS5ZFBiETHVX68jnPHR040YGHW1GFjQxXED67mbq2IhpJTvJu21QjZ8BFaXRPnq2kP3rE3RouUBVHsGzlQ4PEhLi4CB/aCLo9dLhmRDFs+k3jEt0GC3sQD01Y2jCoatJb6MGXEOe0QnMUPCyI+OWBdDk/AK1x3ejraFeLJYG6DTE7XIhjiC0z9IFJ9UNOq2AOz4UZ2bD6hyokj22enU+zTETN0LMyO16MliOzEmpVsH4i1lIzy2kn0+14CjAUSYHi80uFj8cOfL4wiJkpKfT4QKHLB37HVNuNs9hWc6MkVU5IiL9GCSccMsVKjrOl5HZuKnWtqKj4TI6as+AcffTkaUKfl6WL9rEoQP7NpAenhvNxEOO92LGjx+BGCGfUkKh0ZIv21Bz9vTGJ59+8XlRALPZrFexvLAL0zEGNNLJ80h910FVmoQkU2lpaeBfDYT/OXBzjIm0uhFjQMwN9o2KOH6jhngVmI853xGz2QgZ1lH2MI3yIXFHeP4nNP7YVE9cfMlw7BezTKbvpR3/Bx465XnKBextAAAAAElFTkSuQmCC"
|
||||
}
|
||||
@ -63,12 +63,18 @@ class ArXiv(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("ArXiv processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("ArXiv processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
sort_choices = {"relevance": arxiv.SortCriterion.Relevance,
|
||||
"lastUpdatedDate": arxiv.SortCriterion.LastUpdatedDate,
|
||||
@ -79,12 +85,20 @@ class ArXiv(ToolBase, ABC):
|
||||
max_results=self._param.top_n,
|
||||
sort_by=sort_choices[self._param.sort_by]
|
||||
)
|
||||
self._retrieve_chunks(list(arxiv_client.results(search)),
|
||||
results = list(arxiv_client.results(search))
|
||||
|
||||
if self.check_if_canceled("ArXiv processing"):
|
||||
return
|
||||
|
||||
self._retrieve_chunks(results,
|
||||
get_title=lambda r: r.title,
|
||||
get_url=lambda r: r.pdf_url,
|
||||
get_content=lambda r: r.summary)
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("ArXiv processing"):
|
||||
return
|
||||
|
||||
last_e = e
|
||||
logging.exception(f"ArXiv error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
@ -125,6 +125,9 @@ class ToolBase(ComponentBase):
|
||||
return self._param.get_meta()
|
||||
|
||||
def invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Tool processing"):
|
||||
return
|
||||
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
try:
|
||||
res = self._invoke(**kwargs)
|
||||
@ -170,4 +173,4 @@ class ToolBase(ComponentBase):
|
||||
self.set_output("formalized_content", "\n".join(kb_prompt({"chunks": chunks, "doc_aggs": aggs}, 200000, True)))
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return self._canvas.get_component_name(self._id) + " is running..."
|
||||
return self._canvas.get_component_name(self._id) + " is running..."
|
||||
|
||||
@ -131,10 +131,14 @@ class CodeExec(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("CodeExec processing"):
|
||||
return
|
||||
|
||||
lang = kwargs.get("lang", self._param.lang)
|
||||
script = kwargs.get("script", self._param.script)
|
||||
arguments = {}
|
||||
for k, v in self._param.arguments.items():
|
||||
|
||||
if kwargs.get(k):
|
||||
arguments[k] = kwargs[k]
|
||||
continue
|
||||
@ -149,15 +153,28 @@ class CodeExec(ToolBase, ABC):
|
||||
def _execute_code(self, language: str, code: str, arguments: dict):
|
||||
import requests
|
||||
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
try:
|
||||
code_b64 = self._encode_code(code)
|
||||
code_req = CodeExecutionRequest(code_b64=code_b64, language=language, arguments=arguments).model_dump()
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
self.set_output("_ERROR", "construct code request error: " + str(e))
|
||||
|
||||
try:
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return "Task has been canceled"
|
||||
|
||||
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:")
|
||||
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return "Task has been canceled"
|
||||
|
||||
if resp.status_code != 200:
|
||||
resp.raise_for_status()
|
||||
body = resp.json()
|
||||
@ -173,16 +190,25 @@ class CodeExec(ToolBase, ABC):
|
||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run -> {rt}")
|
||||
if isinstance(rt, tuple):
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
if k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt[i]
|
||||
elif isinstance(rt, dict):
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
if k not in rt or k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt[k]
|
||||
else:
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
if k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt
|
||||
@ -190,6 +216,9 @@ class CodeExec(ToolBase, ABC):
|
||||
self.set_output("_ERROR", "There is no response from sandbox")
|
||||
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("CodeExec execution"):
|
||||
return
|
||||
|
||||
self.set_output("_ERROR", "Exception executing code: " + str(e))
|
||||
|
||||
return self.output()
|
||||
|
||||
@ -29,7 +29,7 @@ class CrawlerParam(ToolParamBase):
|
||||
super().__init__()
|
||||
self.proxy = None
|
||||
self.extract_type = "markdown"
|
||||
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.extract_type, "Type of content from the crawler", ['html', 'markdown', 'content'])
|
||||
|
||||
@ -47,18 +47,24 @@ class Crawler(ToolBase, ABC):
|
||||
result = asyncio.run(self.get_web(ans))
|
||||
|
||||
return Crawler.be_output(result)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return Crawler.be_output(f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
async def get_web(self, url):
|
||||
if self.check_if_canceled("Crawler async operation"):
|
||||
return
|
||||
|
||||
proxy = self._param.proxy if self._param.proxy else None
|
||||
async with AsyncWebCrawler(verbose=True, proxy=proxy) as crawler:
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
bypass_cache=True
|
||||
)
|
||||
|
||||
|
||||
if self.check_if_canceled("Crawler async operation"):
|
||||
return
|
||||
|
||||
if self._param.extract_type == 'html':
|
||||
return result.cleaned_html
|
||||
elif self._param.extract_type == 'markdown':
|
||||
|
||||
@ -46,11 +46,16 @@ class DeepL(ComponentBase, ABC):
|
||||
component_name = "DeepL"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
if self.check_if_canceled("DeepL processing"):
|
||||
return
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return DeepL.be_output("")
|
||||
|
||||
if self.check_if_canceled("DeepL processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
translator = deepl.Translator(self._param.auth_key)
|
||||
result = translator.translate_text(ans, source_lang=self._param.source_lang,
|
||||
@ -58,4 +63,6 @@ class DeepL(ComponentBase, ABC):
|
||||
|
||||
return DeepL.be_output(result.text)
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("DeepL processing"):
|
||||
return
|
||||
DeepL.be_output("**Error**:" + str(e))
|
||||
|
||||
@ -75,17 +75,30 @@ class DuckDuckGo(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("DuckDuckGo processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("DuckDuckGo processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
if kwargs.get("topic", "general") == "general":
|
||||
with DDGS() as ddgs:
|
||||
if self.check_if_canceled("DuckDuckGo processing"):
|
||||
return
|
||||
|
||||
# {'title': '', 'href': '', 'body': ''}
|
||||
duck_res = ddgs.text(kwargs["query"], max_results=self._param.top_n)
|
||||
|
||||
if self.check_if_canceled("DuckDuckGo processing"):
|
||||
return
|
||||
|
||||
self._retrieve_chunks(duck_res,
|
||||
get_title=lambda r: r["title"],
|
||||
get_url=lambda r: r.get("href", r.get("url")),
|
||||
@ -94,8 +107,15 @@ class DuckDuckGo(ToolBase, ABC):
|
||||
return self.output("formalized_content")
|
||||
else:
|
||||
with DDGS() as ddgs:
|
||||
if self.check_if_canceled("DuckDuckGo processing"):
|
||||
return
|
||||
|
||||
# {'date': '', 'title': '', 'body': '', 'url': '', 'image': '', 'source': ''}
|
||||
duck_res = ddgs.news(kwargs["query"], max_results=self._param.top_n)
|
||||
|
||||
if self.check_if_canceled("DuckDuckGo processing"):
|
||||
return
|
||||
|
||||
self._retrieve_chunks(duck_res,
|
||||
get_title=lambda r: r["title"],
|
||||
get_url=lambda r: r.get("href", r.get("url")),
|
||||
@ -103,6 +123,9 @@ class DuckDuckGo(ToolBase, ABC):
|
||||
self.set_output("json", duck_res)
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("DuckDuckGo processing"):
|
||||
return
|
||||
|
||||
last_e = e
|
||||
logging.exception(f"DuckDuckGo error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
@ -101,19 +101,27 @@ class Email(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Email processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("to_email"):
|
||||
self.set_output("success", False)
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("Email processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
# Parse JSON string passed from upstream
|
||||
email_data = kwargs
|
||||
|
||||
# Validate required fields
|
||||
if "to_email" not in email_data:
|
||||
return Email.be_output("Missing required field: to_email")
|
||||
self.set_output("_ERROR", "Missing required field: to_email")
|
||||
self.set_output("success", False)
|
||||
return False
|
||||
|
||||
# Create email object
|
||||
msg = MIMEMultipart('alternative')
|
||||
@ -133,6 +141,9 @@ class Email(ToolBase, ABC):
|
||||
# Connect to SMTP server and send
|
||||
logging.info(f"Connecting to SMTP server {self._param.smtp_server}:{self._param.smtp_port}")
|
||||
|
||||
if self.check_if_canceled("Email processing"):
|
||||
return
|
||||
|
||||
context = smtplib.ssl.create_default_context()
|
||||
with smtplib.SMTP(self._param.smtp_server, self._param.smtp_port) as server:
|
||||
server.ehlo()
|
||||
@ -149,6 +160,10 @@ class Email(ToolBase, ABC):
|
||||
|
||||
# Send email
|
||||
logging.info(f"Sending email to recipients: {recipients}")
|
||||
|
||||
if self.check_if_canceled("Email processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
server.send_message(msg, self._param.email, recipients)
|
||||
success = True
|
||||
|
||||
@ -81,6 +81,8 @@ class ExeSQL(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("ExeSQL processing"):
|
||||
return
|
||||
|
||||
def convert_decimals(obj):
|
||||
from decimal import Decimal
|
||||
@ -96,6 +98,9 @@ class ExeSQL(ToolBase, ABC):
|
||||
if not sql:
|
||||
raise Exception("SQL for `ExeSQL` MUST not be empty.")
|
||||
|
||||
if self.check_if_canceled("ExeSQL processing"):
|
||||
return
|
||||
|
||||
vars = self.get_input_elements_from_text(sql)
|
||||
args = {}
|
||||
for k, o in vars.items():
|
||||
@ -108,6 +113,9 @@ class ExeSQL(ToolBase, ABC):
|
||||
self.set_input_value(k, args[k])
|
||||
sql = self.string_format(sql, args)
|
||||
|
||||
if self.check_if_canceled("ExeSQL processing"):
|
||||
return
|
||||
|
||||
sqls = sql.split(";")
|
||||
if self._param.db_type in ["mysql", "mariadb"]:
|
||||
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
|
||||
@ -181,6 +189,10 @@ class ExeSQL(ToolBase, ABC):
|
||||
sql_res = []
|
||||
formalized_content = []
|
||||
for single_sql in sqls:
|
||||
if self.check_if_canceled("ExeSQL processing"):
|
||||
ibm_db.close(conn)
|
||||
return
|
||||
|
||||
single_sql = single_sql.replace("```", "").strip()
|
||||
if not single_sql:
|
||||
continue
|
||||
@ -190,6 +202,9 @@ class ExeSQL(ToolBase, ABC):
|
||||
rows = []
|
||||
row = ibm_db.fetch_assoc(stmt)
|
||||
while row and len(rows) < self._param.max_records:
|
||||
if self.check_if_canceled("ExeSQL processing"):
|
||||
ibm_db.close(conn)
|
||||
return
|
||||
rows.append(row)
|
||||
row = ibm_db.fetch_assoc(stmt)
|
||||
|
||||
@ -220,6 +235,11 @@ class ExeSQL(ToolBase, ABC):
|
||||
sql_res = []
|
||||
formalized_content = []
|
||||
for single_sql in sqls:
|
||||
if self.check_if_canceled("ExeSQL processing"):
|
||||
cursor.close()
|
||||
db.close()
|
||||
return
|
||||
|
||||
single_sql = single_sql.replace('```','')
|
||||
if not single_sql:
|
||||
continue
|
||||
@ -244,6 +264,9 @@ class ExeSQL(ToolBase, ABC):
|
||||
sql_res.append(convert_decimals(single_res.to_dict(orient='records')))
|
||||
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
|
||||
|
||||
cursor.close()
|
||||
db.close()
|
||||
|
||||
self.set_output("json", sql_res)
|
||||
self.set_output("formalized_content", "\n\n".join(formalized_content))
|
||||
return self.output("formalized_content")
|
||||
|
||||
@ -59,17 +59,27 @@ class GitHub(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("GitHub processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("GitHub processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
url = 'https://api.github.com/search/repositories?q=' + kwargs["query"] + '&sort=stars&order=desc&per_page=' + str(
|
||||
self._param.top_n)
|
||||
headers = {"Content-Type": "application/vnd.github+json", "X-GitHub-Api-Version": '2022-11-28'}
|
||||
response = requests.get(url=url, headers=headers).json()
|
||||
|
||||
if self.check_if_canceled("GitHub processing"):
|
||||
return
|
||||
|
||||
self._retrieve_chunks(response['items'],
|
||||
get_title=lambda r: r["name"],
|
||||
get_url=lambda r: r["html_url"],
|
||||
@ -77,6 +87,9 @@ class GitHub(ToolBase, ABC):
|
||||
self.set_output("json", response['items'])
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("GitHub processing"):
|
||||
return
|
||||
|
||||
last_e = e
|
||||
logging.exception(f"GitHub error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
@ -118,6 +118,9 @@ class Google(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Google processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("q"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
@ -132,8 +135,15 @@ class Google(ToolBase, ABC):
|
||||
}
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("Google processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
search = GoogleSearch(params).get_dict()
|
||||
|
||||
if self.check_if_canceled("Google processing"):
|
||||
return
|
||||
|
||||
self._retrieve_chunks(search["organic_results"],
|
||||
get_title=lambda r: r["title"],
|
||||
get_url=lambda r: r["link"],
|
||||
@ -142,6 +152,9 @@ class Google(ToolBase, ABC):
|
||||
self.set_output("json", search["organic_results"])
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("Google processing"):
|
||||
return
|
||||
|
||||
last_e = e
|
||||
logging.exception(f"Google error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
@ -65,15 +65,25 @@ class GoogleScholar(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("GoogleScholar processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("GoogleScholar processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
scholar_client = scholarly.search_pubs(kwargs["query"], patents=self._param.patents, year_low=self._param.year_low,
|
||||
year_high=self._param.year_high, sort_by=self._param.sort_by)
|
||||
|
||||
if self.check_if_canceled("GoogleScholar processing"):
|
||||
return
|
||||
|
||||
self._retrieve_chunks(scholar_client,
|
||||
get_title=lambda r: r['bib']['title'],
|
||||
get_url=lambda r: r["pub_url"],
|
||||
@ -82,6 +92,9 @@ class GoogleScholar(ToolBase, ABC):
|
||||
self.set_output("json", list(scholar_client))
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("GoogleScholar processing"):
|
||||
return
|
||||
|
||||
last_e = e
|
||||
logging.exception(f"GoogleScholar error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
@ -50,6 +50,9 @@ class Jin10(ComponentBase, ABC):
|
||||
component_name = "Jin10"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
if self.check_if_canceled("Jin10 processing"):
|
||||
return
|
||||
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
@ -58,6 +61,9 @@ class Jin10(ComponentBase, ABC):
|
||||
jin10_res = []
|
||||
headers = {'secret-key': self._param.secret_key}
|
||||
try:
|
||||
if self.check_if_canceled("Jin10 processing"):
|
||||
return
|
||||
|
||||
if self._param.type == "flash":
|
||||
params = {
|
||||
'category': self._param.flash_type,
|
||||
@ -69,6 +75,8 @@ class Jin10(ComponentBase, ABC):
|
||||
headers=headers, data=json.dumps(params))
|
||||
response = response.json()
|
||||
for i in response['data']:
|
||||
if self.check_if_canceled("Jin10 processing"):
|
||||
return
|
||||
jin10_res.append({"content": i['data']['content']})
|
||||
if self._param.type == "calendar":
|
||||
params = {
|
||||
@ -79,6 +87,8 @@ class Jin10(ComponentBase, ABC):
|
||||
headers=headers, data=json.dumps(params))
|
||||
|
||||
response = response.json()
|
||||
if self.check_if_canceled("Jin10 processing"):
|
||||
return
|
||||
jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()})
|
||||
if self._param.type == "symbols":
|
||||
params = {
|
||||
@ -90,8 +100,12 @@ class Jin10(ComponentBase, ABC):
|
||||
url='https://open-data-api.jin10.com/data-api/' + self._param.symbols_datatype + '?type=' + self._param.symbols_type,
|
||||
headers=headers, data=json.dumps(params))
|
||||
response = response.json()
|
||||
if self.check_if_canceled("Jin10 processing"):
|
||||
return
|
||||
if self._param.symbols_datatype == "symbols":
|
||||
for i in response['data']:
|
||||
if self.check_if_canceled("Jin10 processing"):
|
||||
return
|
||||
i['Commodity Code'] = i['c']
|
||||
i['Stock Exchange'] = i['e']
|
||||
i['Commodity Name'] = i['n']
|
||||
@ -99,6 +113,8 @@ class Jin10(ComponentBase, ABC):
|
||||
del i['c'], i['e'], i['n'], i['t']
|
||||
if self._param.symbols_datatype == "quotes":
|
||||
for i in response['data']:
|
||||
if self.check_if_canceled("Jin10 processing"):
|
||||
return
|
||||
i['Selling Price'] = i['a']
|
||||
i['Buying Price'] = i['b']
|
||||
i['Commodity Code'] = i['c']
|
||||
@ -120,8 +136,12 @@ class Jin10(ComponentBase, ABC):
|
||||
url='https://open-data-api.jin10.com/data-api/news',
|
||||
headers=headers, data=json.dumps(params))
|
||||
response = response.json()
|
||||
if self.check_if_canceled("Jin10 processing"):
|
||||
return
|
||||
jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()})
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("Jin10 processing"):
|
||||
return
|
||||
return Jin10.be_output("**ERROR**: " + str(e))
|
||||
|
||||
if not jin10_res:
|
||||
|
||||
@ -71,23 +71,40 @@ class PubMed(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("PubMed processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("PubMed processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
Entrez.email = self._param.email
|
||||
pubmedids = Entrez.read(Entrez.esearch(db='pubmed', retmax=self._param.top_n, term=kwargs["query"]))['IdList']
|
||||
|
||||
if self.check_if_canceled("PubMed processing"):
|
||||
return
|
||||
|
||||
pubmedcnt = ET.fromstring(re.sub(r'<(/?)b>|<(/?)i>', '', Entrez.efetch(db='pubmed', id=",".join(pubmedids),
|
||||
retmode="xml").read().decode("utf-8")))
|
||||
|
||||
if self.check_if_canceled("PubMed processing"):
|
||||
return
|
||||
|
||||
self._retrieve_chunks(pubmedcnt.findall("PubmedArticle"),
|
||||
get_title=lambda child: child.find("MedlineCitation").find("Article").find("ArticleTitle").text,
|
||||
get_url=lambda child: "https://pubmed.ncbi.nlm.nih.gov/" + child.find("MedlineCitation").find("PMID").text,
|
||||
get_content=lambda child: self._format_pubmed_content(child),)
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("PubMed processing"):
|
||||
return
|
||||
|
||||
last_e = e
|
||||
logging.exception(f"PubMed error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
@ -58,12 +58,18 @@ class QWeather(ComponentBase, ABC):
|
||||
component_name = "QWeather"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
if self.check_if_canceled("Qweather processing"):
|
||||
return
|
||||
|
||||
ans = self.get_input()
|
||||
ans = "".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return QWeather.be_output("")
|
||||
|
||||
try:
|
||||
if self.check_if_canceled("Qweather processing"):
|
||||
return
|
||||
|
||||
response = requests.get(
|
||||
url="https://geoapi.qweather.com/v2/city/lookup?location=" + ans + "&key=" + self._param.web_apikey).json()
|
||||
if response["code"] == "200":
|
||||
@ -71,16 +77,23 @@ class QWeather(ComponentBase, ABC):
|
||||
else:
|
||||
return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
|
||||
|
||||
if self.check_if_canceled("Qweather processing"):
|
||||
return
|
||||
|
||||
base_url = "https://api.qweather.com/v7/" if self._param.user_type == 'paid' else "https://devapi.qweather.com/v7/"
|
||||
|
||||
if self._param.type == "weather":
|
||||
url = base_url + "weather/" + self._param.time_period + "?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
|
||||
response = requests.get(url=url).json()
|
||||
if self.check_if_canceled("Qweather processing"):
|
||||
return
|
||||
if response["code"] == "200":
|
||||
if self._param.time_period == "now":
|
||||
return QWeather.be_output(str(response["now"]))
|
||||
else:
|
||||
qweather_res = [{"content": str(i) + "\n"} for i in response["daily"]]
|
||||
if self.check_if_canceled("Qweather processing"):
|
||||
return
|
||||
if not qweather_res:
|
||||
return QWeather.be_output("")
|
||||
|
||||
@ -92,6 +105,8 @@ class QWeather(ComponentBase, ABC):
|
||||
elif self._param.type == "indices":
|
||||
url = base_url + "indices/1d?type=0&location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
|
||||
response = requests.get(url=url).json()
|
||||
if self.check_if_canceled("Qweather processing"):
|
||||
return
|
||||
if response["code"] == "200":
|
||||
indices_res = response["daily"][0]["date"] + "\n" + "\n".join(
|
||||
[i["name"] + ": " + i["category"] + ", " + i["text"] for i in response["daily"]])
|
||||
@ -103,9 +118,13 @@ class QWeather(ComponentBase, ABC):
|
||||
elif self._param.type == "airquality":
|
||||
url = base_url + "air/now?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
|
||||
response = requests.get(url=url).json()
|
||||
if self.check_if_canceled("Qweather processing"):
|
||||
return
|
||||
if response["code"] == "200":
|
||||
return QWeather.be_output(str(response["now"]))
|
||||
else:
|
||||
return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("Qweather processing"):
|
||||
return
|
||||
return QWeather.be_output("**Error**" + str(e))
|
||||
|
||||
@ -82,8 +82,12 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
|
||||
kb_ids: list[str] = []
|
||||
for id in self._param.kb_ids:
|
||||
@ -122,7 +126,7 @@ class Retrieval(ToolBase, ABC):
|
||||
vars = self.get_input_elements_from_text(kwargs["query"])
|
||||
vars = {k:o["value"] for k,o in vars.items()}
|
||||
query = self.string_format(kwargs["query"], vars)
|
||||
|
||||
|
||||
doc_ids=[]
|
||||
if self._param.meta_data_filter!={}:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
@ -135,7 +139,7 @@ class Retrieval(ToolBase, ABC):
|
||||
elif self._param.meta_data_filter.get("method") == "manual":
|
||||
filters=self._param.meta_data_filter["manual"]
|
||||
for flt in filters:
|
||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z:0-9_.-]+|sys\.[a-z_]+)\} *\}*")
|
||||
pat = re.compile(self.variable_ref_patt)
|
||||
s = flt["value"]
|
||||
out_parts = []
|
||||
last = 0
|
||||
@ -184,9 +188,14 @@ class Retrieval(ToolBase, ABC):
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(query, kbs),
|
||||
)
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
|
||||
if self._param.toc_enhance:
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
|
||||
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
if self._param.use_kg:
|
||||
@ -195,6 +204,8 @@ class Retrieval(ToolBase, ABC):
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
else:
|
||||
@ -202,6 +213,8 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
if self._param.use_kg and kbs:
|
||||
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
if ck["content_with_weight"]:
|
||||
ck["content"] = ck["content_with_weight"]
|
||||
del ck["content_with_weight"]
|
||||
|
||||
@ -79,6 +79,9 @@ class SearXNG(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("SearXNG processing"):
|
||||
return
|
||||
|
||||
# Gracefully handle try-run without inputs
|
||||
query = kwargs.get("query")
|
||||
if not query or not isinstance(query, str) or not query.strip():
|
||||
@ -93,6 +96,9 @@ class SearXNG(ToolBase, ABC):
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("SearXNG processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
search_params = {
|
||||
'q': query,
|
||||
@ -110,6 +116,9 @@ class SearXNG(ToolBase, ABC):
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if self.check_if_canceled("SearXNG processing"):
|
||||
return
|
||||
|
||||
data = response.json()
|
||||
|
||||
if not data or not isinstance(data, dict):
|
||||
@ -121,6 +130,9 @@ class SearXNG(ToolBase, ABC):
|
||||
|
||||
results = results[:self._param.top_n]
|
||||
|
||||
if self.check_if_canceled("SearXNG processing"):
|
||||
return
|
||||
|
||||
self._retrieve_chunks(results,
|
||||
get_title=lambda r: r.get("title", ""),
|
||||
get_url=lambda r: r.get("url", ""),
|
||||
@ -130,10 +142,16 @@ class SearXNG(ToolBase, ABC):
|
||||
return self.output("formalized_content")
|
||||
|
||||
except requests.RequestException as e:
|
||||
if self.check_if_canceled("SearXNG processing"):
|
||||
return
|
||||
|
||||
last_e = f"Network error: {e}"
|
||||
logging.exception(f"SearXNG network error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("SearXNG processing"):
|
||||
return
|
||||
|
||||
last_e = str(e)
|
||||
logging.exception(f"SearXNG error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
@ -103,6 +103,9 @@ class TavilySearch(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("TavilySearch processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
@ -113,10 +116,16 @@ class TavilySearch(ToolBase, ABC):
|
||||
if fld not in kwargs:
|
||||
kwargs[fld] = getattr(self._param, fld)
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("TavilySearch processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
kwargs["include_images"] = False
|
||||
kwargs["include_raw_content"] = False
|
||||
res = self.tavily_client.search(**kwargs)
|
||||
if self.check_if_canceled("TavilySearch processing"):
|
||||
return
|
||||
|
||||
self._retrieve_chunks(res["results"],
|
||||
get_title=lambda r: r["title"],
|
||||
get_url=lambda r: r["url"],
|
||||
@ -125,6 +134,9 @@ class TavilySearch(ToolBase, ABC):
|
||||
self.set_output("json", res["results"])
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("TavilySearch processing"):
|
||||
return
|
||||
|
||||
last_e = e
|
||||
logging.exception(f"Tavily error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
@ -201,6 +213,9 @@ class TavilyExtract(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("TavilyExtract processing"):
|
||||
return
|
||||
|
||||
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
||||
last_e = None
|
||||
for fld in ["urls", "extract_depth", "format"]:
|
||||
@ -209,12 +224,21 @@ class TavilyExtract(ToolBase, ABC):
|
||||
if kwargs.get("urls") and isinstance(kwargs["urls"], str):
|
||||
kwargs["urls"] = kwargs["urls"].split(",")
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("TavilyExtract processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
kwargs["include_images"] = False
|
||||
res = self.tavily_client.extract(**kwargs)
|
||||
if self.check_if_canceled("TavilyExtract processing"):
|
||||
return
|
||||
|
||||
self.set_output("json", res["results"])
|
||||
return self.output("json")
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("TavilyExtract processing"):
|
||||
return
|
||||
|
||||
last_e = e
|
||||
logging.exception(f"Tavily error: {e}")
|
||||
if last_e:
|
||||
|
||||
@ -43,12 +43,18 @@ class TuShare(ComponentBase, ABC):
|
||||
component_name = "TuShare"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
if self.check_if_canceled("TuShare processing"):
|
||||
return
|
||||
|
||||
ans = self.get_input()
|
||||
ans = ",".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return TuShare.be_output("")
|
||||
|
||||
try:
|
||||
if self.check_if_canceled("TuShare processing"):
|
||||
return
|
||||
|
||||
tus_res = []
|
||||
params = {
|
||||
"api_name": "news",
|
||||
@ -58,12 +64,18 @@ class TuShare(ComponentBase, ABC):
|
||||
}
|
||||
response = requests.post(url="http://api.tushare.pro", data=json.dumps(params).encode('utf-8'))
|
||||
response = response.json()
|
||||
if self.check_if_canceled("TuShare processing"):
|
||||
return
|
||||
if response['code'] != 0:
|
||||
return TuShare.be_output(response['msg'])
|
||||
df = pd.DataFrame(response['data']['items'])
|
||||
df.columns = response['data']['fields']
|
||||
if self.check_if_canceled("TuShare processing"):
|
||||
return
|
||||
tus_res.append({"content": (df[df['content'].str.contains(self._param.keyword, case=False)]).to_markdown()})
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("TuShare processing"):
|
||||
return
|
||||
return TuShare.be_output("**ERROR**: " + str(e))
|
||||
|
||||
if not tus_res:
|
||||
|
||||
@ -70,19 +70,31 @@ class WenCai(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("WenCai processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("report", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("WenCai processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
wencai_res = []
|
||||
res = pywencai.get(query=kwargs["query"], query_type=self._param.query_type, perpage=self._param.top_n)
|
||||
if self.check_if_canceled("WenCai processing"):
|
||||
return
|
||||
|
||||
if isinstance(res, pd.DataFrame):
|
||||
wencai_res.append(res.to_markdown())
|
||||
elif isinstance(res, dict):
|
||||
for item in res.items():
|
||||
if self.check_if_canceled("WenCai processing"):
|
||||
return
|
||||
|
||||
if isinstance(item[1], list):
|
||||
wencai_res.append(item[0] + "\n" + pd.DataFrame(item[1]).to_markdown())
|
||||
elif isinstance(item[1], str):
|
||||
@ -100,6 +112,9 @@ class WenCai(ToolBase, ABC):
|
||||
self.set_output("report", "\n\n".join(wencai_res))
|
||||
return self.output("report")
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("WenCai processing"):
|
||||
return
|
||||
|
||||
last_e = e
|
||||
logging.exception(f"WenCai error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
@ -66,17 +66,26 @@ class Wikipedia(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Wikipedia processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("Wikipedia processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
wikipedia.set_lang(self._param.language)
|
||||
wiki_engine = wikipedia
|
||||
pages = []
|
||||
for p in wiki_engine.search(kwargs["query"], results=self._param.top_n):
|
||||
if self.check_if_canceled("Wikipedia processing"):
|
||||
return
|
||||
|
||||
try:
|
||||
pages.append(wikipedia.page(p))
|
||||
except Exception:
|
||||
@ -87,6 +96,9 @@ class Wikipedia(ToolBase, ABC):
|
||||
get_content=lambda r: r.summary)
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("Wikipedia processing"):
|
||||
return
|
||||
|
||||
last_e = e
|
||||
logging.exception(f"Wikipedia error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
@ -74,15 +74,24 @@ class YahooFinance(ToolBase, ABC):
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("YahooFinance processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("stock_code"):
|
||||
self.set_output("report", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if self.check_if_canceled("YahooFinance processing"):
|
||||
return
|
||||
|
||||
yohoo_res = []
|
||||
try:
|
||||
msft = yf.Ticker(kwargs["stock_code"])
|
||||
if self.check_if_canceled("YahooFinance processing"):
|
||||
return
|
||||
|
||||
if self._param.info:
|
||||
yohoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
|
||||
if self._param.history:
|
||||
@ -100,6 +109,9 @@ class YahooFinance(ToolBase, ABC):
|
||||
self.set_output("report", "\n\n".join(yohoo_res))
|
||||
return self.output("report")
|
||||
except Exception as e:
|
||||
if self.check_if_canceled("YahooFinance processing"):
|
||||
return
|
||||
|
||||
last_e = e
|
||||
logging.exception(f"YahooFinance error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
@ -466,10 +466,7 @@ def upload():
|
||||
if "run" in form_data.keys():
|
||||
if request.form.get("run").strip() == "1":
|
||||
try:
|
||||
info = {"run": 1, "progress": 0}
|
||||
info["progress_msg"] = ""
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
info = {"run": 1, "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
|
||||
DocumentService.update_by_id(doc["id"], info)
|
||||
# if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||
tenant_id = DocumentService.get_tenant_id(doc["id"])
|
||||
@ -726,8 +723,7 @@ def completion_faq():
|
||||
if "quote" not in req:
|
||||
req["quote"] = True
|
||||
|
||||
msg = []
|
||||
msg.append({"role": "user", "content": req["word"]})
|
||||
msg = [{"role": "user", "content": req["word"]}]
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
@ -156,7 +156,7 @@ def run():
|
||||
return get_json_result(data={"message_id": task_id})
|
||||
|
||||
try:
|
||||
canvas = Canvas(cvs.dsl, current_user.id, req["id"])
|
||||
canvas = Canvas(cvs.dsl, current_user.id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -168,8 +168,10 @@ def run():
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
canvas.cancel_task()
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(sse(), mimetype="text/event-stream")
|
||||
@ -177,6 +179,7 @@ def run():
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
resp.call_on_close(lambda: canvas.cancel_task())
|
||||
return resp
|
||||
|
||||
|
||||
@ -410,27 +413,27 @@ def test_db_connect():
|
||||
ibm_db.close(conn)
|
||||
return get_json_result(data="Database Connection Successful!")
|
||||
elif req["db_type"] == 'trino':
|
||||
def _parse_catalog_schema(db: str):
|
||||
if not db:
|
||||
def _parse_catalog_schema(db_name: str):
|
||||
if not db_name:
|
||||
return None, None
|
||||
if "." in db:
|
||||
c, s = db.split(".", 1)
|
||||
elif "/" in db:
|
||||
c, s = db.split("/", 1)
|
||||
if "." in db_name:
|
||||
catalog_name, schema_name = db_name.split(".", 1)
|
||||
elif "/" in db_name:
|
||||
catalog_name, schema_name = db_name.split("/", 1)
|
||||
else:
|
||||
c, s = db, "default"
|
||||
return c, s
|
||||
catalog_name, schema_name = db_name, "default"
|
||||
return catalog_name, schema_name
|
||||
try:
|
||||
import trino
|
||||
import os
|
||||
from trino.auth import BasicAuthentication
|
||||
except Exception:
|
||||
return server_error_response("Missing dependency 'trino'. Please install: pip install trino")
|
||||
except Exception as e:
|
||||
return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}")
|
||||
|
||||
catalog, schema = _parse_catalog_schema(req["database"])
|
||||
if not catalog:
|
||||
return server_error_response("For Trino, 'database' must be 'catalog.schema' or at least 'catalog'.")
|
||||
|
||||
|
||||
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
|
||||
|
||||
auth = None
|
||||
@ -479,7 +482,6 @@ def getlistversion(canvas_id):
|
||||
@login_required
|
||||
def getversion( version_id):
|
||||
try:
|
||||
|
||||
e, version = UserCanvasVersionService.get_by_id(version_id)
|
||||
if version:
|
||||
return get_json_result(data=version.to_dict())
|
||||
@ -546,11 +548,11 @@ def trace():
|
||||
cvs_id = request.args.get("canvas_id")
|
||||
msg_id = request.args.get("message_id")
|
||||
try:
|
||||
bin = REDIS_CONN.get(f"{cvs_id}-{msg_id}-logs")
|
||||
if not bin:
|
||||
binary = REDIS_CONN.get(f"{cvs_id}-{msg_id}-logs")
|
||||
if not binary:
|
||||
return get_json_result(data={})
|
||||
|
||||
return get_json_result(data=json.loads(bin.encode("utf-8")))
|
||||
return get_json_result(data=json.loads(binary.encode("utf-8")))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@ -604,4 +606,4 @@ def download():
|
||||
id = request.args.get("id")
|
||||
created_by = request.args.get("created_by")
|
||||
blob = FileService.get_blob(created_by, id)
|
||||
return flask.make_response(blob)
|
||||
return flask.make_response(blob)
|
||||
|
||||
@ -13,16 +13,26 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from html import escape
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask import make_response, request
|
||||
from flask_login import current_user, login_required
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
|
||||
from api.db import InputType
|
||||
from api.db.services.connector_service import ConnectorService, Connector2KbService, SyncLogsService
|
||||
from api.utils.api_utils import get_json_result, validate_request, get_data_error_result
|
||||
from common.misc_utils import get_uuid
|
||||
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, validate_request
|
||||
from common.constants import RetCode, TaskStatus
|
||||
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, DocumentSource
|
||||
from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
|
||||
from common.misc_utils import get_uuid
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@ -42,8 +52,8 @@ def set_connector():
|
||||
"config": req["config"],
|
||||
"refresh_freq": int(req.get("refresh_freq", 30)),
|
||||
"prune_freq": int(req.get("prune_freq", 720)),
|
||||
"timeout_secs": int(req.get("timeout_secs", 60*29)),
|
||||
"status": TaskStatus.SCHEDULE
|
||||
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
|
||||
"status": TaskStatus.SCHEDULE,
|
||||
}
|
||||
conn["status"] = TaskStatus.SCHEDULE
|
||||
ConnectorService.save(**conn)
|
||||
@ -88,14 +98,14 @@ def resume(connector_id):
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/<connector_id>/link", methods=["POST"]) # noqa: F821
|
||||
@validate_request("kb_ids")
|
||||
@manager.route("/<connector_id>/rebuild", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
def link_kb(connector_id):
|
||||
@validate_request("kb_id")
|
||||
def rebuild(connector_id):
|
||||
req = request.json
|
||||
errors = Connector2KbService.link_kb(connector_id, req["kb_ids"], current_user.id)
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
|
||||
if err:
|
||||
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@ -104,4 +114,182 @@ def link_kb(connector_id):
|
||||
def rm_connector(connector_id):
|
||||
ConnectorService.resume(connector_id, TaskStatus.CANCEL)
|
||||
ConnectorService.delete_by_id(connector_id)
|
||||
return get_json_result(data=True)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
GOOGLE_WEB_FLOW_STATE_PREFIX = "google_drive_web_flow_state"
|
||||
GOOGLE_WEB_FLOW_RESULT_PREFIX = "google_drive_web_flow_result"
|
||||
WEB_FLOW_TTL_SECS = 15 * 60
|
||||
|
||||
|
||||
def _web_state_cache_key(flow_id: str) -> str:
|
||||
return f"{GOOGLE_WEB_FLOW_STATE_PREFIX}:{flow_id}"
|
||||
|
||||
|
||||
def _web_result_cache_key(flow_id: str) -> str:
|
||||
return f"{GOOGLE_WEB_FLOW_RESULT_PREFIX}:{flow_id}"
|
||||
|
||||
|
||||
def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]:
|
||||
if isinstance(payload, dict):
|
||||
return payload
|
||||
try:
|
||||
return json.loads(payload)
|
||||
except json.JSONDecodeError as exc: # pragma: no cover - defensive
|
||||
raise ValueError("Invalid Google credentials JSON.") from exc
|
||||
|
||||
|
||||
def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]:
|
||||
web_section = credentials.get("web")
|
||||
if not isinstance(web_section, dict):
|
||||
raise ValueError("Google OAuth JSON must include a 'web' client configuration to use browser-based authorization.")
|
||||
return {"web": web_section}
|
||||
|
||||
|
||||
def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
|
||||
status = "success" if success else "error"
|
||||
auto_close = "window.close();" if success else ""
|
||||
escaped_message = escape(message)
|
||||
payload_json = json.dumps(
|
||||
{
|
||||
"type": "ragflow-google-drive-oauth",
|
||||
"status": status,
|
||||
"flowId": flow_id or "",
|
||||
"message": message,
|
||||
}
|
||||
)
|
||||
html = GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE.format(
|
||||
heading="Authorization complete" if success else "Authorization failed",
|
||||
message=escaped_message,
|
||||
payload_json=payload_json,
|
||||
auto_close=auto_close,
|
||||
)
|
||||
response = make_response(html, 200)
|
||||
response.headers["Content-Type"] = "text/html; charset=utf-8"
|
||||
return response
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("credentials")
|
||||
def start_google_drive_web_oauth():
|
||||
if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI:
|
||||
return get_json_result(
|
||||
code=RetCode.SERVER_ERROR,
|
||||
message="Google Drive OAuth redirect URI is not configured on the server.",
|
||||
)
|
||||
|
||||
req = request.json or {}
|
||||
raw_credentials = req.get("credentials", "")
|
||||
try:
|
||||
credentials = _load_credentials(raw_credentials)
|
||||
except ValueError as exc:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))
|
||||
|
||||
if credentials.get("refresh_token"):
|
||||
return get_json_result(
|
||||
code=RetCode.ARGUMENT_ERROR,
|
||||
message="Uploaded credentials already include a refresh token.",
|
||||
)
|
||||
|
||||
try:
|
||||
client_config = _get_web_client_config(credentials)
|
||||
except ValueError as exc:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))
|
||||
|
||||
flow_id = str(uuid.uuid4())
|
||||
try:
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||
authorization_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
include_granted_scopes="true",
|
||||
prompt="consent",
|
||||
state=flow_id,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.exception("Failed to create Google OAuth flow: %s", exc)
|
||||
return get_json_result(
|
||||
code=RetCode.SERVER_ERROR,
|
||||
message="Failed to initialize Google OAuth flow. Please verify the uploaded client configuration.",
|
||||
)
|
||||
|
||||
cache_payload = {
|
||||
"user_id": current_user.id,
|
||||
"client_config": client_config,
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
REDIS_CONN.set_obj(_web_state_cache_key(flow_id), cache_payload, WEB_FLOW_TTL_SECS)
|
||||
|
||||
return get_json_result(
|
||||
data={
|
||||
"flow_id": flow_id,
|
||||
"authorization_url": authorization_url,
|
||||
"expires_in": WEB_FLOW_TTL_SECS,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||
def google_drive_web_oauth_callback():
|
||||
state_id = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
error_description = request.args.get("error_description") or error
|
||||
|
||||
if not state_id:
|
||||
return _render_web_oauth_popup("", False, "Missing OAuth state parameter.")
|
||||
|
||||
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id))
|
||||
if not state_cache:
|
||||
return _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.")
|
||||
|
||||
state_obj = json.loads(state_cache)
|
||||
client_config = state_obj.get("client_config")
|
||||
if not client_config:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.")
|
||||
|
||||
if error:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.")
|
||||
|
||||
code = request.args.get("code")
|
||||
if not code:
|
||||
return _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.")
|
||||
|
||||
try:
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||
flow.fetch_token(code=code)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.")
|
||||
|
||||
creds_json = flow.credentials.to_json()
|
||||
result_payload = {
|
||||
"user_id": state_obj.get("user_id"),
|
||||
"credentials": creds_json,
|
||||
}
|
||||
REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
|
||||
return _render_web_oauth_popup(state_id, True, "Authorization completed successfully.")
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("flow_id")
|
||||
def poll_google_drive_web_result():
|
||||
req = request.json or {}
|
||||
flow_id = req.get("flow_id")
|
||||
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id))
|
||||
if not cache_raw:
|
||||
return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.")
|
||||
|
||||
result = json.loads(cache_raw)
|
||||
if result.get("user_id") != current_user.id:
|
||||
return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.")
|
||||
|
||||
REDIS_CONN.delete(_web_result_cache_key(flow_id))
|
||||
return get_json_result(data={"credentials": result.get("credentials")})
|
||||
|
||||
@ -260,6 +260,8 @@ def list_docs():
|
||||
for doc_item in docs:
|
||||
if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
|
||||
doc_item["thumbnail"] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}"
|
||||
if doc_item.get("source_type"):
|
||||
doc_item["source_type"] = doc_item["source_type"].split("/")[0]
|
||||
|
||||
return get_json_result(data={"total": tol, "docs": docs})
|
||||
except Exception as e:
|
||||
|
||||
@ -122,11 +122,12 @@ def update():
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase rename)!")
|
||||
errors = Connector2KbService.link_connectors(kb.id, [conn["id"] for conn in connectors], current_user.id)
|
||||
errors = Connector2KbService.link_connectors(kb.id, [conn for conn in connectors], current_user.id)
|
||||
if errors:
|
||||
logging.error("Link KB errors: ", errors)
|
||||
kb = kb.to_dict()
|
||||
kb.update(req)
|
||||
kb["connectors"] = connectors
|
||||
|
||||
return get_json_result(data=kb)
|
||||
except Exception as e:
|
||||
|
||||
@ -33,7 +33,7 @@ from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
||||
def factories():
|
||||
try:
|
||||
fac = get_allowed_llm_factories()
|
||||
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]
|
||||
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI", "Builtin"]]
|
||||
llms = LLMService.get_all()
|
||||
mdl_types = {}
|
||||
for m in llms:
|
||||
@ -348,7 +348,7 @@ def list_app():
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value])
|
||||
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
|
||||
llms = LLMService.get_all()
|
||||
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.llm_name + "@" + m.fid) in status]
|
||||
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.fid == 'Builtin' or (m.llm_name + "@" + m.fid) in status)]
|
||||
for m in llms:
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
|
||||
if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"] == LLMType.EMBEDDING and m["fid"] == "Builtin" and m["llm_name"] == os.getenv("TEI_MODEL", ""):
|
||||
@ -358,7 +358,7 @@ def list_app():
|
||||
for o in objs:
|
||||
if o.llm_name + "@" + o.llm_factory in llm_set:
|
||||
continue
|
||||
llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
|
||||
llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True, "status": StatusEnum.VALID.value})
|
||||
|
||||
res = {}
|
||||
for m in llms:
|
||||
|
||||
@ -70,4 +70,7 @@ class PipelineTaskType(StrEnum):
|
||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
|
||||
|
||||
|
||||
PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES = {PipelineTaskType.RAPTOR.lower(), PipelineTaskType.GRAPH_RAG.lower(), PipelineTaskType.MINDMAP.lower()}
|
||||
|
||||
|
||||
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
||||
|
||||
@ -669,6 +669,7 @@ class LLMFactories(DataBaseModel):
|
||||
name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True)
|
||||
logo = TextField(null=True, help_text="llm logo base64")
|
||||
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR", index=True)
|
||||
rank = IntegerField(default=0, index=False)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||
|
||||
def __str__(self):
|
||||
@ -1064,6 +1065,7 @@ class Connector2Kb(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
connector_id = CharField(max_length=32, null=False, index=True)
|
||||
kb_id = CharField(max_length=32, null=False, index=True)
|
||||
auto_parse = CharField(max_length=1, null=False, default="1", index=False)
|
||||
|
||||
class Meta:
|
||||
db_table = "connector2kb"
|
||||
@ -1282,4 +1284,12 @@ def migrate_db():
|
||||
migrate(migrator.add_column("tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
|
||||
except Exception:
|
||||
pass
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
@ -89,13 +89,7 @@ def init_superuser():
|
||||
|
||||
|
||||
def init_llm_factory():
|
||||
try:
|
||||
LLMService.filter_delete([(LLM.fid == "MiniMax" or LLM.fid == "Minimax")])
|
||||
LLMService.filter_delete([(LLM.fid == "cohere")])
|
||||
LLMFactoriesService.filter_delete([LLMFactories.name == "cohere"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
LLMFactoriesService.filter_delete([1 == 1])
|
||||
factory_llm_infos = settings.FACTORY_LLM_INFOS
|
||||
for factory_llm_info in factory_llm_infos:
|
||||
info = deepcopy(factory_llm_info)
|
||||
|
||||
@ -67,6 +67,7 @@ class UserCanvasService(CommonService):
|
||||
# will get all permitted agents, be cautious
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.avatar,
|
||||
cls.model.title,
|
||||
cls.model.permission,
|
||||
cls.model.canvas_type,
|
||||
|
||||
@ -90,7 +90,7 @@ class CommonService:
|
||||
else:
|
||||
query_records = cls.model.select()
|
||||
if reverse is not None:
|
||||
if not order_by or not hasattr(cls, order_by):
|
||||
if not order_by or not hasattr(cls.model, order_by):
|
||||
order_by = "create_time"
|
||||
if reverse is True:
|
||||
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
|
||||
|
||||
@ -54,7 +54,6 @@ class ConnectorService(CommonService):
|
||||
SyncLogsService.update_by_id(task["id"], task)
|
||||
ConnectorService.update_by_id(connector_id, {"status": status})
|
||||
|
||||
|
||||
@classmethod
|
||||
def list(cls, tenant_id):
|
||||
fields = [
|
||||
@ -67,6 +66,17 @@ class ConnectorService(CommonService):
|
||||
cls.model.tenant_id == tenant_id
|
||||
).dicts())
|
||||
|
||||
@classmethod
|
||||
def rebuild(cls, kb_id:str, connector_id: str, tenant_id:str):
|
||||
e, conn = cls.get_by_id(connector_id)
|
||||
if not e:
|
||||
return
|
||||
SyncLogsService.filter_delete([SyncLogs.connector_id==connector_id, SyncLogs.kb_id==kb_id])
|
||||
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}", kb_id=kb_id)
|
||||
err = FileService.delete_docs([d.id for d in docs], tenant_id)
|
||||
SyncLogsService.schedule(connector_id, kb_id, reindex=True)
|
||||
return err
|
||||
|
||||
|
||||
class SyncLogsService(CommonService):
|
||||
model = SyncLogs
|
||||
@ -91,6 +101,7 @@ class SyncLogsService(CommonService):
|
||||
Connector.timeout_secs,
|
||||
Knowledgebase.name.alias("kb_name"),
|
||||
Knowledgebase.avatar.alias("kb_avatar"),
|
||||
Connector2Kb.auto_parse,
|
||||
cls.model.from_beginning.alias("reindex"),
|
||||
cls.model.status
|
||||
]
|
||||
@ -179,7 +190,7 @@ class SyncLogsService(CommonService):
|
||||
.where(cls.model.id == id).execute()
|
||||
|
||||
@classmethod
|
||||
def duplicate_and_parse(cls, kb, docs, tenant_id, src):
|
||||
def duplicate_and_parse(cls, kb, docs, tenant_id, src, auto_parse=True):
|
||||
if not docs:
|
||||
return None
|
||||
|
||||
@ -191,14 +202,17 @@ class SyncLogsService(CommonService):
|
||||
return self.blob
|
||||
|
||||
errs = []
|
||||
files = [FileObj(filename=d["semantic_identifier"]+f".{d['extension']}", blob=d["blob"]) for d in docs]
|
||||
files = [FileObj(filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs]
|
||||
doc_ids = []
|
||||
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
|
||||
errs.extend(err)
|
||||
|
||||
kb_table_num_map = {}
|
||||
for doc, _ in doc_blob_pairs:
|
||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||
doc_ids.append(doc["id"])
|
||||
if not auto_parse or auto_parse == "0":
|
||||
continue
|
||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||
|
||||
return errs, doc_ids
|
||||
|
||||
@ -214,43 +228,21 @@ class Connector2KbService(CommonService):
|
||||
model = Connector2Kb
|
||||
|
||||
@classmethod
|
||||
def link_kb(cls, conn_id:str, kb_ids: list[str], tenant_id:str):
|
||||
arr = cls.query(connector_id=conn_id)
|
||||
old_kb_ids = [a.kb_id for a in arr]
|
||||
for kb_id in kb_ids:
|
||||
if kb_id in old_kb_ids:
|
||||
continue
|
||||
cls.save(**{
|
||||
"id": get_uuid(),
|
||||
"connector_id": conn_id,
|
||||
"kb_id": kb_id
|
||||
})
|
||||
SyncLogsService.schedule(conn_id, kb_id, reindex=True)
|
||||
|
||||
errs = []
|
||||
e, conn = ConnectorService.get_by_id(conn_id)
|
||||
for kb_id in old_kb_ids:
|
||||
if kb_id in kb_ids:
|
||||
continue
|
||||
cls.filter_delete([cls.model.kb_id==kb_id, cls.model.connector_id==conn_id])
|
||||
SyncLogsService.filter_update([SyncLogs.connector_id==conn_id, SyncLogs.kb_id==kb_id, SyncLogs.status==TaskStatus.SCHEDULE], {"status": TaskStatus.CANCEL})
|
||||
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}")
|
||||
err = FileService.delete_docs([d.id for d in docs], tenant_id)
|
||||
if err:
|
||||
errs.append(err)
|
||||
return "\n".join(errs)
|
||||
|
||||
@classmethod
|
||||
def link_connectors(cls, kb_id:str, connector_ids: list[str], tenant_id:str):
|
||||
def link_connectors(cls, kb_id:str, connectors: list[dict], tenant_id:str):
|
||||
arr = cls.query(kb_id=kb_id)
|
||||
old_conn_ids = [a.connector_id for a in arr]
|
||||
for conn_id in connector_ids:
|
||||
connector_ids = []
|
||||
for conn in connectors:
|
||||
conn_id = conn["id"]
|
||||
connector_ids.append(conn_id)
|
||||
if conn_id in old_conn_ids:
|
||||
cls.filter_update([cls.model.connector_id==conn_id, cls.model.kb_id==kb_id], {"auto_parse": conn.get("auto_parse", "1")})
|
||||
continue
|
||||
cls.save(**{
|
||||
"id": get_uuid(),
|
||||
"connector_id": conn_id,
|
||||
"kb_id": kb_id
|
||||
"kb_id": kb_id,
|
||||
"auto_parse": conn.get("auto_parse", "1")
|
||||
})
|
||||
SyncLogsService.schedule(conn_id, kb_id, reindex=True)
|
||||
|
||||
@ -260,11 +252,15 @@ class Connector2KbService(CommonService):
|
||||
continue
|
||||
cls.filter_delete([cls.model.kb_id==kb_id, cls.model.connector_id==conn_id])
|
||||
e, conn = ConnectorService.get_by_id(conn_id)
|
||||
SyncLogsService.filter_update([SyncLogs.connector_id==conn_id, SyncLogs.kb_id==kb_id, SyncLogs.status==TaskStatus.SCHEDULE], {"status": TaskStatus.CANCEL})
|
||||
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}")
|
||||
err = FileService.delete_docs([d.id for d in docs], tenant_id)
|
||||
if err:
|
||||
errs.append(err)
|
||||
if not e:
|
||||
continue
|
||||
#SyncLogsService.filter_delete([SyncLogs.connector_id==conn_id, SyncLogs.kb_id==kb_id])
|
||||
# Do not delete docs while unlinking.
|
||||
SyncLogsService.filter_update([SyncLogs.connector_id==conn_id, SyncLogs.kb_id==kb_id, SyncLogs.status.in_([TaskStatus.SCHEDULE, TaskStatus.RUNNING])], {"status": TaskStatus.CANCEL})
|
||||
#docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}")
|
||||
#err = FileService.delete_docs([d.id for d in docs], tenant_id)
|
||||
#if err:
|
||||
# errs.append(err)
|
||||
return "\n".join(errs)
|
||||
|
||||
@classmethod
|
||||
@ -273,6 +269,7 @@ class Connector2KbService(CommonService):
|
||||
Connector.id,
|
||||
Connector.source,
|
||||
Connector.name,
|
||||
cls.model.auto_parse,
|
||||
Connector.status
|
||||
]
|
||||
return list(cls.model.select(*fields)\
|
||||
@ -282,3 +279,5 @@ class Connector2KbService(CommonService):
|
||||
).dicts()
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@ -619,7 +619,12 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
|
||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
|
||||
sys_prompt = """
|
||||
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
||||
Ensure that:
|
||||
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
|
||||
2. Write only the SQL, no explanations or additional text.
|
||||
"""
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
Table of database fields are as follows:
|
||||
@ -640,6 +645,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
sql = re.sub(r".*select ", "select ", sql.lower())
|
||||
sql = re.sub(r" +", " ", sql)
|
||||
sql = re.sub(r"([;;]|```).*", "", sql)
|
||||
sql = re.sub(r"&", "and", sql)
|
||||
if sql[: len("select ")] != "select ":
|
||||
return None, None
|
||||
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
|
||||
|
||||
@ -27,7 +27,7 @@ import xxhash
|
||||
from peewee import fn, Case, JOIN
|
||||
|
||||
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
|
||||
from api.db import FileType, UserTenantRole, CanvasCategory
|
||||
from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES, FileType, UserTenantRole, CanvasCategory
|
||||
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
|
||||
User
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
@ -372,12 +372,16 @@ class DocumentService(CommonService):
|
||||
def get_unfinished_docs(cls):
|
||||
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
|
||||
cls.model.run, cls.model.parser_id]
|
||||
unfinished_task_query = Task.select(Task.doc_id).where(
|
||||
(Task.progress >= 0) & (Task.progress < 1)
|
||||
)
|
||||
|
||||
docs = cls.model.select(*fields) \
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress < 1,
|
||||
cls.model.progress > 0)
|
||||
(((cls.model.progress < 1) & (cls.model.progress > 0)) |
|
||||
(cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
|
||||
return list(docs.dicts())
|
||||
|
||||
@classmethod
|
||||
@ -619,13 +623,17 @@ class DocumentService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def begin2parse(cls, docid):
|
||||
cls.update_by_id(
|
||||
docid, {"progress": random.random() * 1 / 100.,
|
||||
"progress_msg": "Task is queued...",
|
||||
"process_begin_at": get_format_time(),
|
||||
"run": TaskStatus.RUNNING.value
|
||||
})
|
||||
def begin2parse(cls, doc_id, keep_progress=False):
|
||||
info = {
|
||||
"progress_msg": "Task is queued...",
|
||||
"process_begin_at": get_format_time(),
|
||||
}
|
||||
if not keep_progress:
|
||||
info["progress"] = random.random() * 1 / 100.
|
||||
info["run"] = TaskStatus.RUNNING.value
|
||||
# keep the doc in DONE state when keep_progress=True for GraphRAG, RAPTOR and Mindmap tasks
|
||||
|
||||
cls.update_by_id(doc_id, info)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@ -684,8 +692,13 @@ class DocumentService(CommonService):
|
||||
bad = 0
|
||||
e, doc = DocumentService.get_by_id(d["id"])
|
||||
status = doc.run # TaskStatus.RUNNING.value
|
||||
doc_progress = doc.progress if doc and doc.progress else 0.0
|
||||
special_task_running = False
|
||||
priority = 0
|
||||
for t in tsks:
|
||||
task_type = (t.task_type or "").lower()
|
||||
if task_type in PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES:
|
||||
special_task_running = True
|
||||
if 0 <= t.progress < 1:
|
||||
finished = False
|
||||
if t.progress == -1:
|
||||
@ -702,13 +715,15 @@ class DocumentService(CommonService):
|
||||
prg = 1
|
||||
status = TaskStatus.DONE.value
|
||||
|
||||
# only for special task and parsed docs and unfinised
|
||||
freeze_progress = special_task_running and doc_progress >= 1 and not finished
|
||||
msg = "\n".join(sorted(msg))
|
||||
info = {
|
||||
"process_duration": datetime.timestamp(
|
||||
datetime.now()) -
|
||||
d["process_begin_at"].timestamp(),
|
||||
"run": status}
|
||||
if prg != 0:
|
||||
if prg != 0 and not freeze_progress:
|
||||
info["progress"] = prg
|
||||
if msg:
|
||||
info["progress_msg"] = msg
|
||||
@ -755,6 +770,14 @@ class DocumentService(CommonService):
|
||||
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
|
||||
.scalar()
|
||||
)
|
||||
downloaded = (
|
||||
cls.model.select(fn.COUNT(1))
|
||||
.where(
|
||||
cls.model.kb_id == kb_id,
|
||||
cls.model.source_type != "local"
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
row = (
|
||||
cls.model.select(
|
||||
@ -791,6 +814,7 @@ class DocumentService(CommonService):
|
||||
"finished": int(row["finished"]),
|
||||
"failed": int(row["failed"]),
|
||||
"cancelled": int(cancelled),
|
||||
"downloaded": int(downloaded)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -837,7 +861,7 @@ def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", d
|
||||
"to_page": 100000000,
|
||||
"task_type": ty,
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
|
||||
"begin_at": datetime.now(),
|
||||
"begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
|
||||
task = new_task()
|
||||
@ -849,7 +873,7 @@ def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", d
|
||||
|
||||
task["doc_id"] = fake_doc_id
|
||||
task["doc_ids"] = doc_ids
|
||||
DocumentService.begin2parse(sample_doc_id["id"])
|
||||
DocumentService.begin2parse(sample_doc_id["id"], keep_progress=True)
|
||||
assert REDIS_CONN.queue_product(settings.get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
||||
return task["id"]
|
||||
|
||||
@ -1003,4 +1027,3 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||
|
||||
return [d["id"] for d, _ in files]
|
||||
|
||||
|
||||
@ -201,6 +201,7 @@ class KnowledgebaseService(CommonService):
|
||||
# will get all permitted kb, be cautious.
|
||||
fields = [
|
||||
cls.model.name,
|
||||
cls.model.avatar,
|
||||
cls.model.language,
|
||||
cls.model.permission,
|
||||
cls.model.doc_num,
|
||||
|
||||
@ -159,7 +159,7 @@ class PipelineOperationLogService(CommonService):
|
||||
document_name=document.name,
|
||||
document_suffix=document.suffix,
|
||||
document_type=document.type,
|
||||
source_from="", # TODO: add in the future
|
||||
source_from=document.source_type.split("/")[0],
|
||||
progress=document.progress,
|
||||
progress_msg=document.progress_msg,
|
||||
process_begin_at=document.process_begin_at,
|
||||
|
||||
@ -625,7 +625,7 @@ async def is_strong_enough(chat_model, embedding_model):
|
||||
|
||||
|
||||
def get_allowed_llm_factories() -> list:
|
||||
factories = list(LLMFactoriesService.get_all())
|
||||
factories = list(LLMFactoriesService.get_all(reverse=True, order_by="rank"))
|
||||
if settings.ALLOWED_LLM_FACTORIES is None:
|
||||
return factories
|
||||
|
||||
|
||||
@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
import requests
|
||||
from timeit import default_timer as timer
|
||||
@ -146,7 +148,10 @@ def get_redis_info():
|
||||
def check_ragflow_server_alive():
|
||||
start_time = timer()
|
||||
try:
|
||||
response = requests.get(f'http://{settings.HOST_IP}:{settings.HOST_PORT}/v1/system/ping')
|
||||
url = f'http://{settings.HOST_IP}:{settings.HOST_PORT}/v1/system/ping'
|
||||
if '0.0.0.0' in url:
|
||||
url = url.replace('0.0.0.0', '127.0.0.1')
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
|
||||
else:
|
||||
@ -158,6 +163,26 @@ def check_ragflow_server_alive():
|
||||
}
|
||||
|
||||
|
||||
def check_task_executor_alive():
|
||||
task_executor_heartbeats = {}
|
||||
try:
|
||||
task_executors = REDIS_CONN.smembers("TASKEXE")
|
||||
now = datetime.now().timestamp()
|
||||
for task_executor_id in task_executors:
|
||||
heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60 * 30, now)
|
||||
heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats]
|
||||
task_executor_heartbeats[task_executor_id] = heartbeats
|
||||
if task_executor_heartbeats:
|
||||
return {"status": "alive", "message": task_executor_heartbeats}
|
||||
else:
|
||||
return {"status": "timeout", "message": "Not found any task executor."}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "timeout",
|
||||
"message": f"error: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def run_health_checks() -> tuple[dict, bool]:
|
||||
result: dict[str, str | dict] = {}
|
||||
|
||||
|
||||
@ -113,7 +113,7 @@ class FileSource(StrEnum):
|
||||
DISCORD = "discord"
|
||||
CONFLUENCE = "confluence"
|
||||
GMAIL = "gmail"
|
||||
GOOGLE_DRIVER = "google_driver"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
JIRA = "jira"
|
||||
SHAREPOINT = "sharepoint"
|
||||
SLACK = "slack"
|
||||
|
||||
@ -10,7 +10,7 @@ from .notion_connector import NotionConnector
|
||||
from .confluence_connector import ConfluenceConnector
|
||||
from .discord_connector import DiscordConnector
|
||||
from .dropbox_connector import DropboxConnector
|
||||
from .google_drive_connector import GoogleDriveConnector
|
||||
from .google_drive.connector import GoogleDriveConnector
|
||||
from .jira_connector import JiraConnector
|
||||
from .sharepoint_connector import SharePointConnector
|
||||
from .teams_connector import TeamsConnector
|
||||
@ -47,4 +47,4 @@ __all__ = [
|
||||
"CredentialExpiredError",
|
||||
"InsufficientPermissionsError",
|
||||
"UnexpectedValidationError"
|
||||
]
|
||||
]
|
||||
|
||||
@ -42,6 +42,8 @@ class DocumentSource(str, Enum):
|
||||
OCI_STORAGE = "oci_storage"
|
||||
SLACK = "slack"
|
||||
CONFLUENCE = "confluence"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
GMAIL = "gmail"
|
||||
DISCORD = "discord"
|
||||
|
||||
|
||||
@ -100,22 +102,6 @@ NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = (
|
||||
== "true"
|
||||
)
|
||||
|
||||
# This is the Oauth token
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
# This is the service account key
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
# The email saved for both auth types
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
|
||||
USER_FIELDS = "nextPageToken, users(primaryEmail)"
|
||||
|
||||
# Error message substrings
|
||||
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
|
||||
|
||||
SCOPE_INSTRUCTIONS = (
|
||||
"You have upgraded RAGFlow without updating the Google Auth scopes. "
|
||||
)
|
||||
|
||||
SLIM_BATCH_SIZE = 100
|
||||
|
||||
# Notion API constants
|
||||
@ -184,6 +170,14 @@ CONFLUENCE_TIMEZONE_OFFSET = float(
|
||||
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
|
||||
)
|
||||
|
||||
CONFLUENCE_SYNC_TIME_BUFFER_SECONDS = int(
|
||||
os.environ.get("CONFLUENCE_SYNC_TIME_BUFFER_SECONDS", ONE_DAY)
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||
@ -200,6 +194,7 @@ OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", ""
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
)
|
||||
GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = os.environ.get("GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/google-drive/oauth/web/callback")
|
||||
|
||||
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
@ -20,6 +20,7 @@ from requests.exceptions import HTTPError
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource, CONTINUE_ON_CONNECTOR_FAILURE, \
|
||||
CONFLUENCE_CONNECTOR_LABELS_TO_SKIP, CONFLUENCE_TIMEZONE_OFFSET, CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE, \
|
||||
CONFLUENCE_SYNC_TIME_BUFFER_SECONDS, \
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID, OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET, _DEFAULT_PAGINATION_LIMIT, \
|
||||
_PROBLEMATIC_EXPANSIONS, _REPLACEMENT_EXPANSIONS, _USER_NOT_FOUND, _COMMENT_EXPANSION_FIELDS, \
|
||||
_ATTACHMENT_EXPANSION_FIELDS, _PAGE_EXPANSION_FIELDS, ONE_DAY, ONE_HOUR, _RESTRICTIONS_EXPANSION_FIELDS, \
|
||||
@ -1289,6 +1290,7 @@ class ConfluenceConnector(
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
|
||||
time_buffer_seconds: int = CONFLUENCE_SYNC_TIME_BUFFER_SECONDS,
|
||||
scoped_token: bool = False,
|
||||
) -> None:
|
||||
self.wiki_base = wiki_base
|
||||
@ -1300,6 +1302,7 @@ class ConfluenceConnector(
|
||||
self.batch_size = batch_size
|
||||
self.labels_to_skip = labels_to_skip
|
||||
self.timezone_offset = timezone_offset
|
||||
self.time_buffer_seconds = max(0, time_buffer_seconds)
|
||||
self.scoped_token = scoped_token
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self._low_timeout_confluence_client: OnyxConfluence | None = None
|
||||
@ -1356,6 +1359,24 @@ class ConfluenceConnector(
|
||||
logging.info(f"Setting allow_images to {value}.")
|
||||
self.allow_images = value
|
||||
|
||||
def _adjust_start_for_query(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> SecondsSinceUnixEpoch | None:
|
||||
if not start or start <= 0:
|
||||
return start
|
||||
if self.time_buffer_seconds <= 0:
|
||||
return start
|
||||
return max(0.0, start - self.time_buffer_seconds)
|
||||
|
||||
def _is_newer_than_start(
|
||||
self, doc_time: datetime | None, start: SecondsSinceUnixEpoch | None
|
||||
) -> bool:
|
||||
if not start or start <= 0:
|
||||
return True
|
||||
if doc_time is None:
|
||||
return True
|
||||
return doc_time.timestamp() > start
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
@ -1414,9 +1435,10 @@ class ConfluenceConnector(
|
||||
"""
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
# Add time filters
|
||||
if start:
|
||||
query_start = self._adjust_start_for_query(start)
|
||||
if query_start:
|
||||
formatted_start_time = datetime.fromtimestamp(
|
||||
start, tz=self.timezone
|
||||
query_start, tz=self.timezone
|
||||
).strftime("%Y-%m-%d %H:%M")
|
||||
page_query += f" and lastmodified >= '{formatted_start_time}'"
|
||||
if end:
|
||||
@ -1436,10 +1458,12 @@ class ConfluenceConnector(
|
||||
) -> str:
|
||||
attachment_query = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_query += self.cql_label_filter
|
||||
|
||||
# Add time filters to avoid reprocessing unchanged attachments during refresh
|
||||
if start:
|
||||
query_start = self._adjust_start_for_query(start)
|
||||
if query_start:
|
||||
formatted_start_time = datetime.fromtimestamp(
|
||||
start, tz=self.timezone
|
||||
query_start, tz=self.timezone
|
||||
).strftime("%Y-%m-%d %H:%M")
|
||||
attachment_query += f" and lastmodified >= '{formatted_start_time}'"
|
||||
if end:
|
||||
@ -1447,6 +1471,7 @@ class ConfluenceConnector(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
attachment_query += f" and lastmodified <= '{formatted_end_time}'"
|
||||
|
||||
attachment_query += " order by lastmodified asc"
|
||||
return attachment_query
|
||||
|
||||
@ -1668,7 +1693,8 @@ class ConfluenceConnector(
|
||||
),
|
||||
primary_owners=primary_owners,
|
||||
)
|
||||
attachment_docs.append(attachment_doc)
|
||||
if self._is_newer_than_start(attachment_doc.doc_updated_at, start):
|
||||
attachment_docs.append(attachment_doc)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Failed to extract/summarize attachment {attachment['title']}",
|
||||
@ -1729,7 +1755,8 @@ class ConfluenceConnector(
|
||||
continue
|
||||
|
||||
# yield completed document (or failure)
|
||||
yield doc_or_failure
|
||||
if self._is_newer_than_start(doc_or_failure.doc_updated_at, start):
|
||||
yield doc_or_failure
|
||||
|
||||
# Now get attachments for that page:
|
||||
attachment_docs, attachment_failures = self._fetch_page_attachments(
|
||||
|
||||
@ -63,7 +63,7 @@ def _convert_message_to_document(
|
||||
semantic_identifier=semantic_identifier,
|
||||
doc_updated_at=doc_updated_at,
|
||||
blob=message.content.encode("utf-8"),
|
||||
extension="txt",
|
||||
extension=".txt",
|
||||
size_bytes=len(message.content.encode("utf-8")),
|
||||
)
|
||||
|
||||
@ -275,7 +275,7 @@ class DiscordConnector(LoadConnector, PollConnector):
|
||||
semantic_identifier=f"{min_updated_at} -> {max_updated_at}",
|
||||
doc_updated_at=max_updated_at,
|
||||
blob=blob,
|
||||
extension="txt",
|
||||
extension=".txt",
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
|
||||
|
||||
@ -1,39 +1,18 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from common.data_source.config import (
|
||||
INDEX_BATCH_SIZE,
|
||||
DocumentSource, DB_CREDENTIALS_PRIMARY_ADMIN_KEY, USER_FIELDS, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS,
|
||||
SLIM_BATCH_SIZE
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.models import (
|
||||
BasicExpertInfo,
|
||||
Document,
|
||||
TextSection,
|
||||
SlimDocument, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput
|
||||
)
|
||||
from common.data_source.utils import (
|
||||
is_mail_service_disabled_error,
|
||||
build_time_range_query,
|
||||
clean_email_and_extract_name,
|
||||
get_message_body,
|
||||
get_google_creds,
|
||||
get_admin_service,
|
||||
get_gmail_service,
|
||||
execute_paginated_retrieval,
|
||||
execute_single_retrieval,
|
||||
time_str_to_utc
|
||||
)
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE, SLIM_BATCH_SIZE, DocumentSource
|
||||
from common.data_source.google_util.auth import get_google_creds
|
||||
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS
|
||||
from common.data_source.google_util.resource import get_admin_service, get_gmail_service
|
||||
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
|
||||
from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection
|
||||
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, time_str_to_utc
|
||||
|
||||
# Constants for Gmail API fields
|
||||
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
|
||||
@ -57,20 +36,18 @@ def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertIn
|
||||
else:
|
||||
first_name = None
|
||||
last_name = None
|
||||
owners.append(
|
||||
BasicExpertInfo(email=email, first_name=first_name, last_name=last_name)
|
||||
)
|
||||
owners.append(BasicExpertInfo(email=email, first_name=first_name, last_name=last_name))
|
||||
return owners
|
||||
|
||||
|
||||
def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
|
||||
"""Convert Gmail message to text section and metadata."""
|
||||
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
|
||||
|
||||
|
||||
payload = message.get("payload", {})
|
||||
headers = payload.get("headers", [])
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
|
||||
for header in headers:
|
||||
name = header.get("name", "").lower()
|
||||
value = header.get("value", "")
|
||||
@ -80,71 +57,64 @@ def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str,
|
||||
metadata["subject"] = value
|
||||
if name == "date":
|
||||
metadata["updated_at"] = value
|
||||
|
||||
|
||||
if labels := message.get("labelIds"):
|
||||
metadata["labels"] = labels
|
||||
|
||||
|
||||
message_data = ""
|
||||
for name, value in metadata.items():
|
||||
if name != "updated_at":
|
||||
message_data += f"{name}: {value}\n"
|
||||
|
||||
|
||||
message_body_text: str = get_message_body(payload)
|
||||
|
||||
|
||||
return TextSection(link=link, text=message_body_text + message_data), metadata
|
||||
|
||||
|
||||
def thread_to_document(
|
||||
full_thread: dict[str, Any],
|
||||
email_used_to_fetch_thread: str
|
||||
) -> Document | None:
|
||||
def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: str) -> Document | None:
|
||||
"""Convert Gmail thread to Document object."""
|
||||
all_messages = full_thread.get("messages", [])
|
||||
if not all_messages:
|
||||
return None
|
||||
|
||||
|
||||
sections = []
|
||||
semantic_identifier = ""
|
||||
updated_at = None
|
||||
from_emails: dict[str, str | None] = {}
|
||||
other_emails: dict[str, str | None] = {}
|
||||
|
||||
|
||||
for message in all_messages:
|
||||
section, message_metadata = message_to_section(message)
|
||||
sections.append(section)
|
||||
|
||||
|
||||
for name, value in message_metadata.items():
|
||||
if name in EMAIL_FIELDS:
|
||||
email, display_name = clean_email_and_extract_name(value)
|
||||
if name == "from":
|
||||
from_emails[email] = (
|
||||
display_name if not from_emails.get(email) else None
|
||||
)
|
||||
from_emails[email] = display_name if not from_emails.get(email) else None
|
||||
else:
|
||||
other_emails[email] = (
|
||||
display_name if not other_emails.get(email) else None
|
||||
)
|
||||
|
||||
other_emails[email] = display_name if not other_emails.get(email) else None
|
||||
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = message_metadata.get("subject", "")
|
||||
|
||||
|
||||
if message_metadata.get("updated_at"):
|
||||
updated_at = message_metadata.get("updated_at")
|
||||
|
||||
|
||||
updated_at_datetime = None
|
||||
if updated_at:
|
||||
updated_at_datetime = time_str_to_utc(updated_at)
|
||||
|
||||
|
||||
thread_id = full_thread.get("id")
|
||||
if not thread_id:
|
||||
raise ValueError("Thread ID is required")
|
||||
|
||||
|
||||
primary_owners = _get_owners_from_emails(from_emails)
|
||||
secondary_owners = _get_owners_from_emails(other_emails)
|
||||
|
||||
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = "(no subject)"
|
||||
|
||||
|
||||
return Document(
|
||||
id=thread_id,
|
||||
semantic_identifier=semantic_identifier,
|
||||
@ -164,7 +134,7 @@ def thread_to_document(
|
||||
|
||||
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Gmail connector for synchronizing emails from Gmail accounts."""
|
||||
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
@ -174,40 +144,28 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
def primary_admin_email(self) -> str:
|
||||
"""Get primary admin email."""
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
raise RuntimeError("Primary admin email missing, should not call this property before calling load_credentials")
|
||||
return self._primary_admin_email
|
||||
|
||||
@property
|
||||
def google_domain(self) -> str:
|
||||
"""Get Google domain from email."""
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
raise RuntimeError("Primary admin email missing, should not call this property before calling load_credentials")
|
||||
return self._primary_admin_email.split("@")[-1]
|
||||
|
||||
@property
|
||||
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
|
||||
"""Get Google credentials."""
|
||||
if self._creds is None:
|
||||
raise RuntimeError(
|
||||
"Creds missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
raise RuntimeError("Creds missing, should not call this property before calling load_credentials")
|
||||
return self._creds
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
"""Load Gmail credentials."""
|
||||
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||
self._primary_admin_email = primary_admin_email
|
||||
|
||||
|
||||
self._creds, new_creds_dict = get_google_creds(
|
||||
credentials=credentials,
|
||||
source=DocumentSource.GMAIL,
|
||||
@ -230,10 +188,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
return emails
|
||||
except HttpError as e:
|
||||
if e.resp.status == 404:
|
||||
logging.warning(
|
||||
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
|
||||
"with no Workspace domain. Falling back to single user."
|
||||
)
|
||||
logging.warning("Received 404 from Admin SDK; this may indicate a personal Gmail account with no Workspace domain. Falling back to single user.")
|
||||
return [self.primary_admin_email]
|
||||
raise
|
||||
except Exception:
|
||||
@ -247,7 +202,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Fetch Gmail threads within time range."""
|
||||
query = build_time_range_query(time_range_start, time_range_end)
|
||||
doc_batch = []
|
||||
|
||||
|
||||
for user_email in self._get_all_user_emails():
|
||||
gmail_service = get_gmail_service(self.creds, user_email)
|
||||
try:
|
||||
@ -259,7 +214,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
full_threads = execute_single_retrieval(
|
||||
full_threads = _execute_single_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().get,
|
||||
list_key=None,
|
||||
userId=user_email,
|
||||
@ -271,7 +226,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
doc = thread_to_document(full_thread, user_email)
|
||||
if doc is None:
|
||||
continue
|
||||
|
||||
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) > self.batch_size:
|
||||
yield doc_batch
|
||||
@ -284,7 +239,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
@ -297,9 +252,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
raise PermissionError(SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput:
|
||||
"""Poll Gmail for documents within time range."""
|
||||
try:
|
||||
yield from self._fetch_threads(start, end)
|
||||
@ -317,7 +270,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Retrieve slim documents for permission synchronization."""
|
||||
query = build_time_range_query(start, end)
|
||||
doc_batch = []
|
||||
|
||||
|
||||
for user_email in self._get_all_user_emails():
|
||||
logging.info(f"Fetching slim threads for user: {user_email}")
|
||||
gmail_service = get_gmail_service(self.creds, user_email)
|
||||
@ -351,10 +304,10 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
pass
|
||||
|
||||
0
common/data_source/google_drive/__init__.py
Normal file
0
common/data_source/google_drive/__init__.py
Normal file
1292
common/data_source/google_drive/connector.py
Normal file
1292
common/data_source/google_drive/connector.py
Normal file
File diff suppressed because it is too large
Load Diff
4
common/data_source/google_drive/constant.py
Normal file
4
common/data_source/google_drive/constant.py
Normal file
@ -0,0 +1,4 @@
|
||||
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
|
||||
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
|
||||
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
|
||||
DRIVE_FILE_TYPE = "application/vnd.google-apps.file"
|
||||
607
common/data_source/google_drive/doc_conversion.py
Normal file
607
common/data_source/google_drive/doc_conversion.py
Normal file
@ -0,0 +1,607 @@
|
||||
import io
|
||||
import logging
|
||||
import mimetypes
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore # type: ignore
|
||||
from googleapiclient.http import MediaIoBaseDownload # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.data_source.config import DocumentSource, FileOrigin
|
||||
from common.data_source.google_drive.constant import DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE
|
||||
from common.data_source.google_drive.model import GDriveMimeType, GoogleDriveFileType
|
||||
from common.data_source.google_drive.section_extraction import HEADING_DELIMITER
|
||||
from common.data_source.google_util.resource import GoogleDriveService, get_drive_service
|
||||
from common.data_source.models import ConnectorFailure, Document, DocumentFailure, ImageSection, SlimDocument, TextSection
|
||||
from common.data_source.utils import get_file_ext
|
||||
|
||||
# Image types that should be excluded from processing
|
||||
EXCLUDED_IMAGE_TYPES = [
|
||||
"image/bmp",
|
||||
"image/tiff",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"image/avif",
|
||||
]
|
||||
|
||||
GOOGLE_MIME_TYPES_TO_EXPORT = {
|
||||
GDriveMimeType.DOC.value: "text/plain",
|
||||
GDriveMimeType.SPREADSHEET.value: "text/csv",
|
||||
GDriveMimeType.PPT.value: "text/plain",
|
||||
}
|
||||
|
||||
GOOGLE_NATIVE_EXPORT_TARGETS: dict[str, tuple[str, str]] = {
|
||||
GDriveMimeType.DOC.value: ("application/vnd.openxmlformats-officedocument.wordprocessingml.document", ".docx"),
|
||||
GDriveMimeType.SPREADSHEET.value: ("application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ".xlsx"),
|
||||
GDriveMimeType.PPT.value: ("application/vnd.openxmlformats-officedocument.presentationml.presentation", ".pptx"),
|
||||
}
|
||||
GOOGLE_NATIVE_EXPORT_FALLBACK: tuple[str, str] = ("application/pdf", ".pdf")
|
||||
|
||||
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
|
||||
".txt",
|
||||
".md",
|
||||
".mdx",
|
||||
".conf",
|
||||
".log",
|
||||
".json",
|
||||
".csv",
|
||||
".tsv",
|
||||
".xml",
|
||||
".yml",
|
||||
".yaml",
|
||||
".sql",
|
||||
]
|
||||
|
||||
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
|
||||
".pdf",
|
||||
".docx",
|
||||
".pptx",
|
||||
".xlsx",
|
||||
".eml",
|
||||
".epub",
|
||||
".html",
|
||||
]
|
||||
|
||||
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".webp",
|
||||
]
|
||||
|
||||
ALL_ACCEPTED_FILE_EXTENSIONS = ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS + ACCEPTED_DOCUMENT_FILE_EXTENSIONS + ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
|
||||
MAX_RETRIEVER_EMAILS = 20
|
||||
CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read
|
||||
# This is not a standard valid unicode char, it is used by the docs advanced API to
|
||||
# represent smart chips (elements like dates and doc links).
|
||||
SMART_CHIP_CHAR = "\ue907"
|
||||
WEB_VIEW_LINK_KEY = "webViewLink"
|
||||
# Fallback templates for generating web links when Drive omits webViewLink.
|
||||
_FALLBACK_WEB_VIEW_LINK_TEMPLATES = {
|
||||
GDriveMimeType.DOC.value: "https://docs.google.com/document/d/{}/view",
|
||||
GDriveMimeType.SPREADSHEET.value: "https://docs.google.com/spreadsheets/d/{}/view",
|
||||
GDriveMimeType.PPT.value: "https://docs.google.com/presentation/d/{}/view",
|
||||
}
|
||||
|
||||
|
||||
class PermissionSyncContext(BaseModel):
|
||||
"""
|
||||
This is the information that is needed to sync permissions for a document.
|
||||
"""
|
||||
|
||||
primary_admin_email: str
|
||||
google_domain: str
|
||||
|
||||
|
||||
def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
|
||||
link = file.get(WEB_VIEW_LINK_KEY)
|
||||
if not link:
|
||||
file_id = file.get("id")
|
||||
if not file_id:
|
||||
raise KeyError(f"Google Drive file missing both '{WEB_VIEW_LINK_KEY}' and 'id' fields.")
|
||||
mime_type = file.get("mimeType", "")
|
||||
template = _FALLBACK_WEB_VIEW_LINK_TEMPLATES.get(mime_type)
|
||||
if template is None:
|
||||
link = f"https://drive.google.com/file/d/{file_id}/view"
|
||||
else:
|
||||
link = template.format(file_id)
|
||||
logging.debug(
|
||||
"Missing webViewLink for Google Drive file with id %s. Falling back to constructed link %s",
|
||||
file_id,
|
||||
link,
|
||||
)
|
||||
parsed_url = urlparse(link)
|
||||
parsed_url = parsed_url._replace(query="") # remove query parameters
|
||||
spl_path = parsed_url.path.split("/")
|
||||
if spl_path and (spl_path[-1] in ["edit", "view", "preview"]):
|
||||
spl_path.pop()
|
||||
parsed_url = parsed_url._replace(path="/".join(spl_path))
|
||||
# Remove query parameters and reconstruct URL
|
||||
return urlunparse(parsed_url)
|
||||
|
||||
|
||||
def _find_nth(haystack: str, needle: str, n: int, start: int = 0) -> int:
|
||||
start = haystack.find(needle, start)
|
||||
while start >= 0 and n > 1:
|
||||
start = haystack.find(needle, start + len(needle))
|
||||
n -= 1
|
||||
return start
|
||||
|
||||
|
||||
def align_basic_advanced(basic_sections: list[TextSection | ImageSection], adv_sections: list[TextSection]) -> list[TextSection | ImageSection]:
|
||||
"""Align the basic sections with the advanced sections.
|
||||
In particular, the basic sections contain all content of the file,
|
||||
including smart chips like dates and doc links. The advanced sections
|
||||
are separated by section headers and contain header-based links that
|
||||
improve user experience when they click on the source in the UI.
|
||||
|
||||
There are edge cases in text matching (i.e. the heading is a smart chip or
|
||||
there is a smart chip in the doc with text containing the actual heading text)
|
||||
that make the matching imperfect; this is hence done on a best-effort basis.
|
||||
"""
|
||||
if len(adv_sections) <= 1:
|
||||
return basic_sections # no benefit from aligning
|
||||
|
||||
basic_full_text = "".join([section.text for section in basic_sections if isinstance(section, TextSection)])
|
||||
new_sections: list[TextSection | ImageSection] = []
|
||||
heading_start = 0
|
||||
for adv_ind in range(1, len(adv_sections)):
|
||||
heading = adv_sections[adv_ind].text.split(HEADING_DELIMITER)[0]
|
||||
# retrieve the longest part of the heading that is not a smart chip
|
||||
heading_key = max(heading.split(SMART_CHIP_CHAR), key=len).strip()
|
||||
if heading_key == "":
|
||||
logging.warning(f"Cannot match heading: {heading}, its link will come from the following section")
|
||||
continue
|
||||
heading_offset = heading.find(heading_key)
|
||||
|
||||
# count occurrences of heading str in previous section
|
||||
heading_count = adv_sections[adv_ind - 1].text.count(heading_key)
|
||||
|
||||
prev_start = heading_start
|
||||
heading_start = _find_nth(basic_full_text, heading_key, heading_count, start=prev_start) - heading_offset
|
||||
if heading_start < 0:
|
||||
logging.warning(f"Heading key {heading_key} from heading {heading} not found in basic text")
|
||||
heading_start = prev_start
|
||||
continue
|
||||
|
||||
new_sections.append(
|
||||
TextSection(
|
||||
link=adv_sections[adv_ind - 1].link,
|
||||
text=basic_full_text[prev_start:heading_start],
|
||||
)
|
||||
)
|
||||
|
||||
# handle last section
|
||||
new_sections.append(TextSection(link=adv_sections[-1].link, text=basic_full_text[heading_start:]))
|
||||
return new_sections
|
||||
|
||||
|
||||
def is_valid_image_type(mime_type: str) -> bool:
|
||||
"""
|
||||
Check if mime_type is a valid image type.
|
||||
|
||||
Args:
|
||||
mime_type: The MIME type to check
|
||||
|
||||
Returns:
|
||||
True if the MIME type is a valid image type, False otherwise
|
||||
"""
|
||||
return bool(mime_type) and mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
|
||||
|
||||
def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||
"""
|
||||
Return True if the mime_type is a common image type in GDrive.
|
||||
(e.g. 'image/png', 'image/jpeg')
|
||||
"""
|
||||
return is_valid_image_type(mime_type)
|
||||
|
||||
|
||||
def _get_extension_from_file(file: GoogleDriveFileType, mime_type: str, fallback: str = ".bin") -> str:
|
||||
file_name = file.get("name") or ""
|
||||
if file_name:
|
||||
suffix = Path(file_name).suffix
|
||||
if suffix:
|
||||
return suffix
|
||||
|
||||
file_extension = file.get("fileExtension")
|
||||
if file_extension:
|
||||
return f".{file_extension.lstrip('.')}"
|
||||
|
||||
guessed = mimetypes.guess_extension(mime_type or "")
|
||||
if guessed:
|
||||
return guessed
|
||||
|
||||
return fallback
|
||||
|
||||
|
||||
def _download_file_blob(
|
||||
service: GoogleDriveService,
|
||||
file: GoogleDriveFileType,
|
||||
size_threshold: int,
|
||||
allow_images: bool,
|
||||
) -> tuple[bytes, str] | None:
|
||||
mime_type = file.get("mimeType", "")
|
||||
file_id = file.get("id")
|
||||
if not file_id:
|
||||
logging.warning("Encountered Google Drive file without id.")
|
||||
return None
|
||||
|
||||
if is_gdrive_image_mime_type(mime_type) and not allow_images:
|
||||
logging.debug(f"Skipping image {file.get('name')} because allow_images is False.")
|
||||
return None
|
||||
|
||||
blob: bytes = b""
|
||||
extension = ".bin"
|
||||
try:
|
||||
if mime_type in GOOGLE_NATIVE_EXPORT_TARGETS:
|
||||
export_mime, extension = GOOGLE_NATIVE_EXPORT_TARGETS[mime_type]
|
||||
request = service.files().export_media(fileId=file_id, mimeType=export_mime)
|
||||
blob = _download_request(request, file_id, size_threshold)
|
||||
elif mime_type.startswith("application/vnd.google-apps"):
|
||||
export_mime, extension = GOOGLE_NATIVE_EXPORT_FALLBACK
|
||||
request = service.files().export_media(fileId=file_id, mimeType=export_mime)
|
||||
blob = _download_request(request, file_id, size_threshold)
|
||||
else:
|
||||
extension = _get_extension_from_file(file, mime_type)
|
||||
blob = download_request(service, file_id, size_threshold)
|
||||
except HttpError:
|
||||
raise
|
||||
|
||||
if not blob:
|
||||
return None
|
||||
if not extension:
|
||||
extension = _get_extension_from_file(file, mime_type)
|
||||
return blob, extension
|
||||
|
||||
|
||||
def download_request(service: GoogleDriveService, file_id: str, size_threshold: int) -> bytes:
|
||||
"""
|
||||
Download the file from Google Drive.
|
||||
"""
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
return _download_request(request, file_id, size_threshold)
|
||||
|
||||
|
||||
def _download_request(request: Any, file_id: str, size_threshold: int) -> bytes:
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request, chunksize=size_threshold + CHUNK_SIZE_BUFFER)
|
||||
done = False
|
||||
while not done:
|
||||
download_progress, done = downloader.next_chunk()
|
||||
if download_progress.resumable_progress > size_threshold:
|
||||
logging.warning(f"File {file_id} exceeds size threshold of {size_threshold}. Skipping2.")
|
||||
return bytes()
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logging.warning(f"Failed to download {file_id}")
|
||||
return bytes()
|
||||
return response
|
||||
|
||||
|
||||
def _download_and_extract_sections_basic(
|
||||
file: dict[str, str],
|
||||
service: GoogleDriveService,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
) -> list[TextSection | ImageSection]:
|
||||
"""Extract text and images from a Google Drive file."""
|
||||
file_id = file["id"]
|
||||
file_name = file["name"]
|
||||
mime_type = file["mimeType"]
|
||||
link = file.get(WEB_VIEW_LINK_KEY, "")
|
||||
|
||||
# For non-Google files, download the file
|
||||
# Use the correct API call for downloading files
|
||||
# lazy evaluation to only download the file if necessary
|
||||
def response_call() -> bytes:
|
||||
return download_request(service, file_id, size_threshold)
|
||||
|
||||
if is_gdrive_image_mime_type(mime_type):
|
||||
# Skip images if not explicitly enabled
|
||||
if not allow_images:
|
||||
return []
|
||||
|
||||
# Store images for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
|
||||
def store_image_and_create_section(**kwargs):
|
||||
pass
|
||||
|
||||
try:
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
image_data=response_call(),
|
||||
file_id=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
# Use the correct API call for exporting files
|
||||
request = service.files().export_media(fileId=file_id, mimeType=export_mime_type)
|
||||
response = _download_request(request, file_id, size_threshold)
|
||||
if not response:
|
||||
logging.warning(f"Failed to export {file_name} as {export_mime_type}")
|
||||
return []
|
||||
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
# Process based on mime type
|
||||
if mime_type == "text/plain":
|
||||
try:
|
||||
text = response_call().decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
except UnicodeDecodeError as e:
|
||||
logging.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
|
||||
elif mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
|
||||
def docx_to_text_and_images(*args, **kwargs):
|
||||
return "docx_to_text_and_images"
|
||||
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
|
||||
|
||||
def xlsx_to_text(*args, **kwargs):
|
||||
return "xlsx_to_text"
|
||||
|
||||
text = xlsx_to_text(io.BytesIO(response_call()), file_name=file_name)
|
||||
return [TextSection(link=link, text=text)] if text else []
|
||||
|
||||
elif mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation":
|
||||
|
||||
def pptx_to_text(*args, **kwargs):
|
||||
return "pptx_to_text"
|
||||
|
||||
text = pptx_to_text(io.BytesIO(response_call()), file_name=file_name)
|
||||
return [TextSection(link=link, text=text)] if text else []
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
|
||||
def read_pdf_file(*args, **kwargs):
|
||||
return "read_pdf_file"
|
||||
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
|
||||
pdf_sections: list[TextSection | ImageSection] = [TextSection(link=link, text=text)]
|
||||
|
||||
# Process embedded images in the PDF
|
||||
try:
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
image_data=img_data,
|
||||
file_id=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
pdf_sections.append(section)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||
return pdf_sections
|
||||
|
||||
# Final attempt at extracting text
|
||||
file_ext = get_file_ext(file.get("name", ""))
|
||||
if file_ext not in ALL_ACCEPTED_FILE_EXTENSIONS:
|
||||
logging.warning(f"Skipping file {file.get('name')} due to extension.")
|
||||
return []
|
||||
|
||||
try:
|
||||
|
||||
def extract_file_text(*args, **kwargs):
|
||||
return "extract_file_text"
|
||||
|
||||
text = extract_file_text(io.BytesIO(response_call()), file_name)
|
||||
return [TextSection(link=link, text=text)]
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _convert_drive_item_to_document(
|
||||
creds: Any,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
retriever_email: str,
|
||||
file: GoogleDriveFileType,
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Main entry point for converting a Google Drive file => Document object.
|
||||
"""
|
||||
|
||||
def _get_drive_service() -> GoogleDriveService:
|
||||
return get_drive_service(creds, user_email=retriever_email)
|
||||
|
||||
doc_id = "unknown"
|
||||
link = file.get(WEB_VIEW_LINK_KEY)
|
||||
|
||||
try:
|
||||
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
|
||||
logging.info("Skipping shortcut/folder.")
|
||||
return None
|
||||
|
||||
size_str = file.get("size")
|
||||
if size_str:
|
||||
try:
|
||||
size_int = int(size_str)
|
||||
except ValueError:
|
||||
logging.warning(f"Parsing string to int failed: size_str={size_str}")
|
||||
else:
|
||||
if size_int > size_threshold:
|
||||
logging.warning(f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping.")
|
||||
return None
|
||||
|
||||
blob_and_ext = _download_file_blob(
|
||||
service=_get_drive_service(),
|
||||
file=file,
|
||||
size_threshold=size_threshold,
|
||||
allow_images=allow_images,
|
||||
)
|
||||
|
||||
if blob_and_ext is None:
|
||||
logging.info(f"Skipping file {file.get('name')} due to incompatible type or download failure.")
|
||||
return None
|
||||
|
||||
blob, extension = blob_and_ext
|
||||
if not blob:
|
||||
logging.warning(f"Failed to download {file.get('name')}. Skipping.")
|
||||
return None
|
||||
|
||||
doc_id = onyx_document_id_from_drive_file(file)
|
||||
modified_time = file.get("modifiedTime")
|
||||
try:
|
||||
doc_updated_at = datetime.fromisoformat(modified_time.replace("Z", "+00:00")) if modified_time else datetime.now(timezone.utc)
|
||||
except ValueError:
|
||||
logging.warning(f"Failed to parse modifiedTime for {file.get('name')}, defaulting to current time.")
|
||||
doc_updated_at = datetime.now(timezone.utc)
|
||||
|
||||
return Document(
|
||||
id=doc_id,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file.get("name", ""),
|
||||
blob=blob,
|
||||
extension=extension,
|
||||
size_bytes=len(blob),
|
||||
doc_updated_at=doc_updated_at,
|
||||
)
|
||||
except Exception as e:
|
||||
doc_id = "unknown"
|
||||
try:
|
||||
doc_id = onyx_document_id_from_drive_file(file)
|
||||
except Exception as e2:
|
||||
logging.warning(f"Error getting document id from file: {e2}")
|
||||
|
||||
file_name = file.get("name", doc_id)
|
||||
error_str = f"Error converting file '{file_name}' to Document as {retriever_email}: {e}"
|
||||
if isinstance(e, HttpError) and e.status_code == 403:
|
||||
logging.warning(f"Uncommon permissions error while downloading file. User {retriever_email} was able to see file {file_name} but cannot download it.")
|
||||
logging.warning(error_str)
|
||||
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=link,
|
||||
),
|
||||
failed_entity=None,
|
||||
failure_message=error_str,
|
||||
exception=e,
|
||||
)
|
||||
|
||||
|
||||
def convert_drive_item_to_document(
|
||||
creds: Any,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
retriever_emails: list[str],
|
||||
file: GoogleDriveFileType,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Attempt to convert a drive item to a document with each retriever email
|
||||
in order. returns upon a successful retrieval or a non-403 error.
|
||||
|
||||
We used to always get the user email from the file owners when available,
|
||||
but this was causing issues with shared folders where the owner was not included in the service account
|
||||
now we use the email of the account that successfully listed the file. There are cases where a
|
||||
user that can list a file cannot download it, so we retry with file owners and admin email.
|
||||
"""
|
||||
first_error = None
|
||||
doc_or_failure = None
|
||||
retriever_emails = retriever_emails[:MAX_RETRIEVER_EMAILS]
|
||||
# use seen instead of list(set()) to avoid re-ordering the retriever emails
|
||||
seen = set()
|
||||
for retriever_email in retriever_emails:
|
||||
if retriever_email in seen:
|
||||
continue
|
||||
seen.add(retriever_email)
|
||||
doc_or_failure = _convert_drive_item_to_document(
|
||||
creds,
|
||||
allow_images,
|
||||
size_threshold,
|
||||
retriever_email,
|
||||
file,
|
||||
permission_sync_context,
|
||||
)
|
||||
|
||||
# There are a variety of permissions-based errors that occasionally occur
|
||||
# when retrieving files. Often when these occur, there is another user
|
||||
# that can successfully retrieve the file, so we try the next user.
|
||||
if doc_or_failure is None or isinstance(doc_or_failure, Document) or not (isinstance(doc_or_failure.exception, HttpError) and doc_or_failure.exception.status_code in [401, 403, 404]):
|
||||
return doc_or_failure
|
||||
|
||||
if first_error is None:
|
||||
first_error = doc_or_failure
|
||||
else:
|
||||
first_error.failure_message += f"\n\n{doc_or_failure.failure_message}"
|
||||
|
||||
if first_error and isinstance(first_error.exception, HttpError) and first_error.exception.status_code == 403:
|
||||
# This SHOULD happen very rarely, and we don't want to break the indexing process when
|
||||
# a high volume of 403s occurs early. We leave a verbose log to help investigate.
|
||||
logging.error(
|
||||
f"Skipping file id: {file.get('id')} name: {file.get('name')} due to 403 error.Attempted to retrieve with {retriever_emails},got the following errors: {first_error.failure_message}"
|
||||
)
|
||||
return None
|
||||
return first_error
|
||||
|
||||
|
||||
def build_slim_document(
|
||||
creds: Any,
|
||||
file: GoogleDriveFileType,
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
) -> SlimDocument | None:
|
||||
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
|
||||
return None
|
||||
|
||||
owner_email = cast(str | None, file.get("owners", [{}])[0].get("emailAddress"))
|
||||
|
||||
def _get_external_access_for_raw_gdrive_file(*args, **kwargs):
|
||||
return None
|
||||
|
||||
external_access = (
|
||||
_get_external_access_for_raw_gdrive_file(
|
||||
file=file,
|
||||
company_domain=permission_sync_context.google_domain,
|
||||
retriever_drive_service=(
|
||||
get_drive_service(
|
||||
creds,
|
||||
user_email=owner_email,
|
||||
)
|
||||
if owner_email
|
||||
else None
|
||||
),
|
||||
admin_drive_service=get_drive_service(
|
||||
creds,
|
||||
user_email=permission_sync_context.primary_admin_email,
|
||||
),
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
)
|
||||
return SlimDocument(
|
||||
id=onyx_document_id_from_drive_file(file),
|
||||
external_access=external_access,
|
||||
)
|
||||
346
common/data_source/google_drive/file_retrieval.py
Normal file
346
common/data_source/google_drive/file_retrieval.py
Normal file
@ -0,0 +1,346 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Iterator
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from common.data_source.google_drive.constant import DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE
|
||||
from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveFileType, RetrievedDriveFile
|
||||
from common.data_source.google_util.resource import GoogleDriveService
|
||||
from common.data_source.google_util.util import ORDER_BY_KEY, PAGE_TOKEN_KEY, GoogleFields, execute_paginated_retrieval, execute_paginated_retrieval_with_max_pages
|
||||
from common.data_source.models import SecondsSinceUnixEpoch
|
||||
|
||||
PERMISSION_FULL_DESCRIPTION = "permissions(id, emailAddress, type, domain, permissionDetails)"
|
||||
|
||||
FILE_FIELDS = "nextPageToken, files(mimeType, id, name, modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
|
||||
FILE_FIELDS_WITH_PERMISSIONS = f"nextPageToken, files(mimeType, id, name, {PERMISSION_FULL_DESCRIPTION}, permissionIds, modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
|
||||
SLIM_FILE_FIELDS = f"nextPageToken, files(mimeType, driveId, id, name, {PERMISSION_FULL_DESCRIPTION}, permissionIds, webViewLink, owners(emailAddress), modifiedTime)"
|
||||
|
||||
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||
|
||||
|
||||
class DriveFileFieldType(Enum):
|
||||
"""Enum to specify which fields to retrieve from Google Drive files"""
|
||||
|
||||
SLIM = "slim" # Minimal fields for basic file info
|
||||
STANDARD = "standard" # Standard fields including content metadata
|
||||
WITH_PERMISSIONS = "with_permissions" # Full fields including permissions
|
||||
|
||||
|
||||
def generate_time_range_filter(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str:
|
||||
time_range_filter = ""
|
||||
if start is not None:
|
||||
time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
|
||||
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} > '{time_start}'"
|
||||
if end is not None:
|
||||
time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
|
||||
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'"
|
||||
return time_range_filter
|
||||
|
||||
|
||||
def _get_folders_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
# Follow shortcuts to folders
|
||||
query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')"
|
||||
query += " and trashed = false"
|
||||
|
||||
if parent_id:
|
||||
query += f" and '{parent_id}' in parents"
|
||||
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="allDrives",
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=FOLDER_FIELDS,
|
||||
q=query,
|
||||
):
|
||||
yield file
|
||||
|
||||
|
||||
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
"""Get the appropriate fields string based on the field type enum"""
|
||||
if field_type == DriveFileFieldType.SLIM:
|
||||
return SLIM_FILE_FIELDS
|
||||
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
|
||||
return FILE_FIELDS_WITH_PERMISSIONS
|
||||
else: # DriveFileFieldType.STANDARD
|
||||
return FILE_FIELDS
|
||||
|
||||
|
||||
def _get_files_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
field_type: DriveFileFieldType,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
|
||||
query += " and trashed = false"
|
||||
query += generate_time_range_filter(start, end)
|
||||
|
||||
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="allDrives",
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=query,
|
||||
**kwargs,
|
||||
):
|
||||
yield file
|
||||
|
||||
|
||||
def crawl_folders_for_files(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
field_type: DriveFileFieldType,
|
||||
user_email: str,
|
||||
traversed_parent_ids: set[str],
|
||||
update_traversed_ids_func: Callable[[str], None],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[RetrievedDriveFile]:
|
||||
"""
|
||||
This function starts crawling from any folder. It is slower though.
|
||||
"""
|
||||
logging.info("Entered crawl_folders_for_files with parent_id: " + parent_id)
|
||||
if parent_id not in traversed_parent_ids:
|
||||
logging.info("Parent id not in traversed parent ids, getting files")
|
||||
found_files = False
|
||||
file = {}
|
||||
try:
|
||||
for file in _get_files_in_parent(
|
||||
service=service,
|
||||
parent_id=parent_id,
|
||||
field_type=field_type,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
logging.info(f"Found file: {file['name']}, user email: {user_email}")
|
||||
found_files = True
|
||||
yield RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=user_email,
|
||||
parent_id=parent_id,
|
||||
completion_stage=DriveRetrievalStage.FOLDER_FILES,
|
||||
)
|
||||
# Only mark a folder as done if it was fully traversed without errors
|
||||
# This usually indicates that the owner of the folder was impersonated.
|
||||
# In cases where this never happens, most likely the folder owner is
|
||||
# not part of the google workspace in question (or for oauth, the authenticated
|
||||
# user doesn't own the folder)
|
||||
if found_files:
|
||||
update_traversed_ids_func(parent_id)
|
||||
except Exception as e:
|
||||
if isinstance(e, HttpError) and e.status_code == 403:
|
||||
# don't yield an error here because this is expected behavior
|
||||
# when a user doesn't have access to a folder
|
||||
logging.debug(f"Error getting files in parent {parent_id}: {e}")
|
||||
else:
|
||||
logging.error(f"Error getting files in parent {parent_id}: {e}")
|
||||
yield RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=user_email,
|
||||
parent_id=parent_id,
|
||||
completion_stage=DriveRetrievalStage.FOLDER_FILES,
|
||||
error=e,
|
||||
)
|
||||
else:
|
||||
logging.info(f"Skipping subfolder files since already traversed: {parent_id}")
|
||||
|
||||
for subfolder in _get_folders_in_parent(
|
||||
service=service,
|
||||
parent_id=parent_id,
|
||||
):
|
||||
logging.info("Fetching all files in subfolder: " + subfolder["name"])
|
||||
yield from crawl_folders_for_files(
|
||||
service=service,
|
||||
parent_id=subfolder["id"],
|
||||
field_type=field_type,
|
||||
user_email=user_email,
|
||||
traversed_parent_ids=traversed_parent_ids,
|
||||
update_traversed_ids_func=update_traversed_ids_func,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
|
||||
def get_files_in_shared_drive(
|
||||
service: Resource,
|
||||
drive_id: str,
|
||||
field_type: DriveFileFieldType,
|
||||
max_num_pages: int,
|
||||
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
|
||||
cache_folders: bool = True,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
page_token: str | None = None,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||
if page_token:
|
||||
logging.info(f"Using page token: {page_token}")
|
||||
kwargs[PAGE_TOKEN_KEY] = page_token
|
||||
|
||||
if cache_folders:
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
folder_query += " and trashed = false"
|
||||
for folder in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="drive",
|
||||
driveId=drive_id,
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields="nextPageToken, files(id)",
|
||||
q=folder_query,
|
||||
):
|
||||
update_traversed_ids_func(folder["id"])
|
||||
|
||||
# Get all files in the shared drive
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
file_query += generate_time_range_filter(start, end)
|
||||
|
||||
for file in execute_paginated_retrieval_with_max_pages(
|
||||
retrieval_function=service.files().list,
|
||||
max_num_pages=max_num_pages,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="drive",
|
||||
driveId=drive_id,
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
):
|
||||
# If we found any files, mark this drive as traversed. When a user has access to a drive,
|
||||
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
|
||||
# empty drives.
|
||||
# NOTE: ^^ the above is not actually true due to folder restrictions:
|
||||
# https://support.google.com/a/users/answer/12380484?hl=en
|
||||
# So we may have to change this logic for people who use folder restrictions.
|
||||
update_traversed_ids_func(drive_id)
|
||||
yield file
|
||||
|
||||
|
||||
def get_all_files_in_my_drive_and_shared(
|
||||
service: GoogleDriveService,
|
||||
update_traversed_ids_func: Callable,
|
||||
field_type: DriveFileFieldType,
|
||||
include_shared_with_me: bool,
|
||||
max_num_pages: int,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
cache_folders: bool = True,
|
||||
page_token: str | None = None,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||
if page_token:
|
||||
logging.info(f"Using page token: {page_token}")
|
||||
kwargs[PAGE_TOKEN_KEY] = page_token
|
||||
|
||||
if cache_folders:
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
folder_query += " and trashed = false"
|
||||
if not include_shared_with_me:
|
||||
folder_query += " and 'me' in owners"
|
||||
found_folders = False
|
||||
for folder in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
corpora="user",
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=folder_query,
|
||||
):
|
||||
update_traversed_ids_func(folder[GoogleFields.ID])
|
||||
found_folders = True
|
||||
if found_folders:
|
||||
update_traversed_ids_func(get_root_folder_id(service))
|
||||
|
||||
# Then get the files
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
if not include_shared_with_me:
|
||||
file_query += " and 'me' in owners"
|
||||
file_query += generate_time_range_filter(start, end)
|
||||
yield from execute_paginated_retrieval_with_max_pages(
|
||||
retrieval_function=service.files().list,
|
||||
max_num_pages=max_num_pages,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=False,
|
||||
corpora="user",
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_all_files_for_oauth(
|
||||
service: GoogleDriveService,
|
||||
include_files_shared_with_me: bool,
|
||||
include_my_drives: bool,
|
||||
# One of the above 2 should be true
|
||||
include_shared_drives: bool,
|
||||
field_type: DriveFileFieldType,
|
||||
max_num_pages: int,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
page_token: str | None = None,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||
if page_token:
|
||||
logging.info(f"Using page token: {page_token}")
|
||||
kwargs[PAGE_TOKEN_KEY] = page_token
|
||||
|
||||
should_get_all = include_shared_drives and include_my_drives and include_files_shared_with_me
|
||||
corpora = "allDrives" if should_get_all else "user"
|
||||
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
file_query += generate_time_range_filter(start, end)
|
||||
|
||||
if not should_get_all:
|
||||
if include_files_shared_with_me and not include_my_drives:
|
||||
file_query += " and not 'me' in owners"
|
||||
if not include_files_shared_with_me and include_my_drives:
|
||||
file_query += " and 'me' in owners"
|
||||
|
||||
yield from execute_paginated_retrieval_with_max_pages(
|
||||
max_num_pages=max_num_pages,
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=False,
|
||||
corpora=corpora,
|
||||
includeItemsFromAllDrives=should_get_all,
|
||||
supportsAllDrives=should_get_all,
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Just in case we need to get the root folder id
|
||||
def get_root_folder_id(service: Resource) -> str:
|
||||
# we dont paginate here because there is only one root folder per user
|
||||
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
|
||||
return service.files().get(fileId="root", fields=GoogleFields.ID.value).execute()[GoogleFields.ID.value]
|
||||
144
common/data_source/google_drive/model.py
Normal file
144
common/data_source/google_drive/model.py
Normal file
@ -0,0 +1,144 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||
|
||||
from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict
|
||||
from common.data_source.models import ConnectorCheckpoint, SecondsSinceUnixEpoch
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
|
||||
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||
SPREADSHEET_OPEN_FORMAT = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel"
|
||||
PDF = "application/pdf"
|
||||
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
PPT = "application/vnd.google-apps.presentation"
|
||||
POWERPOINT = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
PLAIN_TEXT = "text/plain"
|
||||
MARKDOWN = "text/markdown"
|
||||
|
||||
|
||||
# These correspond to The major stages of retrieval for google drive.
|
||||
# The stages for the oauth flow are:
|
||||
# get_all_files_for_oauth(),
|
||||
# get_all_drive_ids(),
|
||||
# get_files_in_shared_drive(),
|
||||
# crawl_folders_for_files()
|
||||
#
|
||||
# The stages for the service account flow are roughly:
|
||||
# get_all_user_emails(),
|
||||
# get_all_drive_ids(),
|
||||
# get_files_in_shared_drive(),
|
||||
# Then for each user:
|
||||
# get_files_in_my_drive()
|
||||
# get_files_in_shared_drive()
|
||||
# crawl_folders_for_files()
|
||||
class DriveRetrievalStage(str, Enum):
|
||||
START = "start"
|
||||
DONE = "done"
|
||||
# OAuth specific stages
|
||||
OAUTH_FILES = "oauth_files"
|
||||
|
||||
# Service account specific stages
|
||||
USER_EMAILS = "user_emails"
|
||||
MY_DRIVE_FILES = "my_drive_files"
|
||||
|
||||
# Used for both oauth and service account flows
|
||||
DRIVE_IDS = "drive_ids"
|
||||
SHARED_DRIVE_FILES = "shared_drive_files"
|
||||
FOLDER_FILES = "folder_files"
|
||||
|
||||
|
||||
class StageCompletion(BaseModel):
|
||||
"""
|
||||
Describes the point in the retrieval+indexing process that the
|
||||
connector is at. completed_until is the timestamp of the latest
|
||||
file that has been retrieved or error that has been yielded.
|
||||
Optional fields are used for retrieval stages that need more information
|
||||
for resuming than just the timestamp of the latest file.
|
||||
"""
|
||||
|
||||
stage: DriveRetrievalStage
|
||||
completed_until: SecondsSinceUnixEpoch
|
||||
current_folder_or_drive_id: str | None = None
|
||||
next_page_token: str | None = None
|
||||
|
||||
# only used for shared drives
|
||||
processed_drive_ids: set[str] = set()
|
||||
|
||||
def update(
|
||||
self,
|
||||
stage: DriveRetrievalStage,
|
||||
completed_until: SecondsSinceUnixEpoch,
|
||||
current_folder_or_drive_id: str | None = None,
|
||||
) -> None:
|
||||
self.stage = stage
|
||||
self.completed_until = completed_until
|
||||
self.current_folder_or_drive_id = current_folder_or_drive_id
|
||||
|
||||
|
||||
class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
# Checkpoint version of _retrieved_ids
|
||||
retrieved_folder_and_drive_ids: set[str]
|
||||
|
||||
# Describes the point in the retrieval+indexing process that the
|
||||
# checkpoint is at. when this is set to a given stage, the connector
|
||||
# has finished yielding all values from the previous stage.
|
||||
completion_stage: DriveRetrievalStage
|
||||
|
||||
# The latest timestamp of a file that has been retrieved per user email.
|
||||
# StageCompletion is used to track the completion of each stage, but the
|
||||
# timestamp part is not used for folder crawling.
|
||||
completion_map: ThreadSafeDict[str, StageCompletion]
|
||||
|
||||
# all file ids that have been retrieved
|
||||
all_retrieved_file_ids: set[str] = set()
|
||||
|
||||
# cached version of the drive and folder ids to retrieve
|
||||
drive_ids_to_retrieve: list[str] | None = None
|
||||
folder_ids_to_retrieve: list[str] | None = None
|
||||
|
||||
# cached user emails
|
||||
user_emails: list[str] | None = None
|
||||
|
||||
@field_serializer("completion_map")
|
||||
def serialize_completion_map(self, completion_map: ThreadSafeDict[str, StageCompletion], _info: Any) -> dict[str, StageCompletion]:
|
||||
return completion_map._dict
|
||||
|
||||
@field_validator("completion_map", mode="before")
|
||||
def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion]:
|
||||
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
|
||||
return ThreadSafeDict({k: StageCompletion.model_validate(val) for k, val in v.items()})
|
||||
|
||||
|
||||
class RetrievedDriveFile(BaseModel):
|
||||
"""
|
||||
Describes a file that has been retrieved from google drive.
|
||||
user_email is the email of the user that the file was retrieved
|
||||
by impersonating. If an error worthy of being reported is encountered,
|
||||
error should be set and later propagated as a ConnectorFailure.
|
||||
"""
|
||||
|
||||
# The stage at which this file was retrieved
|
||||
completion_stage: DriveRetrievalStage
|
||||
|
||||
# The file that was retrieved
|
||||
drive_file: GoogleDriveFileType
|
||||
|
||||
# The email of the user that the file was retrieved by impersonating
|
||||
user_email: str
|
||||
|
||||
# The id of the parent folder or drive of the file
|
||||
parent_id: str | None = None
|
||||
|
||||
# Any unexpected error that occurred while retrieving the file.
|
||||
# In particular, this is not used for 403/404 errors, which are expected
|
||||
# in the context of impersonating all the users to try to retrieve all
|
||||
# files from all their Drives and Folders.
|
||||
error: Exception | None = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
183
common/data_source/google_drive/section_extraction.py
Normal file
183
common/data_source/google_drive/section_extraction.py
Normal file
@ -0,0 +1,183 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.data_source.google_util.resource import GoogleDocsService
|
||||
from common.data_source.models import TextSection
|
||||
|
||||
HEADING_DELIMITER = "\n"
|
||||
|
||||
|
||||
class CurrentHeading(BaseModel):
|
||||
id: str | None
|
||||
text: str
|
||||
|
||||
|
||||
def get_document_sections(
|
||||
docs_service: GoogleDocsService,
|
||||
doc_id: str,
|
||||
) -> list[TextSection]:
|
||||
"""Extracts sections from a Google Doc, including their headings and content"""
|
||||
# Fetch the document structure
|
||||
http_request = docs_service.documents().get(documentId=doc_id)
|
||||
|
||||
# Google has poor support for tabs in the docs api, see
|
||||
# https://cloud.google.com/python/docs/reference/cloudtasks/
|
||||
# latest/google.cloud.tasks_v2.types.HttpRequest
|
||||
# https://developers.google.com/workspace/docs/api/how-tos/tabs
|
||||
# https://developers.google.com/workspace/docs/api/reference/rest/v1/documents/get
|
||||
# this is a hack to use the param mentioned in the rest api docs
|
||||
# TODO: check if it can be specified i.e. in documents()
|
||||
http_request.uri += "&includeTabsContent=true"
|
||||
doc = http_request.execute()
|
||||
|
||||
# Get the content
|
||||
tabs = doc.get("tabs", {})
|
||||
sections: list[TextSection] = []
|
||||
for tab in tabs:
|
||||
sections.extend(get_tab_sections(tab, doc_id))
|
||||
return sections
|
||||
|
||||
|
||||
def _is_heading(paragraph: dict[str, Any]) -> bool:
|
||||
"""Checks if a paragraph (a block of text in a drive document) is a heading"""
|
||||
if not ("paragraphStyle" in paragraph and "namedStyleType" in paragraph["paragraphStyle"]):
|
||||
return False
|
||||
|
||||
style = paragraph["paragraphStyle"]["namedStyleType"]
|
||||
is_heading = style.startswith("HEADING_")
|
||||
is_title = style.startswith("TITLE")
|
||||
return is_heading or is_title
|
||||
|
||||
|
||||
def _add_finished_section(
|
||||
sections: list[TextSection],
|
||||
doc_id: str,
|
||||
tab_id: str,
|
||||
current_heading: CurrentHeading,
|
||||
current_section: list[str],
|
||||
) -> None:
|
||||
"""Adds a finished section to the list of sections if the section has content.
|
||||
Returns the list of sections to use going forward, which may be the old list
|
||||
if a new section was not added.
|
||||
"""
|
||||
if not (current_section or current_heading.text):
|
||||
return
|
||||
# If we were building a previous section, add it to sections list
|
||||
|
||||
# this is unlikely to ever matter, but helps if the doc contains weird headings
|
||||
header_text = current_heading.text.replace(HEADING_DELIMITER, "")
|
||||
section_text = f"{header_text}{HEADING_DELIMITER}" + "\n".join(current_section)
|
||||
sections.append(
|
||||
TextSection(
|
||||
text=section_text.strip(),
|
||||
link=_build_gdoc_section_link(doc_id, tab_id, current_heading.id),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _build_gdoc_section_link(doc_id: str, tab_id: str, heading_id: str | None) -> str:
|
||||
"""Builds a Google Doc link that jumps to a specific heading"""
|
||||
# NOTE: doesn't support docs with multiple tabs atm, if we need that ask
|
||||
# @Chris
|
||||
heading_str = f"#heading={heading_id}" if heading_id else ""
|
||||
return f"https://docs.google.com/document/d/{doc_id}/edit?tab={tab_id}{heading_str}"
|
||||
|
||||
|
||||
def _extract_id_from_heading(paragraph: dict[str, Any]) -> str:
|
||||
"""Extracts the id from a heading paragraph element"""
|
||||
return paragraph["paragraphStyle"]["headingId"]
|
||||
|
||||
|
||||
def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
|
||||
"""Extracts the text content from a paragraph element"""
|
||||
text_elements = []
|
||||
for element in paragraph.get("elements", []):
|
||||
if "textRun" in element:
|
||||
text_elements.append(element["textRun"].get("content", ""))
|
||||
|
||||
# Handle links
|
||||
if "textStyle" in element and "link" in element["textStyle"]:
|
||||
text_elements.append(f"({element['textStyle']['link'].get('url', '')})")
|
||||
|
||||
if "person" in element:
|
||||
name = element["person"].get("personProperties", {}).get("name", "")
|
||||
email = element["person"].get("personProperties", {}).get("email", "")
|
||||
person_str = "<Person|"
|
||||
if name:
|
||||
person_str += f"name: {name}, "
|
||||
if email:
|
||||
person_str += f"email: {email}"
|
||||
person_str += ">"
|
||||
text_elements.append(person_str)
|
||||
|
||||
if "richLink" in element:
|
||||
props = element["richLink"].get("richLinkProperties", {})
|
||||
title = props.get("title", "")
|
||||
uri = props.get("uri", "")
|
||||
link_str = f"[{title}]({uri})"
|
||||
text_elements.append(link_str)
|
||||
|
||||
return "".join(text_elements)
|
||||
|
||||
|
||||
def _extract_text_from_table(table: dict[str, Any]) -> str:
|
||||
"""
|
||||
Extracts the text content from a table element.
|
||||
"""
|
||||
row_strs = []
|
||||
|
||||
for row in table.get("tableRows", []):
|
||||
cells = row.get("tableCells", [])
|
||||
cell_strs = []
|
||||
for cell in cells:
|
||||
child_elements = cell.get("content", {})
|
||||
cell_str = []
|
||||
for child_elem in child_elements:
|
||||
if "paragraph" not in child_elem:
|
||||
continue
|
||||
cell_str.append(_extract_text_from_paragraph(child_elem["paragraph"]))
|
||||
cell_strs.append("".join(cell_str))
|
||||
row_strs.append(", ".join(cell_strs))
|
||||
return "\n".join(row_strs)
|
||||
|
||||
|
||||
def get_tab_sections(tab: dict[str, Any], doc_id: str) -> list[TextSection]:
|
||||
tab_id = tab["tabProperties"]["tabId"]
|
||||
content = tab.get("documentTab", {}).get("body", {}).get("content", [])
|
||||
|
||||
sections: list[TextSection] = []
|
||||
current_section: list[str] = []
|
||||
current_heading = CurrentHeading(id=None, text="")
|
||||
|
||||
for element in content:
|
||||
if "paragraph" in element:
|
||||
paragraph = element["paragraph"]
|
||||
|
||||
# If this is not a heading, add content to current section
|
||||
if not _is_heading(paragraph):
|
||||
text = _extract_text_from_paragraph(paragraph)
|
||||
if text.strip():
|
||||
current_section.append(text)
|
||||
continue
|
||||
|
||||
_add_finished_section(sections, doc_id, tab_id, current_heading, current_section)
|
||||
|
||||
current_section = []
|
||||
|
||||
# Start new heading
|
||||
heading_id = _extract_id_from_heading(paragraph)
|
||||
heading_text = _extract_text_from_paragraph(paragraph)
|
||||
current_heading = CurrentHeading(
|
||||
id=heading_id,
|
||||
text=heading_text,
|
||||
)
|
||||
elif "table" in element:
|
||||
text = _extract_text_from_table(element["table"])
|
||||
if text.strip():
|
||||
current_section.append(text)
|
||||
|
||||
# Don't forget to add the last section
|
||||
_add_finished_section(sections, doc_id, tab_id, current_heading, current_section)
|
||||
|
||||
return sections
|
||||
@ -1,77 +0,0 @@
|
||||
"""Google Drive connector"""
|
||||
|
||||
from typing import Any
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorValidationError,
|
||||
InsufficientPermissionsError, ConnectorMissingCredentialError
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.utils import (
|
||||
get_google_creds,
|
||||
get_gmail_service
|
||||
)
|
||||
|
||||
|
||||
|
||||
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Google Drive connector for accessing Google Drive files and folders"""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.drive_service = None
|
||||
self.credentials = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load Google Drive credentials"""
|
||||
try:
|
||||
creds, new_creds = get_google_creds(credentials, "drive")
|
||||
self.credentials = creds
|
||||
|
||||
if creds:
|
||||
self.drive_service = get_gmail_service(creds, credentials.get("primary_admin_email", ""))
|
||||
|
||||
return new_creds
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"Google Drive: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Google Drive connector settings"""
|
||||
if not self.drive_service:
|
||||
raise ConnectorMissingCredentialError("Google Drive")
|
||||
|
||||
try:
|
||||
# Test connection by listing files
|
||||
self.drive_service.files().list(pageSize=1).execute()
|
||||
except HttpError as e:
|
||||
if e.resp.status in [401, 403]:
|
||||
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
|
||||
else:
|
||||
raise ConnectorValidationError(f"Google Drive validation error: {e}")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll Google Drive for recent file changes"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
|
||||
def load_from_state(self) -> Any:
|
||||
"""Load files from Google Drive state"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> Any:
|
||||
"""Retrieve all simplified documents with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
0
common/data_source/google_util/__init__.py
Normal file
0
common/data_source/google_util/__init__.py
Normal file
157
common/data_source/google_util/auth.py
Normal file
157
common/data_source/google_util/auth.py
Normal file
@ -0,0 +1,157 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
|
||||
|
||||
from common.data_source.config import OAUTH_GOOGLE_DRIVE_CLIENT_ID, OAUTH_GOOGLE_DRIVE_CLIENT_SECRET, DocumentSource
|
||||
from common.data_source.google_util.constant import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
GOOGLE_SCOPES,
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
|
||||
|
||||
|
||||
def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str:
|
||||
"""we really don't want to be persisting the client id and secret anywhere but the
|
||||
environment.
|
||||
|
||||
Returns a string of serialized json.
|
||||
"""
|
||||
|
||||
# strip the client id and secret
|
||||
oauth_creds_json_str = oauth_creds.to_json()
|
||||
oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str)
|
||||
oauth_creds_sanitized_json.pop("client_id", None)
|
||||
oauth_creds_sanitized_json.pop("client_secret", None)
|
||||
oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json)
|
||||
return oauth_creds_sanitized_json_str
|
||||
|
||||
|
||||
def get_google_creds(
|
||||
credentials: dict[str, str],
|
||||
source: DocumentSource,
|
||||
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
||||
"""Checks for two different types of credentials.
|
||||
(1) A credential which holds a token acquired via a user going through
|
||||
the Google OAuth flow.
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
|
||||
Return a tuple where:
|
||||
The first element is the requested credentials
|
||||
The second element is a new credentials dict that the caller should write back
|
||||
to the db. This happens if token rotation occurs while loading credentials.
|
||||
"""
|
||||
oauth_creds = None
|
||||
service_creds = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
# OAUTH
|
||||
authentication_method: str = credentials.get(
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
GoogleOAuthAuthenticationMethod.UPLOADED,
|
||||
)
|
||||
|
||||
credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
|
||||
credentials_dict = json.loads(credentials_dict_str)
|
||||
|
||||
regenerated_from_client_secret = False
|
||||
if "client_id" not in credentials_dict or "client_secret" not in credentials_dict or "refresh_token" not in credentials_dict:
|
||||
try:
|
||||
credentials_dict = ensure_oauth_token_dict(credentials_dict, source)
|
||||
except Exception as exc:
|
||||
raise PermissionError(
|
||||
"Google Drive OAuth credentials are incomplete. Please finish the OAuth flow to generate access tokens."
|
||||
) from exc
|
||||
credentials_dict_str = json.dumps(credentials_dict)
|
||||
regenerated_from_client_secret = True
|
||||
|
||||
# only send what get_google_oauth_creds needs
|
||||
authorized_user_info = {}
|
||||
|
||||
# oauth_interactive is sanitized and needs credentials from the environment
|
||||
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
else:
|
||||
authorized_user_info["client_id"] = credentials_dict["client_id"]
|
||||
authorized_user_info["client_secret"] = credentials_dict["client_secret"]
|
||||
|
||||
authorized_user_info["refresh_token"] = credentials_dict["refresh_token"]
|
||||
|
||||
authorized_user_info["token"] = credentials_dict["token"]
|
||||
authorized_user_info["expiry"] = credentials_dict["expiry"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(token_json_str=token_json_str, source=source)
|
||||
|
||||
# tell caller to update token stored in DB if the refresh token changed
|
||||
if oauth_creds:
|
||||
should_persist = regenerated_from_client_secret or oauth_creds.refresh_token != authorized_user_info["refresh_token"]
|
||||
if should_persist:
|
||||
# if oauth_interactive, sanitize the credentials so they don't get stored in the db
|
||||
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
|
||||
oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
else:
|
||||
oauth_creds_json_str = oauth_creds.to_json()
|
||||
|
||||
new_creds_dict = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY],
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method,
|
||||
}
|
||||
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
# SERVICE ACCOUNT
|
||||
service_account_key_json_str = credentials[DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY]
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
|
||||
service_creds = ServiceAccountCredentials.from_service_account_info(service_account_key, scopes=GOOGLE_SCOPES[source])
|
||||
|
||||
if not service_creds.valid or not service_creds.expired:
|
||||
service_creds.refresh(Request())
|
||||
|
||||
if not service_creds.valid:
|
||||
raise PermissionError(f"Unable to access {source} - service account credentials are invalid.")
|
||||
|
||||
creds: ServiceAccountCredentials | OAuthCredentials | None = oauth_creds or service_creds
|
||||
if creds is None:
|
||||
raise PermissionError(f"Unable to access {source} - unknown credential structure.")
|
||||
|
||||
return creds, new_creds_dict
|
||||
|
||||
|
||||
def get_google_oauth_creds(token_json_str: str, source: DocumentSource) -> OAuthCredentials | None:
|
||||
"""creds_json only needs to contain client_id, client_secret and refresh_token to
|
||||
refresh the creds.
|
||||
|
||||
expiry and token are optional ... however, if passing in expiry, token
|
||||
should also be passed in or else we may not return any creds.
|
||||
(probably a sign we should refactor the function)
|
||||
"""
|
||||
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(
|
||||
info=creds_json,
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
)
|
||||
if creds.valid:
|
||||
return creds
|
||||
|
||||
if creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
logging.info("Refreshed Google Drive tokens.")
|
||||
return creds
|
||||
except Exception:
|
||||
logging.exception("Failed to refresh google drive access token")
|
||||
return None
|
||||
|
||||
return None
|
||||
103
common/data_source/google_util/constant.py
Normal file
103
common/data_source/google_util/constant.py
Normal file
@ -0,0 +1,103 @@
|
||||
from enum import Enum
|
||||
|
||||
from common.data_source.config import DocumentSource
|
||||
|
||||
SLIM_BATCH_SIZE = 500
|
||||
# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
|
||||
# this is counted under `/auth/drive.readonly`
|
||||
GOOGLE_SCOPES = {
|
||||
DocumentSource.GOOGLE_DRIVE: [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly",
|
||||
],
|
||||
DocumentSource.GMAIL: [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# This is the Oauth token
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
# This is the service account key
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
# The email saved for both auth types
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
|
||||
|
||||
# https://developers.google.com/workspace/guides/create-credentials
|
||||
# Internally defined authentication method type.
|
||||
# The value must be one of "oauth_interactive" or "uploaded"
|
||||
# Used to disambiguate whether credentials have already been created via
|
||||
# certain methods and what actions we allow users to take
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
|
||||
|
||||
|
||||
class GoogleOAuthAuthenticationMethod(str, Enum):
|
||||
OAUTH_INTERACTIVE = "oauth_interactive"
|
||||
UPLOADED = "uploaded"
|
||||
|
||||
|
||||
USER_FIELDS = "nextPageToken, users(primaryEmail)"
|
||||
|
||||
|
||||
# Error message substrings
|
||||
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
|
||||
SCOPE_INSTRUCTIONS = ""
|
||||
|
||||
|
||||
GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE = """<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>Google Drive Authorization</title>
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
background: #f8fafc;
|
||||
color: #0f172a;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
min-height: 100vh;
|
||||
margin: 0;
|
||||
}}
|
||||
.card {{
|
||||
background: white;
|
||||
padding: 32px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 8px 30px rgba(15, 23, 42, 0.1);
|
||||
max-width: 420px;
|
||||
text-align: center;
|
||||
}}
|
||||
h1 {{
|
||||
font-size: 1.5rem;
|
||||
margin-bottom: 12px;
|
||||
}}
|
||||
p {{
|
||||
font-size: 0.95rem;
|
||||
line-height: 1.5;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<h1>{heading}</h1>
|
||||
<p>{message}</p>
|
||||
<p>You can close this window.</p>
|
||||
</div>
|
||||
<script>
|
||||
(function(){{
|
||||
if (window.opener) {{
|
||||
window.opener.postMessage({payload_json}, "*");
|
||||
}}
|
||||
{auto_close}
|
||||
}})();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
191
common/data_source/google_util/oauth_flow.py
Normal file
191
common/data_source/google_util/oauth_flow.py
Normal file
@ -0,0 +1,191 @@
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Callable
|
||||
|
||||
import requests
|
||||
|
||||
from common.data_source.config import DocumentSource
|
||||
from common.data_source.google_util.constant import GOOGLE_SCOPES
|
||||
|
||||
GOOGLE_DEVICE_CODE_URL = "https://oauth2.googleapis.com/device/code"
|
||||
GOOGLE_DEVICE_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
DEFAULT_DEVICE_INTERVAL = 5
|
||||
|
||||
|
||||
def _get_requested_scopes(source: DocumentSource) -> list[str]:
|
||||
"""Return the scopes to request, honoring an optional override env var."""
|
||||
override = os.environ.get("GOOGLE_OAUTH_SCOPE_OVERRIDE", "")
|
||||
if override.strip():
|
||||
scopes = [scope.strip() for scope in override.split(",") if scope.strip()]
|
||||
if scopes:
|
||||
return scopes
|
||||
return GOOGLE_SCOPES[source]
|
||||
|
||||
|
||||
def _get_oauth_timeout_secs() -> int:
|
||||
raw_timeout = os.environ.get("GOOGLE_OAUTH_FLOW_TIMEOUT_SECS", "300").strip()
|
||||
try:
|
||||
timeout = int(raw_timeout)
|
||||
except ValueError:
|
||||
timeout = 300
|
||||
return timeout
|
||||
|
||||
|
||||
def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_message: str) -> Any:
|
||||
if timeout_secs <= 0:
|
||||
return func()
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
error: dict[str, BaseException] = {}
|
||||
|
||||
def _target() -> None:
|
||||
try:
|
||||
result["value"] = func()
|
||||
except BaseException as exc: # pragma: no cover
|
||||
error["error"] = exc
|
||||
|
||||
thread = threading.Thread(target=_target, daemon=True)
|
||||
thread.start()
|
||||
thread.join(timeout_secs)
|
||||
if thread.is_alive():
|
||||
raise TimeoutError(timeout_message)
|
||||
if "error" in error:
|
||||
raise error["error"]
|
||||
return result.get("value")
|
||||
|
||||
|
||||
def _extract_client_info(credentials: dict[str, Any]) -> tuple[str, str | None]:
|
||||
if "client_id" in credentials:
|
||||
return credentials["client_id"], credentials.get("client_secret")
|
||||
for key in ("installed", "web"):
|
||||
if key in credentials and isinstance(credentials[key], dict):
|
||||
nested = credentials[key]
|
||||
if "client_id" not in nested:
|
||||
break
|
||||
return nested["client_id"], nested.get("client_secret")
|
||||
raise ValueError("Provided Google OAuth credentials are missing client_id.")
|
||||
|
||||
|
||||
def start_device_authorization_flow(
|
||||
credentials: dict[str, Any],
|
||||
source: DocumentSource,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
client_id, client_secret = _extract_client_info(credentials)
|
||||
data = {
|
||||
"client_id": client_id,
|
||||
"scope": " ".join(_get_requested_scopes(source)),
|
||||
}
|
||||
if client_secret:
|
||||
data["client_secret"] = client_secret
|
||||
resp = requests.post(GOOGLE_DEVICE_CODE_URL, data=data, timeout=15)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
state = {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"device_code": payload.get("device_code"),
|
||||
"interval": payload.get("interval", DEFAULT_DEVICE_INTERVAL),
|
||||
}
|
||||
response_data = {
|
||||
"user_code": payload.get("user_code"),
|
||||
"verification_url": payload.get("verification_url") or payload.get("verification_uri"),
|
||||
"verification_url_complete": payload.get("verification_url_complete")
|
||||
or payload.get("verification_uri_complete"),
|
||||
"expires_in": payload.get("expires_in"),
|
||||
"interval": state["interval"],
|
||||
}
|
||||
return state, response_data
|
||||
|
||||
|
||||
def poll_device_authorization_flow(state: dict[str, Any]) -> dict[str, Any]:
|
||||
data = {
|
||||
"client_id": state["client_id"],
|
||||
"device_code": state["device_code"],
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
}
|
||||
if state.get("client_secret"):
|
||||
data["client_secret"] = state["client_secret"]
|
||||
resp = requests.post(GOOGLE_DEVICE_TOKEN_URL, data=data, timeout=20)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
|
||||
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
|
||||
scopes = _get_requested_scopes(source)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
client_config,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
open_browser = os.environ.get("GOOGLE_OAUTH_OPEN_BROWSER", "true").lower() != "false"
|
||||
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
|
||||
port = int(preferred_port) if preferred_port else 0
|
||||
timeout_secs = _get_oauth_timeout_secs()
|
||||
timeout_message = (
|
||||
f"Google OAuth verification timed out after {timeout_secs} seconds. "
|
||||
"Close any pending consent windows and rerun the connector configuration to try again."
|
||||
)
|
||||
|
||||
print("Launching Google OAuth flow. A browser window should open shortly.")
|
||||
print("If it does not, copy the URL shown in the console into your browser manually.")
|
||||
if timeout_secs > 0:
|
||||
print(f"You have {timeout_secs} seconds to finish granting access before the request times out.")
|
||||
|
||||
try:
|
||||
creds = _run_with_timeout(
|
||||
lambda: flow.run_local_server(port=port, open_browser=open_browser, prompt="consent"),
|
||||
timeout_secs,
|
||||
timeout_message,
|
||||
)
|
||||
except OSError as exc:
|
||||
allow_console = os.environ.get("GOOGLE_OAUTH_ALLOW_CONSOLE_FALLBACK", "true").lower() != "false"
|
||||
if not allow_console:
|
||||
raise
|
||||
print(f"Local server flow failed ({exc}). Falling back to console-based auth.")
|
||||
creds = _run_with_timeout(flow.run_console, timeout_secs, timeout_message)
|
||||
except Warning as warning:
|
||||
warning_msg = str(warning)
|
||||
if "Scope has changed" in warning_msg:
|
||||
instructions = [
|
||||
"Google rejected one or more of the requested OAuth scopes.",
|
||||
"Fix options:",
|
||||
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
|
||||
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
|
||||
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
|
||||
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
|
||||
" (be aware the connector may lose functionality).",
|
||||
]
|
||||
raise RuntimeError("\n".join(instructions)) from warning
|
||||
raise
|
||||
|
||||
token_dict: dict[str, Any] = json.loads(creds.to_json())
|
||||
|
||||
print("\nGoogle OAuth flow completed successfully.")
|
||||
print("Copy the JSON blob below into GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR to reuse these tokens without re-authenticating:\n")
|
||||
print(json.dumps(token_dict, indent=2))
|
||||
print()
|
||||
|
||||
return token_dict
|
||||
|
||||
|
||||
def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
|
||||
"""Return a dict that contains OAuth tokens, running the flow if only a client config is provided."""
|
||||
if "refresh_token" in credentials and "token" in credentials:
|
||||
return credentials
|
||||
|
||||
client_config: dict[str, Any] | None = None
|
||||
if "installed" in credentials:
|
||||
client_config = {"installed": credentials["installed"]}
|
||||
elif "web" in credentials:
|
||||
client_config = {"web": credentials["web"]}
|
||||
|
||||
if client_config is None:
|
||||
raise ValueError(
|
||||
"Provided Google OAuth credentials are missing both tokens and a client configuration."
|
||||
)
|
||||
|
||||
return _run_local_server_flow(client_config, source)
|
||||
120
common/data_source/google_util/resource.py
Normal file
120
common/data_source/google_util/resource.py
Normal file
@ -0,0 +1,120 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from google.auth.exceptions import RefreshError # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
|
||||
from googleapiclient.discovery import (
|
||||
Resource, # type: ignore
|
||||
build, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class GoogleDriveService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class GoogleDocsService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class AdminService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class GmailService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class RefreshableDriveObject:
|
||||
"""
|
||||
Running Google drive service retrieval functions
|
||||
involves accessing methods of the service object (ie. files().list())
|
||||
which can raise a RefreshError if the access token is expired.
|
||||
This class is a wrapper that propagates the ability to refresh the access token
|
||||
and retry the final retrieval function until execute() is called.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call_stack: Callable[[ServiceAccountCredentials | OAuthCredentials], Any],
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
creds_getter: Callable[..., ServiceAccountCredentials | OAuthCredentials],
|
||||
):
|
||||
self.call_stack = call_stack
|
||||
self.creds = creds
|
||||
self.creds_getter = creds_getter
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name == "execute":
|
||||
return self.make_refreshable_execute()
|
||||
return RefreshableDriveObject(
|
||||
lambda creds: getattr(self.call_stack(creds), name),
|
||||
self.creds,
|
||||
self.creds_getter,
|
||||
)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return RefreshableDriveObject(
|
||||
lambda creds: self.call_stack(creds)(*args, **kwargs),
|
||||
self.creds,
|
||||
self.creds_getter,
|
||||
)
|
||||
|
||||
def make_refreshable_execute(self) -> Callable:
|
||||
def execute(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return self.call_stack(self.creds).execute(*args, **kwargs)
|
||||
except RefreshError as e:
|
||||
logging.warning(f"RefreshError, going to attempt a creds refresh and retry: {e}")
|
||||
# Refresh the access token
|
||||
self.creds = self.creds_getter()
|
||||
return self.call_stack(self.creds).execute(*args, **kwargs)
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def _get_google_service(
|
||||
service_name: str,
|
||||
service_version: str,
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDriveService | GoogleDocsService | AdminService | GmailService:
|
||||
service: Resource
|
||||
if isinstance(creds, ServiceAccountCredentials):
|
||||
# NOTE: https://developers.google.com/identity/protocols/oauth2/service-account#error-codes
|
||||
creds = creds.with_subject(user_email)
|
||||
service = build(service_name, service_version, credentials=creds)
|
||||
elif isinstance(creds, OAuthCredentials):
|
||||
service = build(service_name, service_version, credentials=creds)
|
||||
|
||||
return service
|
||||
|
||||
|
||||
def get_google_docs_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDocsService:
|
||||
return _get_google_service("docs", "v1", creds, user_email)
|
||||
|
||||
|
||||
def get_drive_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDriveService:
|
||||
return _get_google_service("drive", "v3", creds, user_email)
|
||||
|
||||
|
||||
def get_admin_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> AdminService:
|
||||
return _get_google_service("admin", "directory_v1", creds, user_email)
|
||||
|
||||
|
||||
def get_gmail_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GmailService:
|
||||
return _get_google_service("gmail", "v1", creds, user_email)
|
||||
152
common/data_source/google_util/util.py
Normal file
152
common/data_source/google_util/util.py
Normal file
@ -0,0 +1,152 @@
|
||||
import logging
|
||||
import socket
|
||||
from collections.abc import Callable, Iterator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore # type: ignore
|
||||
|
||||
from common.data_source.google_drive.model import GoogleDriveFileType
|
||||
|
||||
|
||||
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
|
||||
class GoogleFields(str, Enum):
|
||||
ID = "id"
|
||||
CREATED_TIME = "createdTime"
|
||||
MODIFIED_TIME = "modifiedTime"
|
||||
NAME = "name"
|
||||
SIZE = "size"
|
||||
PARENTS = "parents"
|
||||
|
||||
|
||||
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
|
||||
PAGE_TOKEN_KEY = "pageToken"
|
||||
ORDER_BY_KEY = "orderBy"
|
||||
|
||||
|
||||
def get_file_owners(file: GoogleDriveFileType, primary_admin_email: str) -> list[str]:
|
||||
"""
|
||||
Get the owners of a file if the attribute is present.
|
||||
"""
|
||||
return [email for owner in file.get("owners", []) if (email := owner.get("emailAddress")) and email.split("@")[-1] == primary_admin_email.split("@")[-1]]
|
||||
|
||||
|
||||
# included for type purposes; caller should not need to address
|
||||
# Nones unless max_num_pages is specified. Use
|
||||
# execute_paginated_retrieval_with_max_pages instead if you want
|
||||
# the early stop + yield None after max_num_pages behavior.
|
||||
def execute_paginated_retrieval(
|
||||
retrieval_function: Callable,
|
||||
list_key: str | None = None,
|
||||
continue_on_404_or_403: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
for item in _execute_paginated_retrieval(
|
||||
retrieval_function,
|
||||
list_key,
|
||||
continue_on_404_or_403,
|
||||
**kwargs,
|
||||
):
|
||||
if not isinstance(item, str):
|
||||
yield item
|
||||
|
||||
|
||||
def execute_paginated_retrieval_with_max_pages(
|
||||
retrieval_function: Callable,
|
||||
max_num_pages: int,
|
||||
list_key: str | None = None,
|
||||
continue_on_404_or_403: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
yield from _execute_paginated_retrieval(
|
||||
retrieval_function,
|
||||
list_key,
|
||||
continue_on_404_or_403,
|
||||
max_num_pages=max_num_pages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _execute_paginated_retrieval(
|
||||
retrieval_function: Callable,
|
||||
list_key: str | None = None,
|
||||
continue_on_404_or_403: bool = False,
|
||||
max_num_pages: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
"""Execute a paginated retrieval from Google Drive API
|
||||
Args:
|
||||
retrieval_function: The specific list function to call (e.g., service.files().list)
|
||||
list_key: If specified, each object returned by the retrieval function
|
||||
will be accessed at the specified key and yielded from.
|
||||
continue_on_404_or_403: If True, the retrieval will continue even if the request returns a 404 or 403 error.
|
||||
max_num_pages: If specified, the retrieval will stop after the specified number of pages and yield None.
|
||||
**kwargs: Arguments to pass to the list function
|
||||
"""
|
||||
if "fields" not in kwargs or "nextPageToken" not in kwargs["fields"]:
|
||||
raise ValueError("fields must contain nextPageToken for execute_paginated_retrieval")
|
||||
next_page_token = kwargs.get(PAGE_TOKEN_KEY, "")
|
||||
num_pages = 0
|
||||
while next_page_token is not None:
|
||||
if max_num_pages is not None and num_pages >= max_num_pages:
|
||||
yield next_page_token
|
||||
return
|
||||
num_pages += 1
|
||||
request_kwargs = kwargs.copy()
|
||||
if next_page_token:
|
||||
request_kwargs[PAGE_TOKEN_KEY] = next_page_token
|
||||
results = _execute_single_retrieval(
|
||||
retrieval_function,
|
||||
continue_on_404_or_403,
|
||||
**request_kwargs,
|
||||
)
|
||||
|
||||
next_page_token = results.get(NEXT_PAGE_TOKEN_KEY)
|
||||
if list_key:
|
||||
for item in results.get(list_key, []):
|
||||
yield item
|
||||
else:
|
||||
yield results
|
||||
|
||||
|
||||
def _execute_single_retrieval(
|
||||
retrieval_function: Callable,
|
||||
continue_on_404_or_403: bool = False,
|
||||
**request_kwargs: Any,
|
||||
) -> GoogleDriveFileType:
|
||||
"""Execute a single retrieval from Google Drive API"""
|
||||
try:
|
||||
results = retrieval_function(**request_kwargs).execute()
|
||||
except HttpError as e:
|
||||
if e.resp.status >= 500:
|
||||
results = retrieval_function()
|
||||
elif e.resp.status == 400:
|
||||
if "pageToken" in request_kwargs and "Invalid Value" in str(e) and "pageToken" in str(e):
|
||||
logging.warning(f"Invalid page token: {request_kwargs['pageToken']}, retrying from start of request")
|
||||
request_kwargs.pop("pageToken")
|
||||
return _execute_single_retrieval(
|
||||
retrieval_function,
|
||||
continue_on_404_or_403,
|
||||
**request_kwargs,
|
||||
)
|
||||
logging.error(f"Error executing request: {e}")
|
||||
raise e
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logging.debug(f"Error executing request: {e}")
|
||||
results = {}
|
||||
else:
|
||||
raise e
|
||||
elif e.resp.status == 429:
|
||||
results = retrieval_function()
|
||||
else:
|
||||
logging.exception("Error executing request:")
|
||||
raise e
|
||||
except (TimeoutError, socket.timeout) as error:
|
||||
logging.warning(
|
||||
"Timed out executing Google API request; retrying with backoff. Details: %s",
|
||||
error,
|
||||
)
|
||||
results = retrieval_function()
|
||||
|
||||
return results
|
||||
141
common/data_source/google_util/util_threadpool_concurrency.py
Normal file
141
common/data_source/google_util/util_threadpool_concurrency.py
Normal file
@ -0,0 +1,141 @@
|
||||
import collections.abc
|
||||
import copy
|
||||
import threading
|
||||
from collections.abc import Callable, Iterator, MutableMapping
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import core_schema
|
||||
|
||||
R = TypeVar("R")
|
||||
KT = TypeVar("KT") # Key type
|
||||
VT = TypeVar("VT") # Value type
|
||||
_T = TypeVar("_T") # Default type
|
||||
|
||||
|
||||
class ThreadSafeDict(MutableMapping[KT, VT]):
|
||||
"""
|
||||
A thread-safe dictionary implementation that uses a lock to ensure thread safety.
|
||||
Implements the MutableMapping interface to provide a complete dictionary-like interface.
|
||||
|
||||
Example usage:
|
||||
# Create a thread-safe dictionary
|
||||
safe_dict: ThreadSafeDict[str, int] = ThreadSafeDict()
|
||||
|
||||
# Basic operations (atomic)
|
||||
safe_dict["key"] = 1
|
||||
value = safe_dict["key"]
|
||||
del safe_dict["key"]
|
||||
|
||||
# Bulk operations (atomic)
|
||||
safe_dict.update({"key1": 1, "key2": 2})
|
||||
"""
|
||||
|
||||
def __init__(self, input_dict: dict[KT, VT] | None = None) -> None:
|
||||
self._dict: dict[KT, VT] = input_dict or {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def __getitem__(self, key: KT) -> VT:
|
||||
with self.lock:
|
||||
return self._dict[key]
|
||||
|
||||
def __setitem__(self, key: KT, value: VT) -> None:
|
||||
with self.lock:
|
||||
self._dict[key] = value
|
||||
|
||||
def __delitem__(self, key: KT) -> None:
|
||||
with self.lock:
|
||||
del self._dict[key]
|
||||
|
||||
def __iter__(self) -> Iterator[KT]:
|
||||
# Return a snapshot of keys to avoid potential modification during iteration
|
||||
with self.lock:
|
||||
return iter(list(self._dict.keys()))
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self.lock:
|
||||
return len(self._dict)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(cls.validate, handler(dict[KT, VT]))
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v: Any) -> "ThreadSafeDict[KT, VT]":
|
||||
if isinstance(v, dict):
|
||||
return ThreadSafeDict(v)
|
||||
return v
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]":
|
||||
return ThreadSafeDict(copy.deepcopy(self._dict))
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all items from the dictionary atomically."""
|
||||
with self.lock:
|
||||
self._dict.clear()
|
||||
|
||||
def copy(self) -> dict[KT, VT]:
|
||||
"""Return a shallow copy of the dictionary atomically."""
|
||||
with self.lock:
|
||||
return self._dict.copy()
|
||||
|
||||
@overload
|
||||
def get(self, key: KT) -> VT | None: ...
|
||||
|
||||
@overload
|
||||
def get(self, key: KT, default: VT | _T) -> VT | _T: ...
|
||||
|
||||
def get(self, key: KT, default: Any = None) -> Any:
|
||||
"""Get a value with a default, atomically."""
|
||||
with self.lock:
|
||||
return self._dict.get(key, default)
|
||||
|
||||
def pop(self, key: KT, default: Any = None) -> Any:
|
||||
"""Remove and return a value with optional default, atomically."""
|
||||
with self.lock:
|
||||
if default is None:
|
||||
return self._dict.pop(key)
|
||||
return self._dict.pop(key, default)
|
||||
|
||||
def setdefault(self, key: KT, default: VT) -> VT:
|
||||
"""Set a default value if key is missing, atomically."""
|
||||
with self.lock:
|
||||
return self._dict.setdefault(key, default)
|
||||
|
||||
def update(self, *args: Any, **kwargs: VT) -> None:
|
||||
"""Update the dictionary atomically from another mapping or from kwargs."""
|
||||
with self.lock:
|
||||
self._dict.update(*args, **kwargs)
|
||||
|
||||
def items(self) -> collections.abc.ItemsView[KT, VT]:
|
||||
"""Return a view of (key, value) pairs atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.ItemsView(self)
|
||||
|
||||
def keys(self) -> collections.abc.KeysView[KT]:
|
||||
"""Return a view of keys atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.KeysView(self)
|
||||
|
||||
def values(self) -> collections.abc.ValuesView[VT]:
|
||||
"""Return a view of values atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.ValuesView(self)
|
||||
|
||||
@overload
|
||||
def atomic_get_set(self, key: KT, value_callback: Callable[[VT], VT], default: VT) -> tuple[VT, VT]: ...
|
||||
|
||||
@overload
|
||||
def atomic_get_set(self, key: KT, value_callback: Callable[[VT | _T], VT], default: VT | _T) -> tuple[VT | _T, VT]: ...
|
||||
|
||||
def atomic_get_set(self, key: KT, value_callback: Callable[[Any], VT], default: Any = None) -> tuple[Any, VT]:
|
||||
"""Replace a value from the dict with a function applied to the previous value, atomically.
|
||||
|
||||
Returns:
|
||||
A tuple of the previous value and the new value.
|
||||
"""
|
||||
with self.lock:
|
||||
val = self._dict.get(key, default)
|
||||
new_val = value_callback(val)
|
||||
self._dict[key] = new_val
|
||||
return val, new_val
|
||||
@ -305,4 +305,4 @@ class ProcessedSlackMessage:
|
||||
SecondsSinceUnixEpoch = float
|
||||
GenerateDocumentsOutput = Any
|
||||
GenerateSlimDocumentOutput = Any
|
||||
CheckpointOutput = Any
|
||||
CheckpointOutput = Any
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from retry import retry
|
||||
|
||||
@ -33,7 +32,7 @@ from common.data_source.utils import (
|
||||
batch_generator,
|
||||
fetch_notion_data,
|
||||
properties_to_str,
|
||||
filter_pages_by_time
|
||||
filter_pages_by_time, datetime_from_string
|
||||
)
|
||||
|
||||
|
||||
@ -253,6 +252,8 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
all_child_page_ids: list[str] = []
|
||||
|
||||
for page in pages:
|
||||
if isinstance(page, dict):
|
||||
page = NotionPage(**page)
|
||||
if page.id in self.indexed_pages:
|
||||
logging.debug(f"Already indexed page with ID '{page.id}'. Skipping.")
|
||||
continue
|
||||
@ -291,9 +292,9 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
blob=blob,
|
||||
source=DocumentSource.NOTION,
|
||||
semantic_identifier=page_title,
|
||||
extension="txt",
|
||||
extension=".txt",
|
||||
size_bytes=len(blob),
|
||||
doc_updated_at=datetime.fromisoformat(page.last_edited_time).astimezone(timezone.utc)
|
||||
doc_updated_at=datetime_from_string(page.last_edited_time)
|
||||
)
|
||||
|
||||
if self.recursive_index_enabled and all_child_page_ids:
|
||||
|
||||
@ -9,15 +9,16 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from collections.abc import Callable, Generator, Iterator, Mapping, Sequence
|
||||
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, as_completed, wait
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from functools import lru_cache, wraps
|
||||
from io import BytesIO
|
||||
from itertools import islice
|
||||
from numbers import Integral
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, IO, TypeVar, cast, Iterable, Generic
|
||||
from urllib.parse import quote, urlparse, urljoin, parse_qs
|
||||
from typing import IO, Any, Generic, Iterable, Optional, Protocol, TypeVar, cast
|
||||
from urllib.parse import parse_qs, quote, urljoin, urlparse
|
||||
|
||||
import boto3
|
||||
import chardet
|
||||
@ -25,8 +26,6 @@ import requests
|
||||
from botocore.client import Config
|
||||
from botocore.credentials import RefreshableCredentials
|
||||
from botocore.session import get_session
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.errors import HttpError
|
||||
from mypy_boto3_s3 import S3Client
|
||||
from retry import retry
|
||||
@ -35,15 +34,18 @@ from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web import SlackResponse
|
||||
|
||||
from common.data_source.config import (
|
||||
BlobType,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
_ITERATION_LIMIT,
|
||||
_NOTION_CALL_TIMEOUT,
|
||||
_SLACK_LIMIT,
|
||||
CONFLUENCE_OAUTH_TOKEN_URL,
|
||||
DOWNLOAD_CHUNK_SIZE,
|
||||
SIZE_THRESHOLD_BUFFER, _NOTION_CALL_TIMEOUT, _ITERATION_LIMIT, CONFLUENCE_OAUTH_TOKEN_URL,
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE, _SLACK_LIMIT, EXCLUDED_IMAGE_TYPES
|
||||
EXCLUDED_IMAGE_TYPES,
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE,
|
||||
SIZE_THRESHOLD_BUFFER,
|
||||
BlobType,
|
||||
)
|
||||
from common.data_source.exceptions import RateLimitTriedTooManyTimesError
|
||||
from common.data_source.interfaces import SecondsSinceUnixEpoch, CT, LoadFunction, \
|
||||
CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, TokenResponse, OnyxExtensionType
|
||||
from common.data_source.interfaces import CT, CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, LoadFunction, OnyxExtensionType, SecondsSinceUnixEpoch, TokenResponse
|
||||
from common.data_source.models import BasicExpertInfo, Document
|
||||
|
||||
|
||||
@ -80,11 +82,7 @@ def is_valid_image_type(mime_type: str) -> bool:
|
||||
Returns:
|
||||
True if the MIME type is a valid image type, False otherwise
|
||||
"""
|
||||
return (
|
||||
bool(mime_type)
|
||||
and mime_type.startswith("image/")
|
||||
and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
)
|
||||
return bool(mime_type) and mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
|
||||
|
||||
"""If you want to allow the external service to tell you when you've hit the rate limit,
|
||||
@ -109,18 +107,12 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
FORBIDDEN_MAX_RETRY_ATTEMPTS = 7
|
||||
FORBIDDEN_RETRY_DELAY = 10
|
||||
if attempt < FORBIDDEN_MAX_RETRY_ATTEMPTS:
|
||||
logging.warning(
|
||||
"403 error. This sometimes happens when we hit "
|
||||
f"Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds..."
|
||||
)
|
||||
logging.warning(f"403 error. This sometimes happens when we hit Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds...")
|
||||
return FORBIDDEN_RETRY_DELAY
|
||||
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
if e.response.status_code != 429 and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower():
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
@ -130,9 +122,7 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logging.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
logging.warning(f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds...")
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
@ -140,14 +130,10 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logging.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
logging.warning(f"Rate limiting with retry header. Retrying after {retry_after} seconds...")
|
||||
delay = retry_after
|
||||
else:
|
||||
logging.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
logging.warning("Rate limiting without retry header. Retrying with exponential backoff...")
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
@ -162,16 +148,10 @@ def update_param_in_path(path: str, param: str, value: str) -> str:
|
||||
parsed_url = urlparse(path)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
query_params[param] = [value]
|
||||
return (
|
||||
path.split("?")[0]
|
||||
+ "?"
|
||||
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
||||
)
|
||||
return path.split("?")[0] + "?" + "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
||||
|
||||
|
||||
def build_confluence_document_id(
|
||||
base_url: str, content_url: str, is_cloud: bool
|
||||
) -> str:
|
||||
def build_confluence_document_id(base_url: str, content_url: str, is_cloud: bool) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
@ -204,17 +184,13 @@ def get_start_param_from_url(url: str) -> int:
|
||||
return int(start_str) if start_str else 0
|
||||
|
||||
|
||||
def wrap_request_to_handle_ratelimiting(
|
||||
request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30
|
||||
) -> R:
|
||||
def wrap_request_to_handle_ratelimiting(request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30) -> R:
|
||||
def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response:
|
||||
for _ in range(max_waits):
|
||||
response = request_fn(*args, **kwargs)
|
||||
if response.status_code == 429:
|
||||
try:
|
||||
wait_time = int(
|
||||
response.headers.get("Retry-After", default_wait_time_sec)
|
||||
)
|
||||
wait_time = int(response.headers.get("Retry-After", default_wait_time_sec))
|
||||
except ValueError:
|
||||
wait_time = default_wait_time_sec
|
||||
|
||||
@ -241,6 +217,7 @@ rl_requests = _RateLimitedRequest
|
||||
|
||||
# Blob Storage Utilities
|
||||
|
||||
|
||||
def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], european_residency: bool = False) -> S3Client:
|
||||
"""Create S3 client for different blob storage types"""
|
||||
if bucket_type == BlobType.R2:
|
||||
@ -325,9 +302,7 @@ def detect_bucket_region(s3_client: S3Client, bucket_name: str) -> str | None:
|
||||
"""Detect bucket region"""
|
||||
try:
|
||||
response = s3_client.head_bucket(Bucket=bucket_name)
|
||||
bucket_region = response.get("BucketRegion") or response.get(
|
||||
"ResponseMetadata", {}
|
||||
).get("HTTPHeaders", {}).get("x-amz-bucket-region")
|
||||
bucket_region = response.get("BucketRegion") or response.get("ResponseMetadata", {}).get("HTTPHeaders", {}).get("x-amz-bucket-region")
|
||||
|
||||
if bucket_region:
|
||||
logging.debug(f"Detected bucket region: {bucket_region}")
|
||||
@ -367,9 +342,7 @@ def read_stream_with_limit(body: Any, key: str, size_threshold: int) -> bytes |
|
||||
bytes_read += len(chunk)
|
||||
|
||||
if bytes_read > size_threshold + SIZE_THRESHOLD_BUFFER:
|
||||
logging.warning(
|
||||
f"{key} exceeds size threshold of {size_threshold}. Skipping."
|
||||
)
|
||||
logging.warning(f"{key} exceeds size threshold of {size_threshold}. Skipping.")
|
||||
return None
|
||||
|
||||
return b"".join(chunks)
|
||||
@ -417,11 +390,7 @@ def read_text_file(
|
||||
try:
|
||||
line = line.decode(encoding) if isinstance(line, bytes) else line
|
||||
except UnicodeDecodeError:
|
||||
line = (
|
||||
line.decode(encoding, errors=errors)
|
||||
if isinstance(line, bytes)
|
||||
else line
|
||||
)
|
||||
line = line.decode(encoding, errors=errors) if isinstance(line, bytes) else line
|
||||
|
||||
# optionally parse metadata in the first line
|
||||
if ind == 0 and not ignore_onyx_metadata:
|
||||
@ -550,9 +519,9 @@ def to_bytesio(stream: IO[bytes]) -> BytesIO:
|
||||
return BytesIO(data)
|
||||
|
||||
|
||||
|
||||
# Slack Utilities
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_base_url(token: str) -> str:
|
||||
"""Get and cache Slack workspace base URL"""
|
||||
@ -567,9 +536,7 @@ def get_message_link(event: dict, client: WebClient, channel_id: str) -> str:
|
||||
thread_ts = event.get("thread_ts")
|
||||
base_url = get_base_url(client.token)
|
||||
|
||||
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (
|
||||
f"?thread_ts={thread_ts}" if thread_ts else ""
|
||||
)
|
||||
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (f"?thread_ts={thread_ts}" if thread_ts else "")
|
||||
return link
|
||||
|
||||
|
||||
@ -578,9 +545,7 @@ def make_slack_api_call(call: Callable[..., SlackResponse], **kwargs: Any) -> Sl
|
||||
return call(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
def make_paginated_slack_api_call(call: Callable[..., SlackResponse], **kwargs: Any) -> Generator[dict[str, Any], None, None]:
|
||||
"""Make paginated Slack API call"""
|
||||
return _make_slack_api_call_paginated(call)(**kwargs)
|
||||
|
||||
@ -652,14 +617,9 @@ class SlackTextCleaner:
|
||||
if user_id not in self._id_to_name_map:
|
||||
try:
|
||||
response = self._client.users_info(user=user_id)
|
||||
self._id_to_name_map[user_id] = (
|
||||
response["user"]["profile"]["display_name"]
|
||||
or response["user"]["profile"]["real_name"]
|
||||
)
|
||||
self._id_to_name_map[user_id] = response["user"]["profile"]["display_name"] or response["user"]["profile"]["real_name"]
|
||||
except SlackApiError as e:
|
||||
logging.exception(
|
||||
f"Error fetching data for user {user_id}: {e.response['error']}"
|
||||
)
|
||||
logging.exception(f"Error fetching data for user {user_id}: {e.response['error']}")
|
||||
raise
|
||||
|
||||
return self._id_to_name_map[user_id]
|
||||
@ -677,9 +637,7 @@ class SlackTextCleaner:
|
||||
|
||||
message = message.replace(f"<@{user_id}>", f"@{user_name}")
|
||||
except Exception:
|
||||
logging.exception(
|
||||
f"Unable to replace user ID with username for user_id '{user_id}'"
|
||||
)
|
||||
logging.exception(f"Unable to replace user ID with username for user_id '{user_id}'")
|
||||
|
||||
return message
|
||||
|
||||
@ -705,9 +663,7 @@ class SlackTextCleaner:
|
||||
"""Basic channel replacement"""
|
||||
channel_matches = re.findall(r"<#(.*?)\|(.*?)>", message)
|
||||
for channel_id, channel_name in channel_matches:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}|{channel_name}>", f"#{channel_name}"
|
||||
)
|
||||
message = message.replace(f"<#{channel_id}|{channel_name}>", f"#{channel_name}")
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
@ -732,16 +688,14 @@ class SlackTextCleaner:
|
||||
|
||||
# Gmail Utilities
|
||||
|
||||
|
||||
def is_mail_service_disabled_error(error: HttpError) -> bool:
|
||||
"""Detect if the Gmail API is telling us the mailbox is not provisioned."""
|
||||
if error.resp.status != 400:
|
||||
return False
|
||||
|
||||
error_message = str(error)
|
||||
return (
|
||||
"Mail service not enabled" in error_message
|
||||
or "failedPrecondition" in error_message
|
||||
)
|
||||
return "Mail service not enabled" in error_message or "failedPrecondition" in error_message
|
||||
|
||||
|
||||
def build_time_range_query(
|
||||
@ -789,59 +743,11 @@ def get_message_body(payload: dict[str, Any]) -> str:
|
||||
return message_body
|
||||
|
||||
|
||||
def get_google_creds(
|
||||
credentials: dict[str, Any],
|
||||
source: str
|
||||
) -> tuple[OAuthCredentials | ServiceAccountCredentials | None, dict[str, str] | None]:
|
||||
"""Get Google credentials based on authentication type."""
|
||||
# Simplified credential loading - in production this would handle OAuth and service accounts
|
||||
primary_admin_email = credentials.get(DB_CREDENTIALS_PRIMARY_ADMIN_KEY)
|
||||
|
||||
if not primary_admin_email:
|
||||
raise ValueError("Primary admin email is required")
|
||||
|
||||
# Return None for credentials and empty dict for new creds
|
||||
# In a real implementation, this would handle actual credential loading
|
||||
return None, {}
|
||||
|
||||
|
||||
def get_admin_service(creds: OAuthCredentials | ServiceAccountCredentials, admin_email: str):
|
||||
"""Get Google Admin service instance."""
|
||||
# Simplified implementation
|
||||
return None
|
||||
|
||||
|
||||
def get_gmail_service(creds: OAuthCredentials | ServiceAccountCredentials, user_email: str):
|
||||
"""Get Gmail service instance."""
|
||||
# Simplified implementation
|
||||
return None
|
||||
|
||||
|
||||
def execute_paginated_retrieval(
|
||||
retrieval_function,
|
||||
list_key: str,
|
||||
fields: str,
|
||||
**kwargs
|
||||
):
|
||||
"""Execute paginated retrieval from Google APIs."""
|
||||
# Simplified pagination implementation
|
||||
return []
|
||||
|
||||
|
||||
def execute_single_retrieval(
|
||||
retrieval_function,
|
||||
list_key: Optional[str],
|
||||
**kwargs
|
||||
):
|
||||
"""Execute single retrieval from Google APIs."""
|
||||
# Simplified single retrieval implementation
|
||||
return []
|
||||
|
||||
|
||||
def time_str_to_utc(time_str: str):
|
||||
"""Convert time string to UTC datetime."""
|
||||
from datetime import datetime
|
||||
return datetime.fromisoformat(time_str.replace('Z', '+00:00'))
|
||||
|
||||
return datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
||||
|
||||
|
||||
# Notion Utilities
|
||||
@ -865,12 +771,7 @@ def batch_generator(
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def fetch_notion_data(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
method: str = "GET",
|
||||
json_data: Optional[dict] = None
|
||||
) -> dict[str, Any]:
|
||||
def fetch_notion_data(url: str, headers: dict[str, str], method: str = "GET", json_data: Optional[dict] = None) -> dict[str, Any]:
|
||||
"""Fetch data from Notion API with retry logic."""
|
||||
try:
|
||||
if method == "GET":
|
||||
@ -899,10 +800,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
|
||||
list_properties.append(_recurse_list_properties(item))
|
||||
else:
|
||||
list_properties.append(str(item))
|
||||
return (
|
||||
", ".join([list_property for list_property in list_properties if list_property])
|
||||
or None
|
||||
)
|
||||
return ", ".join([list_property for list_property in list_properties if list_property]) or None
|
||||
|
||||
def _recurse_properties(inner_dict: dict[str, Any]) -> str | None:
|
||||
sub_inner_dict: dict[str, Any] | list[Any] | str = inner_dict
|
||||
@ -955,12 +853,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
|
||||
return result
|
||||
|
||||
|
||||
def filter_pages_by_time(
|
||||
pages: list[dict[str, Any]],
|
||||
start: float,
|
||||
end: float,
|
||||
filter_field: str = "last_edited_time"
|
||||
) -> list[dict[str, Any]]:
|
||||
def filter_pages_by_time(pages: list[dict[str, Any]], start: float, end: float, filter_field: str = "last_edited_time") -> list[dict[str, Any]]:
|
||||
"""Filter pages by time range."""
|
||||
from datetime import datetime
|
||||
|
||||
@ -1005,9 +898,7 @@ def load_all_docs_from_checkpoint_connector(
|
||||
) -> list[Document]:
|
||||
return _load_all_docs(
|
||||
connector=connector,
|
||||
load=lambda checkpoint: connector.load_from_checkpoint(
|
||||
start=start, end=end, checkpoint=checkpoint
|
||||
),
|
||||
load=lambda checkpoint: connector.load_from_checkpoint(start=start, end=end, checkpoint=checkpoint),
|
||||
)
|
||||
|
||||
|
||||
@ -1042,9 +933,7 @@ def process_confluence_user_profiles_override(
|
||||
]
|
||||
|
||||
|
||||
def confluence_refresh_tokens(
|
||||
client_id: str, client_secret: str, cloud_id: str, refresh_token: str
|
||||
) -> dict[str, Any]:
|
||||
def confluence_refresh_tokens(client_id: str, client_secret: str, cloud_id: str, refresh_token: str) -> dict[str, Any]:
|
||||
# rotate the refresh and access token
|
||||
# Note that access tokens are only good for an hour in confluence cloud,
|
||||
# so we're going to have problems if the connector runs for longer
|
||||
@ -1080,9 +969,7 @@ def confluence_refresh_tokens(
|
||||
|
||||
|
||||
class TimeoutThread(threading.Thread, Generic[R]):
|
||||
def __init__(
|
||||
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
):
|
||||
def __init__(self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any):
|
||||
super().__init__()
|
||||
self.timeout = timeout
|
||||
self.func = func
|
||||
@ -1097,14 +984,10 @@ class TimeoutThread(threading.Thread, Generic[R]):
|
||||
self.exception = e
|
||||
|
||||
def end(self) -> None:
|
||||
raise TimeoutError(
|
||||
f"Function {self.func.__name__} timed out after {self.timeout} seconds"
|
||||
)
|
||||
raise TimeoutError(f"Function {self.func.__name__} timed out after {self.timeout} seconds")
|
||||
|
||||
|
||||
def run_with_timeout(
|
||||
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
) -> R:
|
||||
def run_with_timeout(timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any) -> R:
|
||||
"""
|
||||
Executes a function with a timeout. If the function doesn't complete within the specified
|
||||
timeout, raises TimeoutError.
|
||||
@ -1136,7 +1019,81 @@ def validate_attachment_filetype(
|
||||
title = attachment.get("title", "")
|
||||
extension = Path(title).suffix.lstrip(".").lower() if "." in title else ""
|
||||
|
||||
return is_accepted_file_ext(
|
||||
"." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
|
||||
)
|
||||
return is_accepted_file_ext("." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document)
|
||||
|
||||
|
||||
class CallableProtocol(Protocol):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
def run_functions_tuples_in_parallel(
|
||||
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
|
||||
allow_failures: bool = False,
|
||||
max_workers: int | None = None,
|
||||
) -> list[Any]:
|
||||
"""
|
||||
Executes multiple functions in parallel and returns a list of the results for each function.
|
||||
This function preserves contextvars across threads, which is important for maintaining
|
||||
context like tenant IDs in database sessions.
|
||||
|
||||
Args:
|
||||
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
|
||||
allow_failures: if set to True, then the function result will just be None
|
||||
max_workers: Max number of worker threads
|
||||
|
||||
Returns:
|
||||
list: A list of results from each function, in the same order as the input functions.
|
||||
"""
|
||||
workers = min(max_workers, len(functions_with_args)) if max_workers is not None else len(functions_with_args)
|
||||
|
||||
if workers <= 0:
|
||||
return []
|
||||
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# The primary reason for propagating contextvars is to allow acquiring a db session
|
||||
# that respects tenant id. Context.run is expected to be low-overhead, but if we later
|
||||
# find that it is increasing latency we can make using it optional.
|
||||
future_to_index = {executor.submit(contextvars.copy_context().run, func, *args): i for i, (func, args) in enumerate(functions_with_args)}
|
||||
|
||||
for future in as_completed(future_to_index):
|
||||
index = future_to_index[future]
|
||||
try:
|
||||
results.append((index, future.result()))
|
||||
except Exception as e:
|
||||
logging.exception(f"Function at index {index} failed due to {e}")
|
||||
results.append((index, None)) # type: ignore
|
||||
|
||||
if not allow_failures:
|
||||
raise
|
||||
|
||||
results.sort(key=lambda x: x[0])
|
||||
return [result for index, result in results]
|
||||
|
||||
|
||||
def _next_or_none(ind: int, gen: Iterator[R]) -> tuple[int, R | None]:
|
||||
return ind, next(gen, None)
|
||||
|
||||
|
||||
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
|
||||
"""
|
||||
Runs the list of generators with thread-level parallelism, yielding
|
||||
results as available. The asynchronous nature of this yielding means
|
||||
that stopping the returned iterator early DOES NOT GUARANTEE THAT NO
|
||||
FURTHER ITEMS WERE PRODUCED by the input gens. Only use this function
|
||||
if you are consuming all elements from the generators OR it is acceptable
|
||||
for some extra generator code to run and not have the result(s) yielded.
|
||||
"""
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_index: dict[Future[tuple[int, R | None]], int] = {executor.submit(_next_or_none, ind, gen): ind for ind, gen in enumerate(gens)}
|
||||
|
||||
next_ind = len(gens)
|
||||
while future_to_index:
|
||||
done, _ = wait(future_to_index, return_when=FIRST_COMPLETED)
|
||||
for future in done:
|
||||
ind, result = future.result()
|
||||
if result is not None:
|
||||
yield result
|
||||
future_to_index[executor.submit(_next_or_none, ind, gens[ind])] = next_ind
|
||||
next_ind += 1
|
||||
del future_to_index[future]
|
||||
|
||||
@ -39,5 +39,6 @@
|
||||
"n_hop_with_weight": {"type": "varchar", "default": ""},
|
||||
"removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"doc_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"toc_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}
|
||||
"toc_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"raptor_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,TTS,TEXT RE-RANK,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
"rank": "999",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "gpt-5",
|
||||
@ -174,6 +175,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM",
|
||||
"status": "1",
|
||||
"rank": "930",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "grok-4",
|
||||
@ -330,6 +332,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,TEXT RE-RANK,TTS,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
"rank": "950",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "Moonshot-Kimi-K2-Instruct",
|
||||
@ -714,6 +717,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
"rank": "940",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "glm-4.5",
|
||||
@ -859,6 +863,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
"rank": "830",
|
||||
"llm": []
|
||||
},
|
||||
{
|
||||
@ -880,7 +885,8 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
"llm": []
|
||||
"llm": [],
|
||||
"rank": "890"
|
||||
},
|
||||
{
|
||||
"name": "VLLM",
|
||||
@ -892,8 +898,9 @@
|
||||
{
|
||||
"name": "Moonshot",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
|
||||
"status": "1",
|
||||
"rank": "960",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "kimi-thinking-preview",
|
||||
@ -916,6 +923,20 @@
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "kimi-k2-thinking",
|
||||
"tags": "LLM,CHAT,256k",
|
||||
"max_tokens": 262144,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "kimi-k2-thinking-turbo",
|
||||
"tags": "LLM,CHAT,256k",
|
||||
"max_tokens": 262144,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "kimi-k2-turbo-preview",
|
||||
"tags": "LLM,CHAT,256k",
|
||||
@ -932,25 +953,46 @@
|
||||
},
|
||||
{
|
||||
"llm_name": "moonshot-v1-8k",
|
||||
"tags": "LLM,CHAT,",
|
||||
"max_tokens": 7900,
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "moonshot-v1-32k",
|
||||
"tags": "LLM,CHAT,",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "moonshot-v1-128k",
|
||||
"tags": "LLM,CHAT",
|
||||
"max_tokens": 128000,
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "moonshot-v1-8k-vision-preview",
|
||||
"tags": "LLM,IMAGE2TEXT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "image2text",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "moonshot-v1-32k-vision-preview",
|
||||
"tags": "LLM,IMAGE2TEXT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "image2text",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "moonshot-v1-128k-vision-preview",
|
||||
"tags": "LLM,IMAGE2TEXT,128k",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "image2text",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "moonshot-v1-auto",
|
||||
"tags": "LLM,CHAT,",
|
||||
@ -979,6 +1021,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM",
|
||||
"status": "1",
|
||||
"rank": "970",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "deepseek-chat",
|
||||
@ -1157,6 +1200,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
"rank": "810",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "abab6.5-chat",
|
||||
@ -1196,6 +1240,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,MODERATION",
|
||||
"status": "1",
|
||||
"rank": "910",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "codestral-latest",
|
||||
@ -1289,6 +1334,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
"rank": "850",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "gpt-4o-mini",
|
||||
@ -1373,6 +1419,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
"rank": "860",
|
||||
"llm": []
|
||||
},
|
||||
{
|
||||
@ -1380,6 +1427,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
|
||||
"status": "1",
|
||||
"rank": "980",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "gemini-2.5-flash",
|
||||
@ -1435,6 +1483,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM",
|
||||
"status": "1",
|
||||
"rank": "800",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "gemma2-9b-it",
|
||||
@ -1494,7 +1543,8 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,IMAGE2TEXT",
|
||||
"status": "1",
|
||||
"llm": []
|
||||
"llm": [],
|
||||
"rank": "840"
|
||||
},
|
||||
{
|
||||
"name": "StepFun",
|
||||
@ -1544,6 +1594,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING, TEXT RE-RANK",
|
||||
"status": "1",
|
||||
"rank": "790",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "01-ai/yi-large",
|
||||
@ -2298,6 +2349,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING, TEXT RE-RANK",
|
||||
"status": "1",
|
||||
"rank": "900",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "command-r-plus",
|
||||
@ -2405,108 +2457,6 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "LeptonAI",
|
||||
"logo": "",
|
||||
"tags": "LLM",
|
||||
"status": "1",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "dolphin-mixtral-8x7b",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gemma-7b",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "llama3-1-8b",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "llama3-8b",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "llama2-13b",
|
||||
"tags": "LLM,CHAT,4K",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "llama3-1-70b",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "llama3-70b",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "llama3-1-405b",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "mistral-7b",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "mistral-8x7b",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "nous-hermes-llama2",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "openchat-3-5",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "toppy-m-7b",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "wizardlm-2-7b",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "wizardlm-2-8x22b",
|
||||
"tags": "LLM,CHAT,64K",
|
||||
"max_tokens": 65536,
|
||||
"model_type": "chat"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "TogetherAI",
|
||||
"logo": "",
|
||||
@ -2514,167 +2464,6 @@
|
||||
"status": "1",
|
||||
"llm": []
|
||||
},
|
||||
{
|
||||
"name": "PerfXCloud",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "deepseek-v2-chat",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "llama3.1:405b",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "Qwen2-72B-Instruct",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "Qwen2-72B-Instruct-GPTQ-Int4",
|
||||
"tags": "LLM,CHAT,2k",
|
||||
"max_tokens": 2048,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "Qwen2-72B-Instruct-awq-int4",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "Llama3-Chinese_v2",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "Yi-1_5-9B-Chat-16K",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"max_tokens": 16384,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "Qwen1.5-72B-Chat-GPTQ-Int4",
|
||||
"tags": "LLM,CHAT,2k",
|
||||
"max_tokens": 2048,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "Meta-Llama-3.1-8B-Instruct",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "Qwen2-7B-Instruct",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "deepseek-v2-lite-chat",
|
||||
"tags": "LLM,CHAT,2k",
|
||||
"max_tokens": 2048,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "Qwen2-7B",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "chatglm3-6b",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "Meta-Llama-3-70B-Instruct-GPTQ-Int4",
|
||||
"tags": "LLM,CHAT,1k",
|
||||
"max_tokens": 1024,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "Meta-Llama-3-8B-Instruct",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "Mistral-7B-Instruct",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "MindChat-Qwen-7B-v2",
|
||||
"tags": "LLM,CHAT,2k",
|
||||
"max_tokens": 2048,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "phi-2",
|
||||
"tags": "LLM,CHAT,2k",
|
||||
"max_tokens": 2048,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "SOLAR-10_7B-Instruct",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "Mixtral-8x7B-Instruct-v0.1-GPTQ",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "Qwen1.5-7B",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "BAAI/bge-large-en-v1.5",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 512,
|
||||
"model_type": "embedding"
|
||||
},
|
||||
{
|
||||
"llm_name": "BAAI/bge-large-zh-v1.5",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 1024,
|
||||
"model_type": "embedding"
|
||||
},
|
||||
{
|
||||
"llm_name": "BAAI/bge-m3",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "embedding"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Upstage",
|
||||
"logo": "",
|
||||
@ -2839,12 +2628,13 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,TEXT RE-RANK,IMAGE2TEXT",
|
||||
"status": "1",
|
||||
"rank": "780",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name":"THUDM/GLM-4.1V-9B-Thinking",
|
||||
"tags":"LLM,CHAT,IMAGE2TEXT, 64k",
|
||||
"max_tokens":64000,
|
||||
"model_type":"chat",
|
||||
"llm_name": "THUDM/GLM-4.1V-9B-Thinking",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT, 64k",
|
||||
"max_tokens": 64000,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
@ -2938,13 +2728,6 @@
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "Pro/deepseek-ai/DeepSeek-V3-1226",
|
||||
"tags": "LLM,CHAT,64k",
|
||||
"max_tokens": 64000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "Pro/deepseek-ai/DeepSeek-V3.1",
|
||||
"tags": "LLM,CHAT,160k",
|
||||
@ -2987,20 +2770,6 @@
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "deepseek-ai/DeepSeek-V2.5",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
@ -3057,13 +2826,6 @@
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "THUDM/chatglm3-6b",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "Pro/THUDM/glm-4-9b-chat",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
@ -3085,13 +2847,6 @@
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "Qwen/QwQ-32B-Preview",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
@ -3155,13 +2910,6 @@
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "internlm/internlm2_5-20b-chat",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "internlm/internlm2_5-7b-chat",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
@ -3197,13 +2945,6 @@
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "Pro/Qwen/Qwen2-1.5B-Instruct",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "BAAI/bge-m3",
|
||||
"tags": "LLM,EMBEDDING,8k",
|
||||
@ -3387,75 +3128,6 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "01.AI",
|
||||
"logo": "",
|
||||
"tags": "LLM,IMAGE2TEXT",
|
||||
"status": "1",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "yi-lightning",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"max_tokens": 16384,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "yi-large",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "yi-medium",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"max_tokens": 16384,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "yi-medium-200k",
|
||||
"tags": "LLM,CHAT,200k",
|
||||
"max_tokens": 204800,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "yi-spark",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"max_tokens": 16384,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "yi-large-rag",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"max_tokens": 16384,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "yi-large-fc",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "yi-large-turbo",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"max_tokens": 16384,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "yi-large-preview",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"max_tokens": 16384,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "yi-vision",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT,16k",
|
||||
"max_tokens": 16384,
|
||||
"model_type": "image2text"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Replicate",
|
||||
"logo": "",
|
||||
@ -3507,6 +3179,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM,TTS",
|
||||
"status": "1",
|
||||
"rank": "820",
|
||||
"llm": []
|
||||
},
|
||||
{
|
||||
@ -3514,6 +3187,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM",
|
||||
"status": "1",
|
||||
"rank": "880",
|
||||
"llm": []
|
||||
},
|
||||
{
|
||||
@ -3535,6 +3209,7 @@
|
||||
"logo": "",
|
||||
"tags": "LLM",
|
||||
"status": "1",
|
||||
"rank": "990",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "claude-opus-4-1-20250805",
|
||||
@ -4136,6 +3811,7 @@
|
||||
"logo": "",
|
||||
"tags": "TEXT EMBEDDING,TEXT RE-RANK",
|
||||
"status": "1",
|
||||
"rank": "920",
|
||||
"llm": []
|
||||
},
|
||||
{
|
||||
@ -4875,10 +4551,11 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Meituan",
|
||||
"name": "LongCat",
|
||||
"logo": "",
|
||||
"tags": "LLM",
|
||||
"status": "1",
|
||||
"rank": "870",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "LongCat-Flash-Chat",
|
||||
@ -5164,4 +4841,4 @@
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@ -32,6 +32,8 @@ redis:
|
||||
db: 1
|
||||
password: 'infini_rag_flow'
|
||||
host: 'localhost:6379'
|
||||
task_executor:
|
||||
message_queue_type: 'redis'
|
||||
user_default_llm:
|
||||
default_models:
|
||||
embedding_model:
|
||||
|
||||
15
docker/.env
15
docker/.env
@ -106,17 +106,11 @@ ADMIN_SVR_HTTP_PORT=9381
|
||||
SVR_MCP_PORT=9382
|
||||
|
||||
# The RAGFlow Docker image to download. v0.22+ doesn't include embedding models.
|
||||
# Defaults to the v0.21.1-slim edition, which is the RAGFlow Docker image without embedding models.
|
||||
RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1-slim
|
||||
# RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1
|
||||
# The Docker image of the v0.21.1 edition includes built-in embedding models:
|
||||
# - BAAI/bge-large-zh-v1.5
|
||||
# - maidalun1020/bce-embedding-base_v1
|
||||
#
|
||||
RAGFLOW_IMAGE=infiniflow/ragflow:v0.22.0
|
||||
|
||||
# If you cannot download the RAGFlow Docker image:
|
||||
# RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:v0.21.1
|
||||
# RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:v0.21.1
|
||||
# RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:v0.22.0
|
||||
# RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:v0.22.0
|
||||
#
|
||||
# - For the `nightly` edition, uncomment either of the following:
|
||||
# RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:nightly
|
||||
@ -217,3 +211,6 @@ REGISTER_ENABLED=1
|
||||
# Enable DocLing and Mineru
|
||||
USE_DOCLING=false
|
||||
USE_MINERU=false
|
||||
|
||||
# pptx support
|
||||
DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
|
||||
@ -77,13 +77,7 @@ The [.env](./.env) file contains important environment variables for Docker.
|
||||
- `SVR_HTTP_PORT`
|
||||
The port used to expose RAGFlow's HTTP API service to the host machine, allowing **external** access to the service running inside the Docker container. Defaults to `9380`.
|
||||
- `RAGFLOW-IMAGE`
|
||||
The Docker image edition. Available editions:
|
||||
|
||||
- `infiniflow/ragflow:v0.21.1-slim` (default): The RAGFlow Docker image without embedding models.
|
||||
- `infiniflow/ragflow:v0.21.1`: The RAGFlow Docker image with embedding models including:
|
||||
- Built-in embedding models:
|
||||
- `BAAI/bge-large-zh-v1.5`
|
||||
- `maidalun1020/bce-embedding-base_v1`
|
||||
The Docker image edition. Defaults to `infiniflow/ragflow:v0.22.0`. The RAGFlow Docker image does not include embedding models.
|
||||
|
||||
|
||||
> [!TIP]
|
||||
|
||||
@ -72,7 +72,7 @@ services:
|
||||
infinity:
|
||||
profiles:
|
||||
- infinity
|
||||
image: infiniflow/infinity:v0.6.2
|
||||
image: infiniflow/infinity:v0.6.5
|
||||
volumes:
|
||||
- infinity_data:/var/infinity
|
||||
- ./infinity_conf.toml:/infinity_conf.toml
|
||||
@ -120,8 +120,8 @@ services:
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "http://localhost:9385/healthz"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
timeout: 10s
|
||||
retries: 120
|
||||
restart: on-failure
|
||||
|
||||
mysql:
|
||||
@ -149,7 +149,7 @@ services:
|
||||
test: ["CMD", "mysqladmin" ,"ping", "-uroot", "-p${MYSQL_PASSWORD}"]
|
||||
interval: 10s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
retries: 120
|
||||
restart: on-failure
|
||||
|
||||
minio:
|
||||
@ -169,9 +169,9 @@ services:
|
||||
restart: on-failure
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
interval: 10s
|
||||
timeout: 10s
|
||||
retries: 120
|
||||
|
||||
redis:
|
||||
# swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/valkey/valkey:8
|
||||
@ -187,10 +187,9 @@ services:
|
||||
restart: on-failure
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
interval: 10s
|
||||
timeout: 10s
|
||||
retries: 120
|
||||
|
||||
|
||||
tei-cpu:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user