mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-04 03:25:30 +08:00
Compare commits
142 Commits
c8b79dfed4
...
v0.21.0
| Author | SHA1 | Date | |
|---|---|---|---|
| fdac4afd10 | |||
| 769d701f56 | |||
| 8b512cdadf | |||
| 3ae126836a | |||
| e8bfda6020 | |||
| 34c54cd459 | |||
| 3d873d98fb | |||
| fbe25b5add | |||
| 0c6c7c8fe7 | |||
| e266f9a66f | |||
| fde6e5ab39 | |||
| 67529825e2 | |||
| 738a7d5c24 | |||
| 83ec915d51 | |||
| e535099f36 | |||
| 16b5feadb7 | |||
| 960f47c4d4 | |||
| 51139de178 | |||
| 1f5167f1ca | |||
| 578ea34b3e | |||
| 5fb3d2f55c | |||
| d99d1e3518 | |||
| 5b387b68ba | |||
| f92a45dcc4 | |||
| c4b8e4845c | |||
| 87659dcd3a | |||
| 6fd9508017 | |||
| 113851a692 | |||
| 66c69d10fe | |||
| 781d49cd0e | |||
| aaae938f54 | |||
| 9e73f799b2 | |||
| 21a62130c8 | |||
| 68e47c81d4 | |||
| f11d8af936 | |||
| 74ec734d69 | |||
| 8c75803b70 | |||
| ff4239c7cf | |||
| cf5867b146 | |||
| 77481ab3ab | |||
| 9c53b3336a | |||
| 24481f0332 | |||
| 4e6b84bb41 | |||
| 65c3f0406c | |||
| 7fb8b30cc2 | |||
| acca3640f7 | |||
| 58836d84fe | |||
| ad56137a59 | |||
| 2828e321bc | |||
| 932781ea4e | |||
| 5200711441 | |||
| c21cea2038 | |||
| 6a0f448419 | |||
| 7d2f65671f | |||
| a0d5f81098 | |||
| 52f26f4643 | |||
| 313e92dd9b | |||
| fee757eb41 | |||
| b5ddc7ca05 | |||
| 534fa60b2a | |||
| 390b2b8f26 | |||
| 0283e4098f | |||
| 2cdba3d1e6 | |||
| eb0b37d7ee | |||
| 198e52e990 | |||
| a50ccf77f9 | |||
| deaf15a08b | |||
| 0d8791936e | |||
| 5d167cd772 | |||
| f35c5ed119 | |||
| fc46d6bb87 | |||
| 8252b1c5c0 | |||
| c802a6ffdd | |||
| 9b06734ced | |||
| 6ab4c1a6e9 | |||
| f631073ac2 | |||
| 8aabc2807c | |||
| d931c33ced | |||
| f4324e89d9 | |||
| f04c9e2937 | |||
| 1fc2889f98 | |||
| ee0c38da66 | |||
| c1806e1ab2 | |||
| 66d0d44a00 | |||
| 2078d88c28 | |||
| 7734ad7fcd | |||
| 1a47e136e3 | |||
| cbf04ee470 | |||
| ef0aecea3b | |||
| dfc5fa1f4d | |||
| f341dc03b8 | |||
| 4585edc20e | |||
| dba9158f9a | |||
| 82f572ff95 | |||
| 518a00630e | |||
| aa61ae24dc | |||
| fb950079ef | |||
| aec8c15e7e | |||
| 7c620bdc69 | |||
| e7dde69584 | |||
| d6eded1959 | |||
| 80f851922a | |||
| 17757930a3 | |||
| a8883905a7 | |||
| 8426cbbd02 | |||
| 0b759f559c | |||
| 2d5d10ecbf | |||
| 954bd5a1c2 | |||
| ccb1c269e8 | |||
| 6dfb0c245c | |||
| 72d1047a8f | |||
| bece37e6c8 | |||
| 59cb0eb8bc | |||
| fc56217eb3 | |||
| 723cf9443e | |||
| bd94b5dfb5 | |||
| ef59c5bab9 | |||
| 62b7c655c5 | |||
| b0b866c8fd | |||
| 3a831d0c28 | |||
| 9e323a9351 | |||
| 7ac95b759b | |||
| daea357940 | |||
| 4aa1abd8e5 | |||
| 922b5c652d | |||
| aaa97874c6 | |||
| 193d93d820 | |||
| 4058715df7 | |||
| 3f595029d7 | |||
| e8f5a4da56 | |||
| a9472e3652 | |||
| 4dd48b60f3 | |||
| e4ab8ba2de | |||
| a1f848bfe0 | |||
| f2309ff93e | |||
| 38be53cf31 | |||
| 65a06d62d8 | |||
| 10cbbb76f8 | |||
| 1c84d1b562 | |||
| 4eb7659499 | |||
| 46a61e5aff | |||
| da82566304 |
18
.github/workflows/release.yml
vendored
18
.github/workflows/release.yml
vendored
@ -25,7 +25,7 @@ jobs:
|
|||||||
- name: Check out code
|
- name: Check out code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.MY_GITHUB_TOKEN }} # Use the secret as an environment variable
|
token: ${{ secrets.GITHUB_TOKEN }} # Use the secret as an environment variable
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
fetch-tags: true
|
fetch-tags: true
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ jobs:
|
|||||||
# https://github.com/actions/upload-release-asset has been replaced by https://github.com/softprops/action-gh-release
|
# https://github.com/actions/upload-release-asset has been replaced by https://github.com/softprops/action-gh-release
|
||||||
uses: softprops/action-gh-release@v2
|
uses: softprops/action-gh-release@v2
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.MY_GITHUB_TOKEN }} # Use the secret as an environment variable
|
token: ${{ secrets.GITHUB_TOKEN }} # Use the secret as an environment variable
|
||||||
prerelease: ${{ env.PRERELEASE }}
|
prerelease: ${{ env.PRERELEASE }}
|
||||||
tag_name: ${{ env.RELEASE_TAG }}
|
tag_name: ${{ env.RELEASE_TAG }}
|
||||||
# The body field does not support environment variable substitution directly.
|
# The body field does not support environment variable substitution directly.
|
||||||
@ -120,3 +120,17 @@ jobs:
|
|||||||
packages-dir: sdk/python/dist/
|
packages-dir: sdk/python/dist/
|
||||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
verbose: true
|
verbose: true
|
||||||
|
|
||||||
|
- name: Build ragflow-cli
|
||||||
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
|
run: |
|
||||||
|
cd admin/client && \
|
||||||
|
uv build
|
||||||
|
|
||||||
|
- name: Publish client package distributions to PyPI
|
||||||
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
packages-dir: admin/client/dist/
|
||||||
|
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
verbose: true
|
||||||
|
|||||||
48
.github/workflows/tests.yml
vendored
48
.github/workflows/tests.yml
vendored
@ -34,12 +34,10 @@ jobs:
|
|||||||
# https://github.com/hmarr/debug-action
|
# https://github.com/hmarr/debug-action
|
||||||
#- uses: hmarr/debug-action@v2
|
#- uses: hmarr/debug-action@v2
|
||||||
|
|
||||||
- name: Show who triggered this workflow
|
- name: Ensure workspace ownership
|
||||||
run: |
|
run: |
|
||||||
echo "Workflow triggered by ${{ github.event_name }}"
|
echo "Workflow triggered by ${{ github.event_name }}"
|
||||||
|
echo "chown -R $USER $GITHUB_WORKSPACE" && sudo chown -R $USER $GITHUB_WORKSPACE
|
||||||
- name: Ensure workspace ownership
|
|
||||||
run: echo "chown -R $USER $GITHUB_WORKSPACE" && sudo chown -R $USER $GITHUB_WORKSPACE
|
|
||||||
|
|
||||||
# https://github.com/actions/checkout/issues/1781
|
# https://github.com/actions/checkout/issues/1781
|
||||||
- name: Check out code
|
- name: Check out code
|
||||||
@ -48,6 +46,44 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
fetch-tags: true
|
fetch-tags: true
|
||||||
|
|
||||||
|
- name: Check workflow duplication
|
||||||
|
if: ${{ !cancelled() && !failure() && (github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'ci')) }}
|
||||||
|
run: |
|
||||||
|
if [[ ${{ github.event_name }} != 'pull_request' ]]; then
|
||||||
|
HEAD=$(git rev-parse HEAD)
|
||||||
|
# Find a PR that introduced a given commit
|
||||||
|
gh auth login --with-token <<< "${{ secrets.GITHUB_TOKEN }}"
|
||||||
|
PR_NUMBER=$(gh pr list --search ${HEAD} --state merged --json number --jq .[0].number)
|
||||||
|
echo "HEAD=${HEAD}"
|
||||||
|
echo "PR_NUMBER=${PR_NUMBER}"
|
||||||
|
if [[ -n ${PR_NUMBER} ]]; then
|
||||||
|
PR_SHA_FP=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/PR_${PR_NUMBER}
|
||||||
|
if [[ -f ${PR_SHA_FP} ]]; then
|
||||||
|
read -r PR_SHA PR_RUN_ID < "${PR_SHA_FP}"
|
||||||
|
# Calculate the hash of the current workspace content
|
||||||
|
HEAD_SHA=$(git rev-parse HEAD^{tree})
|
||||||
|
if [[ ${HEAD_SHA} == ${PR_SHA} ]]; then
|
||||||
|
echo "Cancel myself since the workspace content hash is the same with PR #${PR_NUMBER} merged. See ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/actions/runs/${PR_RUN_ID} for details."
|
||||||
|
gh run cancel ${GITHUB_RUN_ID}
|
||||||
|
while true; do
|
||||||
|
status=$(gh run view ${GITHUB_RUN_ID} --json status -q .status)
|
||||||
|
[ "$status" = "completed" ] && break
|
||||||
|
sleep 5
|
||||||
|
done
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
PR_NUMBER=${{ github.event.pull_request.number }}
|
||||||
|
PR_SHA_FP=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/PR_${PR_NUMBER}
|
||||||
|
# Calculate the hash of the current workspace content
|
||||||
|
PR_SHA=$(git rev-parse HEAD^{tree})
|
||||||
|
echo "PR #${PR_NUMBER} workspace content hash: ${PR_SHA}"
|
||||||
|
mkdir -p ${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}
|
||||||
|
echo "${PR_SHA} ${GITHUB_RUN_ID}" > ${PR_SHA_FP}
|
||||||
|
fi
|
||||||
|
|
||||||
# https://github.com/astral-sh/ruff-action
|
# https://github.com/astral-sh/ruff-action
|
||||||
- name: Static check with Ruff
|
- name: Static check with Ruff
|
||||||
uses: astral-sh/ruff-action@v3
|
uses: astral-sh/ruff-action@v3
|
||||||
@ -59,11 +95,11 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-$HOME}
|
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-$HOME}
|
||||||
sudo docker pull ubuntu:22.04
|
sudo docker pull ubuntu:22.04
|
||||||
sudo docker build --progress=plain --build-arg LIGHTEN=1 --build-arg NEED_MIRROR=1 -f Dockerfile -t infiniflow/ragflow:nightly-slim .
|
sudo DOCKER_BUILDKIT=1 docker build --build-arg LIGHTEN=1 --build-arg NEED_MIRROR=1 -f Dockerfile -t infiniflow/ragflow:nightly-slim .
|
||||||
|
|
||||||
- name: Build ragflow:nightly
|
- name: Build ragflow:nightly
|
||||||
run: |
|
run: |
|
||||||
sudo docker build --progress=plain --build-arg NEED_MIRROR=1 -f Dockerfile -t infiniflow/ragflow:nightly .
|
sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 -f Dockerfile -t infiniflow/ragflow:nightly .
|
||||||
|
|
||||||
- name: Start ragflow:nightly-slim
|
- name: Start ragflow:nightly-slim
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -149,7 +149,7 @@ out
|
|||||||
# Nuxt.js build / generate output
|
# Nuxt.js build / generate output
|
||||||
.nuxt
|
.nuxt
|
||||||
dist
|
dist
|
||||||
|
ragflow_cli.egg-info
|
||||||
# Gatsby files
|
# Gatsby files
|
||||||
.cache/
|
.cache/
|
||||||
# Comment in the public line in if your project uses Gatsby and not Next.js
|
# Comment in the public line in if your project uses Gatsby and not Next.js
|
||||||
|
|||||||
@ -191,6 +191,7 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
|||||||
ENV PYTHONPATH=/ragflow/
|
ENV PYTHONPATH=/ragflow/
|
||||||
|
|
||||||
COPY web web
|
COPY web web
|
||||||
|
COPY admin admin
|
||||||
COPY api api
|
COPY api api
|
||||||
COPY conf conf
|
COPY conf conf
|
||||||
COPY deepdoc deepdoc
|
COPY deepdoc deepdoc
|
||||||
|
|||||||
20
README.md
20
README.md
@ -1,6 +1,6 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://demo.ragflow.io/">
|
<a href="https://demo.ragflow.io/">
|
||||||
<img src="web/src/assets/logo-with-text.png" width="520" alt="ragflow logo">
|
<img src="web/src/assets/logo-with-text.svg" width="520" alt="ragflow logo">
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -22,7 +22,7 @@
|
|||||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.5">
|
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.21.0">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||||
@ -84,8 +84,8 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
|
|
||||||
## 🔥 Latest Updates
|
## 🔥 Latest Updates
|
||||||
|
|
||||||
|
- 2025-10-15 Supports orchestrable ingestion pipeline.
|
||||||
- 2025-08-08 Supports OpenAI's latest GPT-5 series models.
|
- 2025-08-08 Supports OpenAI's latest GPT-5 series models.
|
||||||
- 2025-08-04 Supports new models, including Kimi K2 and Grok 4.
|
|
||||||
- 2025-08-01 Supports agentic workflow and MCP.
|
- 2025-08-01 Supports agentic workflow and MCP.
|
||||||
- 2025-05-23 Adds a Python/JavaScript code executor component to Agent.
|
- 2025-05-23 Adds a Python/JavaScript code executor component to Agent.
|
||||||
- 2025-05-05 Supports cross-language query.
|
- 2025-05-05 Supports cross-language query.
|
||||||
@ -187,7 +187,7 @@ releases! 🌟
|
|||||||
> All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64.
|
> All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64.
|
||||||
> If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system.
|
> If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system.
|
||||||
|
|
||||||
> The command below downloads the `v0.20.5-slim` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.20.5-slim`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. For example: set `RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5` for the full edition `v0.20.5`.
|
> The command below downloads the `v0.21.0-slim` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.21.0-slim`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. For example: set `RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.0` for the full edition `v0.21.0`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cd ragflow/docker
|
$ cd ragflow/docker
|
||||||
@ -200,8 +200,8 @@ releases! 🌟
|
|||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
|-------------------|-----------------|-----------------------|--------------------------|
|
|-------------------|-----------------|-----------------------|--------------------------|
|
||||||
| v0.20.5 | ≈9 | :heavy_check_mark: | Stable release |
|
| v0.21.0 | ≈9 | :heavy_check_mark: | Stable release |
|
||||||
| v0.20.5-slim | ≈2 | ❌ | Stable release |
|
| v0.21.0-slim | ≈2 | ❌ | Stable release |
|
||||||
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
||||||
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
||||||
|
|
||||||
@ -341,11 +341,13 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly
|
|||||||
5. If your operating system does not have jemalloc, please install it as follows:
|
5. If your operating system does not have jemalloc, please install it as follows:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# ubuntu
|
# Ubuntu
|
||||||
sudo apt-get install libjemalloc-dev
|
sudo apt-get install libjemalloc-dev
|
||||||
# centos
|
# CentOS
|
||||||
sudo yum install jemalloc
|
sudo yum install jemalloc
|
||||||
# mac
|
# OpenSUSE
|
||||||
|
sudo zypper install jemalloc
|
||||||
|
# macOS
|
||||||
sudo brew install jemalloc
|
sudo brew install jemalloc
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
12
README_id.md
12
README_id.md
@ -1,6 +1,6 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://demo.ragflow.io/">
|
<a href="https://demo.ragflow.io/">
|
||||||
<img src="web/src/assets/logo-with-text.png" width="520" alt="Logo ragflow">
|
<img src="web/src/assets/logo-with-text.svg" width="520" alt="Logo ragflow">
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -22,7 +22,7 @@
|
|||||||
<img alt="Lencana Daring" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
<img alt="Lencana Daring" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.5">
|
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.21.0">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Rilis%20Terbaru" alt="Rilis Terbaru">
|
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Rilis%20Terbaru" alt="Rilis Terbaru">
|
||||||
@ -80,8 +80,8 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
|
|
||||||
## 🔥 Pembaruan Terbaru
|
## 🔥 Pembaruan Terbaru
|
||||||
|
|
||||||
|
- 2025-10-15 Dukungan untuk jalur data yang terorkestrasi.
|
||||||
- 2025-08-08 Mendukung model seri GPT-5 terbaru dari OpenAI.
|
- 2025-08-08 Mendukung model seri GPT-5 terbaru dari OpenAI.
|
||||||
- 2025-08-04 Mendukung model baru, termasuk Kimi K2 dan Grok 4.
|
|
||||||
- 2025-08-01 Mendukung alur kerja agen dan MCP.
|
- 2025-08-01 Mendukung alur kerja agen dan MCP.
|
||||||
- 2025-05-23 Menambahkan komponen pelaksana kode Python/JS ke Agen.
|
- 2025-05-23 Menambahkan komponen pelaksana kode Python/JS ke Agen.
|
||||||
- 2025-05-05 Mendukung kueri lintas bahasa.
|
- 2025-05-05 Mendukung kueri lintas bahasa.
|
||||||
@ -181,7 +181,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
> Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64.
|
> Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64.
|
||||||
> Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image).
|
> Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image).
|
||||||
|
|
||||||
> Perintah di bawah ini mengunduh edisi v0.20.5-slim dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.20.5-slim, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server. Misalnya, atur RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5 untuk edisi lengkap v0.20.5.
|
> Perintah di bawah ini mengunduh edisi v0.21.0-slim dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.21.0-slim, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server. Misalnya, atur RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.0 untuk edisi lengkap v0.21.0.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cd ragflow/docker
|
$ cd ragflow/docker
|
||||||
@ -194,8 +194,8 @@ $ docker compose -f docker-compose.yml up -d
|
|||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||||
| v0.20.5 | ≈9 | :heavy_check_mark: | Stable release |
|
| v0.21.0 | ≈9 | :heavy_check_mark: | Stable release |
|
||||||
| v0.20.5-slim | ≈2 | ❌ | Stable release |
|
| v0.21.0-slim | ≈2 | ❌ | Stable release |
|
||||||
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
||||||
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
||||||
|
|
||||||
|
|||||||
12
README_ja.md
12
README_ja.md
@ -1,6 +1,6 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://demo.ragflow.io/">
|
<a href="https://demo.ragflow.io/">
|
||||||
<img src="web/src/assets/logo-with-text.png" width="350" alt="ragflow logo">
|
<img src="web/src/assets/logo-with-text.svg" width="350" alt="ragflow logo">
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -22,7 +22,7 @@
|
|||||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.5">
|
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.21.0">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||||
@ -60,8 +60,8 @@
|
|||||||
|
|
||||||
## 🔥 最新情報
|
## 🔥 最新情報
|
||||||
|
|
||||||
|
- 2025-10-15 オーケストレーションされたデータパイプラインのサポート。
|
||||||
- 2025-08-08 OpenAI の最新 GPT-5 シリーズモデルをサポートします。
|
- 2025-08-08 OpenAI の最新 GPT-5 シリーズモデルをサポートします。
|
||||||
- 2025-08-04 新モデル、キミK2およびGrok 4をサポート。
|
|
||||||
- 2025-08-01 エージェントワークフローとMCPをサポート。
|
- 2025-08-01 エージェントワークフローとMCPをサポート。
|
||||||
- 2025-05-23 エージェントに Python/JS コードエグゼキュータコンポーネントを追加しました。
|
- 2025-05-23 エージェントに Python/JS コードエグゼキュータコンポーネントを追加しました。
|
||||||
- 2025-05-05 言語間クエリをサポートしました。
|
- 2025-05-05 言語間クエリをサポートしました。
|
||||||
@ -160,7 +160,7 @@
|
|||||||
> 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。
|
> 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。
|
||||||
> ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。
|
> ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。
|
||||||
|
|
||||||
> 以下のコマンドは、RAGFlow Docker イメージの v0.20.5-slim エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.20.5-slim とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。例えば、完全版 v0.20.5 をダウンロードするには、RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5 と設定します。
|
> 以下のコマンドは、RAGFlow Docker イメージの v0.21.0-slim エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.21.0-slim とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。例えば、完全版 v0.21.0 をダウンロードするには、RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.0 と設定します。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cd ragflow/docker
|
$ cd ragflow/docker
|
||||||
@ -173,8 +173,8 @@
|
|||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||||
| v0.20.5 | ≈9 | :heavy_check_mark: | Stable release |
|
| v0.21.0 | ≈9 | :heavy_check_mark: | Stable release |
|
||||||
| v0.20.5-slim | ≈2 | ❌ | Stable release |
|
| v0.21.0-slim | ≈2 | ❌ | Stable release |
|
||||||
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
||||||
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
||||||
|
|
||||||
|
|||||||
12
README_ko.md
12
README_ko.md
@ -1,6 +1,6 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://demo.ragflow.io/">
|
<a href="https://demo.ragflow.io/">
|
||||||
<img src="web/src/assets/logo-with-text.png" width="520" alt="ragflow logo">
|
<img src="web/src/assets/logo-with-text.svg" width="520" alt="ragflow logo">
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -22,7 +22,7 @@
|
|||||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.5">
|
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.21.0">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||||
@ -60,8 +60,8 @@
|
|||||||
|
|
||||||
## 🔥 업데이트
|
## 🔥 업데이트
|
||||||
|
|
||||||
|
- 2025-10-15 조정된 데이터 파이프라인 지원.
|
||||||
- 2025-08-08 OpenAI의 최신 GPT-5 시리즈 모델을 지원합니다.
|
- 2025-08-08 OpenAI의 최신 GPT-5 시리즈 모델을 지원합니다.
|
||||||
- 2025-08-04 새로운 모델인 Kimi K2와 Grok 4를 포함하여 지원합니다.
|
|
||||||
- 2025-08-01 에이전트 워크플로우와 MCP를 지원합니다.
|
- 2025-08-01 에이전트 워크플로우와 MCP를 지원합니다.
|
||||||
- 2025-05-23 Agent에 Python/JS 코드 실행기 구성 요소를 추가합니다.
|
- 2025-05-23 Agent에 Python/JS 코드 실행기 구성 요소를 추가합니다.
|
||||||
- 2025-05-05 언어 간 쿼리를 지원합니다.
|
- 2025-05-05 언어 간 쿼리를 지원합니다.
|
||||||
@ -160,7 +160,7 @@
|
|||||||
> 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다.
|
> 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다.
|
||||||
> ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image).
|
> ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image).
|
||||||
|
|
||||||
> 아래 명령어는 RAGFlow Docker 이미지의 v0.20.5-slim 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.20.5-slim과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오. 예를 들어, 전체 버전인 v0.20.5을 다운로드하려면 RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5로 설정합니다.
|
> 아래 명령어는 RAGFlow Docker 이미지의 v0.21.0-slim 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.21.0-slim과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오. 예를 들어, 전체 버전인 v0.21.0을 다운로드하려면 RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.0로 설정합니다.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cd ragflow/docker
|
$ cd ragflow/docker
|
||||||
@ -173,8 +173,8 @@
|
|||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||||
| v0.20.5 | ≈9 | :heavy_check_mark: | Stable release |
|
| v0.21.0 | ≈9 | :heavy_check_mark: | Stable release |
|
||||||
| v0.20.5-slim | ≈2 | ❌ | Stable release |
|
| v0.21.0-slim | ≈2 | ❌ | Stable release |
|
||||||
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
||||||
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://demo.ragflow.io/">
|
<a href="https://demo.ragflow.io/">
|
||||||
<img src="web/src/assets/logo-with-text.png" width="520" alt="ragflow logo">
|
<img src="web/src/assets/logo-with-text.svg" width="520" alt="ragflow logo">
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -22,7 +22,7 @@
|
|||||||
<img alt="Badge Estático" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
<img alt="Badge Estático" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.5">
|
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.21.0">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Última%20Relese" alt="Última Versão">
|
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Última%20Relese" alt="Última Versão">
|
||||||
@ -80,8 +80,8 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
|
|
||||||
## 🔥 Últimas Atualizações
|
## 🔥 Últimas Atualizações
|
||||||
|
|
||||||
|
- 10-15-2025 Suporte para pipelines de dados orquestrados.
|
||||||
- 08-08-2025 Suporta a mais recente série GPT-5 da OpenAI.
|
- 08-08-2025 Suporta a mais recente série GPT-5 da OpenAI.
|
||||||
- 04-08-2025 Suporta novos modelos, incluindo Kimi K2 e Grok 4.
|
|
||||||
- 01-08-2025 Suporta fluxo de trabalho agente e MCP.
|
- 01-08-2025 Suporta fluxo de trabalho agente e MCP.
|
||||||
- 23-05-2025 Adicione o componente executor de código Python/JS ao Agente.
|
- 23-05-2025 Adicione o componente executor de código Python/JS ao Agente.
|
||||||
- 05-05-2025 Suporte a consultas entre idiomas.
|
- 05-05-2025 Suporte a consultas entre idiomas.
|
||||||
@ -180,7 +180,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
> Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64.
|
> Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64.
|
||||||
> Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema.
|
> Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema.
|
||||||
|
|
||||||
> O comando abaixo baixa a edição `v0.20.5-slim` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.20.5-slim`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor. Por exemplo: defina `RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5` para a edição completa `v0.20.5`.
|
> O comando abaixo baixa a edição `v0.21.0-slim` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.21.0-slim`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor. Por exemplo: defina `RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.0` para a edição completa `v0.21.0`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cd ragflow/docker
|
$ cd ragflow/docker
|
||||||
@ -193,8 +193,8 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||||||
|
|
||||||
| Tag da imagem RAGFlow | Tamanho da imagem (GB) | Possui modelos de incorporação? | Estável? |
|
| Tag da imagem RAGFlow | Tamanho da imagem (GB) | Possui modelos de incorporação? | Estável? |
|
||||||
| --------------------- | ---------------------- | ------------------------------- | ------------------------ |
|
| --------------------- | ---------------------- | ------------------------------- | ------------------------ |
|
||||||
| v0.20.5 | ~9 | :heavy_check_mark: | Lançamento estável |
|
| v0.21.0 | ~9 | :heavy_check_mark: | Lançamento estável |
|
||||||
| v0.20.5-slim | ~2 | ❌ | Lançamento estável |
|
| v0.21.0-slim | ~2 | ❌ | Lançamento estável |
|
||||||
| nightly | ~9 | :heavy_check_mark: | _Instável_ build noturno |
|
| nightly | ~9 | :heavy_check_mark: | _Instável_ build noturno |
|
||||||
| nightly-slim | ~2 | ❌ | _Instável_ build noturno |
|
| nightly-slim | ~2 | ❌ | _Instável_ build noturno |
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://demo.ragflow.io/">
|
<a href="https://demo.ragflow.io/">
|
||||||
<img src="web/src/assets/logo-with-text.png" width="350" alt="ragflow logo">
|
<img src="web/src/assets/logo-with-text.svg" width="350" alt="ragflow logo">
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -22,7 +22,7 @@
|
|||||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.5">
|
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.21.0">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||||
@ -83,8 +83,8 @@
|
|||||||
|
|
||||||
## 🔥 近期更新
|
## 🔥 近期更新
|
||||||
|
|
||||||
|
- 2025-10-15 支援可編排的資料管道。
|
||||||
- 2025-08-08 支援 OpenAI 最新的 GPT-5 系列模型。
|
- 2025-08-08 支援 OpenAI 最新的 GPT-5 系列模型。
|
||||||
- 2025-08-04 支援 Kimi K2 和 Grok 4 等模型.
|
|
||||||
- 2025-08-01 支援 agentic workflow 和 MCP
|
- 2025-08-01 支援 agentic workflow 和 MCP
|
||||||
- 2025-05-23 為 Agent 新增 Python/JS 程式碼執行器元件。
|
- 2025-05-23 為 Agent 新增 Python/JS 程式碼執行器元件。
|
||||||
- 2025-05-05 支援跨語言查詢。
|
- 2025-05-05 支援跨語言查詢。
|
||||||
@ -183,7 +183,7 @@
|
|||||||
> 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。
|
> 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。
|
||||||
> 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。
|
> 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。
|
||||||
|
|
||||||
> 執行以下指令會自動下載 RAGFlow slim Docker 映像 `v0.20.5-slim`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.20.5-slim` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。例如,你可以透過設定 `RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5` 來下載 RAGFlow 鏡像的 `v0.20.5` 完整發行版。
|
> 執行以下指令會自動下載 RAGFlow slim Docker 映像 `v0.21.0-slim`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.21.0-slim` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。例如,你可以透過設定 `RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.0` 來下載 RAGFlow 鏡像的 `v0.21.0` 完整發行版。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cd ragflow/docker
|
$ cd ragflow/docker
|
||||||
@ -196,8 +196,8 @@
|
|||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||||
| v0.20.5 | ≈9 | :heavy_check_mark: | Stable release |
|
| v0.21.0 | ≈9 | :heavy_check_mark: | Stable release |
|
||||||
| v0.20.5-slim | ≈2 | ❌ | Stable release |
|
| v0.21.0-slim | ≈2 | ❌ | Stable release |
|
||||||
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
||||||
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
||||||
|
|
||||||
|
|||||||
14
README_zh.md
14
README_zh.md
@ -1,6 +1,6 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://demo.ragflow.io/">
|
<a href="https://demo.ragflow.io/">
|
||||||
<img src="web/src/assets/logo-with-text.png" width="350" alt="ragflow logo">
|
<img src="web/src/assets/logo-with-text.svg" width="350" alt="ragflow logo">
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -22,7 +22,7 @@
|
|||||||
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
<img alt="Static Badge" src="https://img.shields.io/badge/Online-Demo-4e6b99">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
<a href="https://hub.docker.com/r/infiniflow/ragflow" target="_blank">
|
||||||
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.20.5">
|
<img src="https://img.shields.io/docker/pulls/infiniflow/ragflow?label=Docker%20Pulls&color=0db7ed&logo=docker&logoColor=white&style=flat-square" alt="docker pull infiniflow/ragflow:v0.21.0">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
<a href="https://github.com/infiniflow/ragflow/releases/latest">
|
||||||
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
<img src="https://img.shields.io/github/v/release/infiniflow/ragflow?color=blue&label=Latest%20Release" alt="Latest Release">
|
||||||
@ -83,8 +83,8 @@
|
|||||||
|
|
||||||
## 🔥 近期更新
|
## 🔥 近期更新
|
||||||
|
|
||||||
- 2025-08-08 支持 OpenAI 最新的 GPT-5 系列模型.
|
- 2025-10-15 支持可编排的数据管道。
|
||||||
- 2025-08-04 新增对 Kimi K2 和 Grok 4 等模型的支持.
|
- 2025-08-08 支持 OpenAI 最新的 GPT-5 系列模型。
|
||||||
- 2025-08-01 支持 agentic workflow 和 MCP。
|
- 2025-08-01 支持 agentic workflow 和 MCP。
|
||||||
- 2025-05-23 Agent 新增 Python/JS 代码执行器组件。
|
- 2025-05-23 Agent 新增 Python/JS 代码执行器组件。
|
||||||
- 2025-05-05 支持跨语言查询。
|
- 2025-05-05 支持跨语言查询。
|
||||||
@ -183,7 +183,7 @@
|
|||||||
> 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。
|
> 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。
|
||||||
> 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。
|
> 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。
|
||||||
|
|
||||||
> 运行以下命令会自动下载 RAGFlow slim Docker 镜像 `v0.20.5-slim`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.20.5-slim` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。比如,你可以通过设置 `RAGFLOW_IMAGE=infiniflow/ragflow:v0.20.5` 来下载 RAGFlow 镜像的 `v0.20.5` 完整发行版。
|
> 运行以下命令会自动下载 RAGFlow slim Docker 镜像 `v0.21.0-slim`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.21.0-slim` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。比如,你可以通过设置 `RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.0` 来下载 RAGFlow 镜像的 `v0.21.0` 完整发行版。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cd ragflow/docker
|
$ cd ragflow/docker
|
||||||
@ -196,8 +196,8 @@
|
|||||||
|
|
||||||
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? |
|
||||||
| ----------------- | --------------- | --------------------- | ------------------------ |
|
| ----------------- | --------------- | --------------------- | ------------------------ |
|
||||||
| v0.20.5 | ≈9 | :heavy_check_mark: | Stable release |
|
| v0.21.0 | ≈9 | :heavy_check_mark: | Stable release |
|
||||||
| v0.20.5-slim | ≈2 | ❌ | Stable release |
|
| v0.21.0-slim | ≈2 | ❌ | Stable release |
|
||||||
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
| nightly | ≈9 | :heavy_check_mark: | _Unstable_ nightly build |
|
||||||
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
| nightly-slim | ≈2 | ❌ | _Unstable_ nightly build |
|
||||||
|
|
||||||
|
|||||||
47
admin/build_cli_release.sh
Executable file
47
admin/build_cli_release.sh
Executable file
@ -0,0 +1,47 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "🚀 Start building..."
|
||||||
|
echo "================================"
|
||||||
|
|
||||||
|
PROJECT_NAME="ragflow-cli"
|
||||||
|
|
||||||
|
RELEASE_DIR="release"
|
||||||
|
BUILD_DIR="dist"
|
||||||
|
SOURCE_DIR="src"
|
||||||
|
PACKAGE_DIR="ragflow_cli"
|
||||||
|
|
||||||
|
echo "🧹 Clean old build folder..."
|
||||||
|
rm -rf release/
|
||||||
|
|
||||||
|
echo "📁 Prepare source code..."
|
||||||
|
mkdir release/$PROJECT_NAME/$SOURCE_DIR -p
|
||||||
|
cp pyproject.toml release/$PROJECT_NAME/pyproject.toml
|
||||||
|
cp README.md release/$PROJECT_NAME/README.md
|
||||||
|
|
||||||
|
mkdir release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR -p
|
||||||
|
cp admin_client.py release/$PROJECT_NAME/$SOURCE_DIR/$PACKAGE_DIR/admin_client.py
|
||||||
|
|
||||||
|
if [ -d "release/$PROJECT_NAME/$SOURCE_DIR" ]; then
|
||||||
|
echo "✅ source dir: release/$PROJECT_NAME/$SOURCE_DIR"
|
||||||
|
else
|
||||||
|
echo "❌ source dir not exist: release/$PROJECT_NAME/$SOURCE_DIR"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "🔨 Make build file..."
|
||||||
|
cd release/$PROJECT_NAME
|
||||||
|
export PYTHONPATH=$(pwd)
|
||||||
|
python -m build
|
||||||
|
|
||||||
|
echo "✅ check build result..."
|
||||||
|
if [ -d "$BUILD_DIR" ]; then
|
||||||
|
echo "📦 Package generated:"
|
||||||
|
ls -la $BUILD_DIR/
|
||||||
|
else
|
||||||
|
echo "❌ Build Failed: $BUILD_DIR not exist."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "🎉 Build finished successfully!"
|
||||||
@ -15,22 +15,48 @@ It consists of a server-side Service and a command-line client (CLI), both imple
|
|||||||
- **Admin Service**: A backend service that interfaces with the RAGFlow system to execute administrative operations and monitor its status.
|
- **Admin Service**: A backend service that interfaces with the RAGFlow system to execute administrative operations and monitor its status.
|
||||||
- **Admin CLI**: A command-line interface that allows users to connect to the Admin Service and issue commands for system management.
|
- **Admin CLI**: A command-line interface that allows users to connect to the Admin Service and issue commands for system management.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Starting the Admin Service
|
### Starting the Admin Service
|
||||||
|
|
||||||
1. Before start Admin Service, please make sure RAGFlow system is already started.
|
#### Launching from source code
|
||||||
|
|
||||||
|
1. Before start Admin Service, please make sure RAGFlow system is already started.
|
||||||
|
|
||||||
|
2. Launch from source code:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python admin/server/admin_server.py
|
||||||
|
```
|
||||||
|
The service will start and listen for incoming connections from the CLI on the configured port.
|
||||||
|
|
||||||
|
#### Using docker image
|
||||||
|
|
||||||
|
1. Before startup, please configure the `docker_compose.yml` file to enable admin server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
command:
|
||||||
|
- --enable-adminserver
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start the containers, the service will start and listen for incoming connections from the CLI on the configured port.
|
||||||
|
|
||||||
|
|
||||||
2. Run the service script:
|
|
||||||
```bash
|
|
||||||
python admin/admin_server.py
|
|
||||||
```
|
|
||||||
The service will start and listen for incoming connections from the CLI on the configured port.
|
|
||||||
|
|
||||||
### Using the Admin CLI
|
### Using the Admin CLI
|
||||||
|
|
||||||
1. Ensure the Admin Service is running.
|
1. Ensure the Admin Service is running.
|
||||||
2. Launch the CLI client:
|
2. Install ragflow-cli.
|
||||||
```bash
|
```bash
|
||||||
python admin/admin_client.py -h 0.0.0.0 -p 9381
|
pip install ragflow-cli
|
||||||
|
```
|
||||||
|
3. Launch the CLI client:
|
||||||
|
```bash
|
||||||
|
ragflow-cli -h 0.0.0.0 -p 9381
|
||||||
|
```
|
||||||
|
Enter superuser's password to login. Default password is `admin`.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Supported Commands
|
## Supported Commands
|
||||||
|
|
||||||
@ -42,12 +68,7 @@ Commands are case-insensitive and must be terminated with a semicolon (`;`).
|
|||||||
- Lists all available services within the RAGFlow system.
|
- Lists all available services within the RAGFlow system.
|
||||||
- `SHOW SERVICE <id>;`
|
- `SHOW SERVICE <id>;`
|
||||||
- Shows detailed status information for the service identified by `<id>`.
|
- Shows detailed status information for the service identified by `<id>`.
|
||||||
- `STARTUP SERVICE <id>;`
|
|
||||||
- Attempts to start the service identified by `<id>`.
|
|
||||||
- `SHUTDOWN SERVICE <id>;`
|
|
||||||
- Attempts to gracefully shut down the service identified by `<id>`.
|
|
||||||
- `RESTART SERVICE <id>;`
|
|
||||||
- Attempts to restart the service identified by `<id>`.
|
|
||||||
|
|
||||||
### User Management Commands
|
### User Management Commands
|
||||||
|
|
||||||
@ -55,10 +76,17 @@ Commands are case-insensitive and must be terminated with a semicolon (`;`).
|
|||||||
- Lists all users known to the system.
|
- Lists all users known to the system.
|
||||||
- `SHOW USER '<username>';`
|
- `SHOW USER '<username>';`
|
||||||
- Shows details and permissions for the specified user. The username must be enclosed in single or double quotes.
|
- Shows details and permissions for the specified user. The username must be enclosed in single or double quotes.
|
||||||
|
|
||||||
|
- `CREATE USER <username> <password>;`
|
||||||
|
- Create user by username and password. The username and password must be enclosed in single or double quotes.
|
||||||
|
|
||||||
- `DROP USER '<username>';`
|
- `DROP USER '<username>';`
|
||||||
- Removes the specified user from the system. Use with caution.
|
- Removes the specified user from the system. Use with caution.
|
||||||
- `ALTER USER PASSWORD '<username>' '<new_password>';`
|
- `ALTER USER PASSWORD '<username>' '<new_password>';`
|
||||||
- Changes the password for the specified user.
|
- Changes the password for the specified user.
|
||||||
|
- `ALTER USER ACTIVE <username> <on/off>;`
|
||||||
|
- Changes the user to active or inactive.
|
||||||
|
|
||||||
|
|
||||||
### Data and Agent Commands
|
### Data and Agent Commands
|
||||||
|
|
||||||
@ -1,7 +1,27 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import base64
|
import base64
|
||||||
|
from cmd import Cmd
|
||||||
|
|
||||||
|
from Cryptodome.PublicKey import RSA
|
||||||
|
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any
|
||||||
from lark import Lark, Transformer, Tree
|
from lark import Lark, Transformer, Tree, Token
|
||||||
import requests
|
import requests
|
||||||
from requests.auth import HTTPBasicAuth
|
from requests.auth import HTTPBasicAuth
|
||||||
|
|
||||||
@ -19,6 +39,8 @@ sql_command: list_services
|
|||||||
| show_user
|
| show_user
|
||||||
| drop_user
|
| drop_user
|
||||||
| alter_user
|
| alter_user
|
||||||
|
| create_user
|
||||||
|
| activate_user
|
||||||
| list_datasets
|
| list_datasets
|
||||||
| list_agents
|
| list_agents
|
||||||
|
|
||||||
@ -35,6 +57,7 @@ meta_arg: /[^\\s"']+/ | quoted_string
|
|||||||
LIST: "LIST"i
|
LIST: "LIST"i
|
||||||
SERVICES: "SERVICES"i
|
SERVICES: "SERVICES"i
|
||||||
SHOW: "SHOW"i
|
SHOW: "SHOW"i
|
||||||
|
CREATE: "CREATE"i
|
||||||
SERVICE: "SERVICE"i
|
SERVICE: "SERVICE"i
|
||||||
SHUTDOWN: "SHUTDOWN"i
|
SHUTDOWN: "SHUTDOWN"i
|
||||||
STARTUP: "STARTUP"i
|
STARTUP: "STARTUP"i
|
||||||
@ -43,6 +66,7 @@ USERS: "USERS"i
|
|||||||
DROP: "DROP"i
|
DROP: "DROP"i
|
||||||
USER: "USER"i
|
USER: "USER"i
|
||||||
ALTER: "ALTER"i
|
ALTER: "ALTER"i
|
||||||
|
ACTIVE: "ACTIVE"i
|
||||||
PASSWORD: "PASSWORD"i
|
PASSWORD: "PASSWORD"i
|
||||||
DATASETS: "DATASETS"i
|
DATASETS: "DATASETS"i
|
||||||
OF: "OF"i
|
OF: "OF"i
|
||||||
@ -58,12 +82,15 @@ list_users: LIST USERS ";"
|
|||||||
drop_user: DROP USER quoted_string ";"
|
drop_user: DROP USER quoted_string ";"
|
||||||
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
|
alter_user: ALTER USER PASSWORD quoted_string quoted_string ";"
|
||||||
show_user: SHOW USER quoted_string ";"
|
show_user: SHOW USER quoted_string ";"
|
||||||
|
create_user: CREATE USER quoted_string quoted_string ";"
|
||||||
|
activate_user: ALTER USER ACTIVE quoted_string status ";"
|
||||||
|
|
||||||
list_datasets: LIST DATASETS OF quoted_string ";"
|
list_datasets: LIST DATASETS OF quoted_string ";"
|
||||||
list_agents: LIST AGENTS OF quoted_string ";"
|
list_agents: LIST AGENTS OF quoted_string ";"
|
||||||
|
|
||||||
identifier: WORD
|
identifier: WORD
|
||||||
quoted_string: QUOTED_STRING
|
quoted_string: QUOTED_STRING
|
||||||
|
status: WORD
|
||||||
|
|
||||||
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
|
QUOTED_STRING: /'[^']+'/ | /"[^"]+"/
|
||||||
WORD: /[a-zA-Z0-9_\-\.]+/
|
WORD: /[a-zA-Z0-9_\-\.]+/
|
||||||
@ -118,6 +145,16 @@ class AdminTransformer(Transformer):
|
|||||||
new_password = items[4]
|
new_password = items[4]
|
||||||
return {"type": "alter_user", "username": user_name, "password": new_password}
|
return {"type": "alter_user", "username": user_name, "password": new_password}
|
||||||
|
|
||||||
|
def create_user(self, items):
|
||||||
|
user_name = items[2]
|
||||||
|
password = items[3]
|
||||||
|
return {"type": "create_user", "username": user_name, "password": password, "role": "user"}
|
||||||
|
|
||||||
|
def activate_user(self, items):
|
||||||
|
user_name = items[3]
|
||||||
|
activate_status = items[4]
|
||||||
|
return {"type": "activate_user", "activate_status": activate_status, "username": user_name}
|
||||||
|
|
||||||
def list_datasets(self, items):
|
def list_datasets(self, items):
|
||||||
user_name = items[3]
|
user_name = items[3]
|
||||||
return {"type": "list_datasets", "username": user_name}
|
return {"type": "list_datasets", "username": user_name}
|
||||||
@ -147,17 +184,67 @@ class AdminTransformer(Transformer):
|
|||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt(input_string):
|
||||||
|
pub = '-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----'
|
||||||
|
pub_key = RSA.importKey(pub)
|
||||||
|
cipher = Cipher_pkcs1_v1_5.new(pub_key)
|
||||||
|
cipher_text = cipher.encrypt(base64.b64encode(input_string.encode('utf-8')))
|
||||||
|
return base64.b64encode(cipher_text).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def encode_to_base64(input_string):
|
def encode_to_base64(input_string):
|
||||||
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
||||||
return base64_encoded.decode('utf-8')
|
return base64_encoded.decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
class AdminCommandParser:
|
class AdminCLI(Cmd):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
self.parser = Lark(GRAMMAR, start='start', parser='lalr', transformer=AdminTransformer())
|
self.parser = Lark(GRAMMAR, start='start', parser='lalr', transformer=AdminTransformer())
|
||||||
self.command_history = []
|
self.command_history = []
|
||||||
|
self.is_interactive = False
|
||||||
|
self.admin_account = "admin@ragflow.io"
|
||||||
|
self.admin_password: str = "admin"
|
||||||
|
self.host: str = ""
|
||||||
|
self.port: int = 0
|
||||||
|
|
||||||
def parse_command(self, command_str: str) -> Dict[str, Any]:
|
intro = r"""Type "\h" for help."""
|
||||||
|
prompt = "admin> "
|
||||||
|
|
||||||
|
def onecmd(self, command: str) -> bool:
|
||||||
|
try:
|
||||||
|
# print(f"command: {command}")
|
||||||
|
result = self.parse_command(command)
|
||||||
|
|
||||||
|
# if 'type' in result and result.get('type') == 'empty':
|
||||||
|
# return False
|
||||||
|
|
||||||
|
if isinstance(result, dict):
|
||||||
|
if 'type' in result and result.get('type') == 'empty':
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.execute_command(result)
|
||||||
|
|
||||||
|
if isinstance(result, Tree):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if result.get('type') == 'meta' and result.get('command') in ['q', 'quit', 'exit']:
|
||||||
|
return True
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nUse '\\q' to quit")
|
||||||
|
except EOFError:
|
||||||
|
print("\nGoodbye!")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def emptyline(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def default(self, line: str) -> bool:
|
||||||
|
return self.onecmd(line)
|
||||||
|
|
||||||
|
def parse_command(self, command_str: str) -> dict[str, str] | Tree[Token]:
|
||||||
if not command_str.strip():
|
if not command_str.strip():
|
||||||
return {'type': 'empty'}
|
return {'type': 'empty'}
|
||||||
|
|
||||||
@ -169,16 +256,6 @@ class AdminCommandParser:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {'type': 'error', 'message': f'Parse error: {str(e)}'}
|
return {'type': 'error', 'message': f'Parse error: {str(e)}'}
|
||||||
|
|
||||||
|
|
||||||
class AdminCLI:
|
|
||||||
def __init__(self):
|
|
||||||
self.parser = AdminCommandParser()
|
|
||||||
self.is_interactive = False
|
|
||||||
self.admin_account = "admin@ragflow.io"
|
|
||||||
self.admin_password: str = "admin"
|
|
||||||
self.host: str = ""
|
|
||||||
self.port: int = 0
|
|
||||||
|
|
||||||
def verify_admin(self, args):
|
def verify_admin(self, args):
|
||||||
|
|
||||||
conn_info = self._parse_connection_args(args)
|
conn_info = self._parse_connection_args(args)
|
||||||
@ -220,14 +297,32 @@ class AdminCLI:
|
|||||||
if not data:
|
if not data:
|
||||||
print("No data to print")
|
print("No data to print")
|
||||||
return
|
return
|
||||||
|
if isinstance(data, dict):
|
||||||
|
# handle single row data
|
||||||
|
data = [data]
|
||||||
|
|
||||||
columns = list(data[0].keys())
|
columns = list(data[0].keys())
|
||||||
col_widths = {}
|
col_widths = {}
|
||||||
|
|
||||||
|
def get_string_width(text):
|
||||||
|
half_width_chars = (
|
||||||
|
" !\"#$%&'()*+,-./0123456789:;<=>?@"
|
||||||
|
"ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`"
|
||||||
|
"abcdefghijklmnopqrstuvwxyz{|}~"
|
||||||
|
"\t\n\r"
|
||||||
|
)
|
||||||
|
width = 0
|
||||||
|
for char in text:
|
||||||
|
if char in half_width_chars:
|
||||||
|
width += 1
|
||||||
|
else:
|
||||||
|
width += 2
|
||||||
|
return width
|
||||||
|
|
||||||
for col in columns:
|
for col in columns:
|
||||||
max_width = len(str(col))
|
max_width = get_string_width(str(col))
|
||||||
for item in data:
|
for item in data:
|
||||||
value_len = len(str(item.get(col, '')))
|
value_len = get_string_width(str(item.get(col, '')))
|
||||||
if value_len > max_width:
|
if value_len > max_width:
|
||||||
max_width = value_len
|
max_width = value_len
|
||||||
col_widths[col] = max(2, max_width)
|
col_widths[col] = max(2, max_width)
|
||||||
@ -265,7 +360,7 @@ class AdminCLI:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"command: {command}")
|
print(f"command: {command}")
|
||||||
result = self.parser.parse_command(command)
|
result = self.parse_command(command)
|
||||||
self.execute_command(result)
|
self.execute_command(result)
|
||||||
|
|
||||||
if isinstance(result, Tree):
|
if isinstance(result, Tree):
|
||||||
@ -335,6 +430,10 @@ class AdminCLI:
|
|||||||
self._handle_drop_user(command_dict)
|
self._handle_drop_user(command_dict)
|
||||||
case 'alter_user':
|
case 'alter_user':
|
||||||
self._handle_alter_user(command_dict)
|
self._handle_alter_user(command_dict)
|
||||||
|
case 'create_user':
|
||||||
|
self._handle_create_user(command_dict)
|
||||||
|
case 'activate_user':
|
||||||
|
self._handle_activate_user(command_dict)
|
||||||
case 'list_datasets':
|
case 'list_datasets':
|
||||||
self._handle_list_datasets(command_dict)
|
self._handle_list_datasets(command_dict)
|
||||||
case 'list_agents':
|
case 'list_agents':
|
||||||
@ -349,9 +448,8 @@ class AdminCLI:
|
|||||||
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/services'
|
url = f'http://{self.host}:{self.port}/api/v1/admin/services'
|
||||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
res_json = dict
|
res_json = response.json()
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
res_json = response.json()
|
|
||||||
self._print_table_simple(res_json['data'])
|
self._print_table_simple(res_json['data'])
|
||||||
else:
|
else:
|
||||||
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
@ -360,6 +458,22 @@ class AdminCLI:
|
|||||||
service_id: int = command['number']
|
service_id: int = command['number']
|
||||||
print(f"Showing service: {service_id}")
|
print(f"Showing service: {service_id}")
|
||||||
|
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/services/{service_id}'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
res_data = res_json['data']
|
||||||
|
if res_data['alive']:
|
||||||
|
print(f"Service {res_data['service_name']} is alive. Detail:")
|
||||||
|
if isinstance(res_data['message'], str):
|
||||||
|
print(res_data['message'])
|
||||||
|
else:
|
||||||
|
self._print_table_simple(res_data['message'])
|
||||||
|
else:
|
||||||
|
print(f"Service {res_data['service_name']} is down. Detail: {res_data['message']}")
|
||||||
|
else:
|
||||||
|
print(f"Fail to show service, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
def _handle_restart_service(self, command):
|
def _handle_restart_service(self, command):
|
||||||
service_id: int = command['number']
|
service_id: int = command['number']
|
||||||
print(f"Restart service {service_id}")
|
print(f"Restart service {service_id}")
|
||||||
@ -377,9 +491,8 @@ class AdminCLI:
|
|||||||
|
|
||||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||||
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
res_json = dict
|
res_json = response.json()
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
res_json = response.json()
|
|
||||||
self._print_table_simple(res_json['data'])
|
self._print_table_simple(res_json['data'])
|
||||||
else:
|
else:
|
||||||
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
@ -388,11 +501,25 @@ class AdminCLI:
|
|||||||
username_tree: Tree = command['username']
|
username_tree: Tree = command['username']
|
||||||
username: str = username_tree.children[0].strip("'\"")
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
print(f"Showing user: {username}")
|
print(f"Showing user: {username}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get user {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
def _handle_drop_user(self, command):
|
def _handle_drop_user(self, command):
|
||||||
username_tree: Tree = command['username']
|
username_tree: Tree = command['username']
|
||||||
username: str = username_tree.children[0].strip("'\"")
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
print(f"Drop user: {username}")
|
print(f"Drop user: {username}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
|
||||||
|
response = requests.delete(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(res_json["message"])
|
||||||
|
else:
|
||||||
|
print(f"Fail to drop user, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
def _handle_alter_user(self, command):
|
def _handle_alter_user(self, command):
|
||||||
username_tree: Tree = command['username']
|
username_tree: Tree = command['username']
|
||||||
@ -400,16 +527,75 @@ class AdminCLI:
|
|||||||
password_tree: Tree = command['password']
|
password_tree: Tree = command['password']
|
||||||
password: str = password_tree.children[0].strip("'\"")
|
password: str = password_tree.children[0].strip("'\"")
|
||||||
print(f"Alter user: {username}, password: {password}")
|
print(f"Alter user: {username}, password: {password}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/password'
|
||||||
|
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||||
|
json={'new_password': encrypt(password)})
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(res_json["message"])
|
||||||
|
else:
|
||||||
|
print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_create_user(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
password_tree: Tree = command['password']
|
||||||
|
password: str = password_tree.children[0].strip("'\"")
|
||||||
|
role: str = command['role']
|
||||||
|
print(f"Create user: {username}, password: {password}, role: {role}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users'
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||||
|
json={'username': username, 'password': encrypt(password), 'role': role}
|
||||||
|
)
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to create user {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
def _handle_activate_user(self, command):
|
||||||
|
username_tree: Tree = command['username']
|
||||||
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
|
activate_tree: Tree = command['activate_status']
|
||||||
|
activate_status: str = activate_tree.children[0].strip("'\"")
|
||||||
|
if activate_status.lower() in ['on', 'off']:
|
||||||
|
print(f"Alter user {username} activate status, turn {activate_status.lower()}.")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/activate'
|
||||||
|
response = requests.put(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password),
|
||||||
|
json={'activate_status': activate_status})
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(res_json["message"])
|
||||||
|
else:
|
||||||
|
print(f"Fail to alter activate status, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
else:
|
||||||
|
print(f"Unknown activate status: {activate_status}.")
|
||||||
|
|
||||||
def _handle_list_datasets(self, command):
|
def _handle_list_datasets(self, command):
|
||||||
username_tree: Tree = command['username']
|
username_tree: Tree = command['username']
|
||||||
username: str = username_tree.children[0].strip("'\"")
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
print(f"Listing all datasets of user: {username}")
|
print(f"Listing all datasets of user: {username}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/datasets'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get all datasets of {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
def _handle_list_agents(self, command):
|
def _handle_list_agents(self, command):
|
||||||
username_tree: Tree = command['username']
|
username_tree: Tree = command['username']
|
||||||
username: str = username_tree.children[0].strip("'\"")
|
username: str = username_tree.children[0].strip("'\"")
|
||||||
print(f"Listing all agents of user: {username}")
|
print(f"Listing all agents of user: {username}")
|
||||||
|
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}/agents'
|
||||||
|
response = requests.get(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||||
|
res_json = response.json()
|
||||||
|
if response.status_code == 200:
|
||||||
|
self._print_table_simple(res_json['data'])
|
||||||
|
else:
|
||||||
|
print(f"Fail to get all agents of {username}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
def _handle_meta_command(self, command):
|
def _handle_meta_command(self, command):
|
||||||
meta_command = command['command']
|
meta_command = command['command']
|
||||||
@ -436,6 +622,7 @@ Commands:
|
|||||||
DROP USER <user>
|
DROP USER <user>
|
||||||
CREATE USER <user> <password>
|
CREATE USER <user> <password>
|
||||||
ALTER USER PASSWORD <user> <new_password>
|
ALTER USER PASSWORD <user> <new_password>
|
||||||
|
ALTER USER ACTIVE <user> <on/off>
|
||||||
LIST DATASETS OF <user>
|
LIST DATASETS OF <user>
|
||||||
LIST AGENTS OF <user>
|
LIST AGENTS OF <user>
|
||||||
|
|
||||||
@ -460,10 +647,17 @@ def main():
|
|||||||
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
|
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
|
||||||
""")
|
""")
|
||||||
if cli.verify_admin(sys.argv):
|
if cli.verify_admin(sys.argv):
|
||||||
cli.run_interactive()
|
cli.cmdloop()
|
||||||
else:
|
else:
|
||||||
|
print(r"""
|
||||||
|
____ ___ ______________ ___ __ _
|
||||||
|
/ __ \/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___
|
||||||
|
/ /_/ / /| |/ / __/ /_ / / __ \ | /| / / / /| |/ __ / __ `__ \/ / __ \
|
||||||
|
/ _, _/ ___ / /_/ / __/ / / /_/ / |/ |/ / / ___ / /_/ / / / / / / / / / /
|
||||||
|
/_/ |_/_/ |_\____/_/ /_/\____/|__/|__/ /_/ |_\__,_/_/ /_/ /_/_/_/ /_/
|
||||||
|
""")
|
||||||
if cli.verify_admin(sys.argv):
|
if cli.verify_admin(sys.argv):
|
||||||
cli.run_interactive()
|
cli.cmdloop()
|
||||||
# cli.run_single_command(sys.argv[1:])
|
# cli.run_single_command(sys.argv[1:])
|
||||||
|
|
||||||
|
|
||||||
24
admin/client/pyproject.toml
Normal file
24
admin/client/pyproject.toml
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
[project]
|
||||||
|
name = "ragflow-cli"
|
||||||
|
version = "0.21.0.dev5"
|
||||||
|
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
|
||||||
|
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
|
||||||
|
license = { text = "Apache License, Version 2.0" }
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10,<3.13"
|
||||||
|
dependencies = [
|
||||||
|
"requests>=2.30.0,<3.0.0",
|
||||||
|
"beartype>=0.18.5,<0.19.0",
|
||||||
|
"pycryptodomex>=3.10.0",
|
||||||
|
"lark>=1.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
test = [
|
||||||
|
"pytest>=8.3.5",
|
||||||
|
"requests>=2.32.3",
|
||||||
|
"requests-toolbelt>=1.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
ragflow-cli = "admin_client:main"
|
||||||
24
admin/pyproject.toml
Normal file
24
admin/pyproject.toml
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
[project]
|
||||||
|
name = "ragflow-cli"
|
||||||
|
version = "0.21.0.dev2"
|
||||||
|
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
|
||||||
|
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
|
||||||
|
license = { text = "Apache License, Version 2.0" }
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10,<3.13"
|
||||||
|
dependencies = [
|
||||||
|
"requests>=2.30.0,<3.0.0",
|
||||||
|
"beartype>=0.18.5,<0.19.0",
|
||||||
|
"pycryptodomex>=3.10.0",
|
||||||
|
"lark>=1.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
test = [
|
||||||
|
"pytest>=8.3.5",
|
||||||
|
"requests>=2.32.3",
|
||||||
|
"requests-toolbelt>=1.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
ragflow-cli = "ragflow_cli.admin_client:main"
|
||||||
@ -1,15 +0,0 @@
|
|||||||
from flask import jsonify
|
|
||||||
|
|
||||||
def success_response(data=None, message="Success", code = 0):
|
|
||||||
return jsonify({
|
|
||||||
"code": code,
|
|
||||||
"message": message,
|
|
||||||
"data": data
|
|
||||||
}), 200
|
|
||||||
|
|
||||||
def error_response(message="Error", code=-1, data=None):
|
|
||||||
return jsonify({
|
|
||||||
"code": code,
|
|
||||||
"message": message,
|
|
||||||
"data": data
|
|
||||||
}), 400
|
|
||||||
@ -1,3 +1,18 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
@ -10,7 +25,8 @@ from flask import Flask
|
|||||||
from routes import admin_bp
|
from routes import admin_bp
|
||||||
from api.utils.log_utils import init_root_logger
|
from api.utils.log_utils import init_root_logger
|
||||||
from api.constants import SERVICE_CONF
|
from api.constants import SERVICE_CONF
|
||||||
from config import load_configurations, SERVICE_CONFIGS
|
from api import settings
|
||||||
|
from admin.server.config import load_configurations, SERVICE_CONFIGS
|
||||||
|
|
||||||
stop_event = threading.Event()
|
stop_event = threading.Event()
|
||||||
|
|
||||||
@ -26,7 +42,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.register_blueprint(admin_bp)
|
app.register_blueprint(admin_bp)
|
||||||
|
settings.init_settings()
|
||||||
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
|
SERVICE_CONFIGS.configs = load_configurations(SERVICE_CONF)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1,9 +1,26 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from flask import request, jsonify
|
from flask import request, jsonify
|
||||||
|
|
||||||
from exceptions import AdminException
|
from api.common.exceptions import AdminException
|
||||||
from api.db.init_data import encode_to_base64
|
from api.db.init_data import encode_to_base64
|
||||||
from api.db.services import UserService
|
from api.db.services import UserService
|
||||||
|
|
||||||
@ -1,10 +1,27 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from api.utils import read_config
|
from api.utils.configs import read_config
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
@ -32,9 +49,11 @@ class BaseConfig(BaseModel):
|
|||||||
host: str
|
host: str
|
||||||
port: int
|
port: int
|
||||||
service_type: str
|
service_type: str
|
||||||
|
detail_func_name: str
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port, 'service_type': self.service_type}
|
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
|
||||||
|
'service_type': self.service_type}
|
||||||
|
|
||||||
|
|
||||||
class MetaConfig(BaseConfig):
|
class MetaConfig(BaseConfig):
|
||||||
@ -209,7 +228,8 @@ def load_configurations(config_path: str) -> list[BaseConfig]:
|
|||||||
name: str = f'ragflow_{ragflow_count}'
|
name: str = f'ragflow_{ragflow_count}'
|
||||||
host: str = v['host']
|
host: str = v['host']
|
||||||
http_port: int = v['http_port']
|
http_port: int = v['http_port']
|
||||||
config = RAGFlowServerConfig(id=id_count, name=name, host=host, port=http_port, service_type="ragflow_server")
|
config = RAGFlowServerConfig(id=id_count, name=name, host=host, port=http_port,
|
||||||
|
service_type="ragflow_server", detail_func_name="check_ragflow_server_alive")
|
||||||
configurations.append(config)
|
configurations.append(config)
|
||||||
id_count += 1
|
id_count += 1
|
||||||
case "es":
|
case "es":
|
||||||
@ -222,7 +242,8 @@ def load_configurations(config_path: str) -> list[BaseConfig]:
|
|||||||
password: str = v.get('password')
|
password: str = v.get('password')
|
||||||
config = ElasticsearchConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval",
|
config = ElasticsearchConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval",
|
||||||
retrieval_type="elasticsearch",
|
retrieval_type="elasticsearch",
|
||||||
username=username, password=password)
|
username=username, password=password,
|
||||||
|
detail_func_name="get_es_cluster_stats")
|
||||||
configurations.append(config)
|
configurations.append(config)
|
||||||
id_count += 1
|
id_count += 1
|
||||||
|
|
||||||
@ -234,7 +255,7 @@ def load_configurations(config_path: str) -> list[BaseConfig]:
|
|||||||
port = int(parts[1])
|
port = int(parts[1])
|
||||||
database: str = v.get('db_name', 'default_db')
|
database: str = v.get('db_name', 'default_db')
|
||||||
config = InfinityConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval", retrieval_type="infinity",
|
config = InfinityConfig(id=id_count, name=name, host=host, port=port, service_type="retrieval", retrieval_type="infinity",
|
||||||
db_name=database)
|
db_name=database, detail_func_name="get_infinity_status")
|
||||||
configurations.append(config)
|
configurations.append(config)
|
||||||
id_count += 1
|
id_count += 1
|
||||||
case "minio":
|
case "minio":
|
||||||
@ -246,7 +267,7 @@ def load_configurations(config_path: str) -> list[BaseConfig]:
|
|||||||
user = v.get('user')
|
user = v.get('user')
|
||||||
password = v.get('password')
|
password = v.get('password')
|
||||||
config = MinioConfig(id=id_count, name=name, host=host, port=port, user=user, password=password, service_type="file_store",
|
config = MinioConfig(id=id_count, name=name, host=host, port=port, user=user, password=password, service_type="file_store",
|
||||||
store_type="minio")
|
store_type="minio", detail_func_name="check_minio_alive")
|
||||||
configurations.append(config)
|
configurations.append(config)
|
||||||
id_count += 1
|
id_count += 1
|
||||||
case "redis":
|
case "redis":
|
||||||
@ -258,7 +279,7 @@ def load_configurations(config_path: str) -> list[BaseConfig]:
|
|||||||
password = v.get('password')
|
password = v.get('password')
|
||||||
db: int = v.get('db')
|
db: int = v.get('db')
|
||||||
config = RedisConfig(id=id_count, name=name, host=host, port=port, password=password, database=db,
|
config = RedisConfig(id=id_count, name=name, host=host, port=port, password=password, database=db,
|
||||||
service_type="message_queue", mq_type="redis")
|
service_type="message_queue", mq_type="redis", detail_func_name="get_redis_info")
|
||||||
configurations.append(config)
|
configurations.append(config)
|
||||||
id_count += 1
|
id_count += 1
|
||||||
case "mysql":
|
case "mysql":
|
||||||
@ -268,7 +289,7 @@ def load_configurations(config_path: str) -> list[BaseConfig]:
|
|||||||
username = v.get('user')
|
username = v.get('user')
|
||||||
password = v.get('password')
|
password = v.get('password')
|
||||||
config = MySQLConfig(id=id_count, name=name, host=host, port=port, username=username, password=password,
|
config = MySQLConfig(id=id_count, name=name, host=host, port=port, username=username, password=password,
|
||||||
service_type="meta_data", meta_type="mysql")
|
service_type="meta_data", meta_type="mysql", detail_func_name="get_mysql_status")
|
||||||
configurations.append(config)
|
configurations.append(config)
|
||||||
id_count += 1
|
id_count += 1
|
||||||
case "admin":
|
case "admin":
|
||||||
15
admin/server/models.py
Normal file
15
admin/server/models.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
34
admin/server/responses.py
Normal file
34
admin/server/responses.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
from flask import jsonify
|
||||||
|
|
||||||
|
|
||||||
|
def success_response(data=None, message="Success", code=0):
|
||||||
|
return jsonify({
|
||||||
|
"code": code,
|
||||||
|
"message": message,
|
||||||
|
"data": data
|
||||||
|
}), 200
|
||||||
|
|
||||||
|
|
||||||
|
def error_response(message="Error", code=-1, data=None):
|
||||||
|
return jsonify({
|
||||||
|
"code": code,
|
||||||
|
"message": message,
|
||||||
|
"data": data
|
||||||
|
}), 400
|
||||||
@ -1,8 +1,26 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
from flask import Blueprint, request
|
from flask import Blueprint, request
|
||||||
from auth import login_verify
|
|
||||||
|
from admin.server.auth import login_verify
|
||||||
from responses import success_response, error_response
|
from responses import success_response, error_response
|
||||||
from services import UserMgr, ServiceMgr
|
from services import UserMgr, ServiceMgr, UserServiceMgr
|
||||||
from exceptions import AdminException
|
from api.common.exceptions import AdminException
|
||||||
|
|
||||||
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
|
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
|
||||||
|
|
||||||
@ -38,21 +56,29 @@ def create_user():
|
|||||||
password = data['password']
|
password = data['password']
|
||||||
role = data.get('role', 'user')
|
role = data.get('role', 'user')
|
||||||
|
|
||||||
user = UserMgr.create_user(username, password, role)
|
res = UserMgr.create_user(username, password, role)
|
||||||
return success_response(user, "User created successfully", 201)
|
if res["success"]:
|
||||||
|
user_info = res["user_info"]
|
||||||
|
user_info.pop("password") # do not return password
|
||||||
|
return success_response(user_info, "User created successfully")
|
||||||
|
else:
|
||||||
|
return error_response("create user failed")
|
||||||
|
|
||||||
except AdminException as e:
|
except AdminException as e:
|
||||||
return error_response(e.message, e.code)
|
return error_response(e.message, e.code)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return error_response(str(e), 500)
|
return error_response(str(e))
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/users/<username>', methods=['DELETE'])
|
@admin_bp.route('/users/<username>', methods=['DELETE'])
|
||||||
@login_verify
|
@login_verify
|
||||||
def delete_user(username):
|
def delete_user(username):
|
||||||
try:
|
try:
|
||||||
UserMgr.delete_user(username)
|
res = UserMgr.delete_user(username)
|
||||||
return success_response(None, "User and all data deleted successfully")
|
if res["success"]:
|
||||||
|
return success_response(None, res["message"])
|
||||||
|
else:
|
||||||
|
return error_response(res["message"])
|
||||||
|
|
||||||
except AdminException as e:
|
except AdminException as e:
|
||||||
return error_response(e.message, e.code)
|
return error_response(e.message, e.code)
|
||||||
@ -69,8 +95,8 @@ def change_password(username):
|
|||||||
return error_response("New password is required", 400)
|
return error_response("New password is required", 400)
|
||||||
|
|
||||||
new_password = data['new_password']
|
new_password = data['new_password']
|
||||||
UserMgr.update_user_password(username, new_password)
|
msg = UserMgr.update_user_password(username, new_password)
|
||||||
return success_response(None, "Password updated successfully")
|
return success_response(None, msg)
|
||||||
|
|
||||||
except AdminException as e:
|
except AdminException as e:
|
||||||
return error_response(e.message, e.code)
|
return error_response(e.message, e.code)
|
||||||
@ -78,6 +104,22 @@ def change_password(username):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>/activate', methods=['PUT'])
|
||||||
|
@login_verify
|
||||||
|
def alter_user_activate_status(username):
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
if not data or 'activate_status' not in data:
|
||||||
|
return error_response("Activation status is required", 400)
|
||||||
|
activate_status = data['activate_status']
|
||||||
|
msg = UserMgr.update_user_activate_status(username, activate_status)
|
||||||
|
return success_response(None, msg)
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/users/<username>', methods=['GET'])
|
@admin_bp.route('/users/<username>', methods=['GET'])
|
||||||
@login_verify
|
@login_verify
|
||||||
def get_user_details(username):
|
def get_user_details(username):
|
||||||
@ -91,6 +133,32 @@ def get_user_details(username):
|
|||||||
return error_response(str(e), 500)
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>/datasets', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_user_datasets(username):
|
||||||
|
try:
|
||||||
|
datasets_list = UserServiceMgr.get_user_datasets(username)
|
||||||
|
return success_response(datasets_list)
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_bp.route('/users/<username>/agents', methods=['GET'])
|
||||||
|
@login_verify
|
||||||
|
def get_user_agents(username):
|
||||||
|
try:
|
||||||
|
agents_list = UserServiceMgr.get_user_agents(username)
|
||||||
|
return success_response(agents_list)
|
||||||
|
|
||||||
|
except AdminException as e:
|
||||||
|
return error_response(e.message, e.code)
|
||||||
|
except Exception as e:
|
||||||
|
return error_response(str(e), 500)
|
||||||
|
|
||||||
|
|
||||||
@admin_bp.route('/services', methods=['GET'])
|
@admin_bp.route('/services', methods=['GET'])
|
||||||
@login_verify
|
@login_verify
|
||||||
def get_services():
|
def get_services():
|
||||||
222
admin/server/services.py
Normal file
222
admin/server/services.py
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
import re
|
||||||
|
from werkzeug.security import check_password_hash
|
||||||
|
from api.db import ActiveEnum
|
||||||
|
from api.db.services import UserService
|
||||||
|
from api.db.joint_services.user_account_service import create_new_user, delete_user_data
|
||||||
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
|
from api.db.services.user_service import TenantService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.utils.crypt import decrypt
|
||||||
|
from api.utils import health_utils
|
||||||
|
|
||||||
|
from api.common.exceptions import AdminException, UserAlreadyExistsError, UserNotFoundError
|
||||||
|
from admin.server.config import SERVICE_CONFIGS
|
||||||
|
|
||||||
|
|
||||||
|
class UserMgr:
|
||||||
|
@staticmethod
|
||||||
|
def get_all_users():
|
||||||
|
users = UserService.get_all_users()
|
||||||
|
result = []
|
||||||
|
for user in users:
|
||||||
|
result.append({'email': user.email, 'nickname': user.nickname, 'create_date': user.create_date,
|
||||||
|
'is_active': user.is_active})
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_details(username):
|
||||||
|
# use email to query
|
||||||
|
users = UserService.query_user_by_email(username)
|
||||||
|
result = []
|
||||||
|
for user in users:
|
||||||
|
result.append({
|
||||||
|
'email': user.email,
|
||||||
|
'language': user.language,
|
||||||
|
'last_login_time': user.last_login_time,
|
||||||
|
'is_authenticated': user.is_authenticated,
|
||||||
|
'is_active': user.is_active,
|
||||||
|
'is_anonymous': user.is_anonymous,
|
||||||
|
'login_channel': user.login_channel,
|
||||||
|
'status': user.status,
|
||||||
|
'is_superuser': user.is_superuser,
|
||||||
|
'create_date': user.create_date,
|
||||||
|
'update_date': user.update_date
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_user(username, password, role="user") -> dict:
|
||||||
|
# Validate the email address
|
||||||
|
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", username):
|
||||||
|
raise AdminException(f"Invalid email address: {username}!")
|
||||||
|
# Check if the email address is already used
|
||||||
|
if UserService.query(email=username):
|
||||||
|
raise UserAlreadyExistsError(username)
|
||||||
|
# Construct user info data
|
||||||
|
user_info_dict = {
|
||||||
|
"email": username,
|
||||||
|
"nickname": "", # ask user to edit it manually in settings.
|
||||||
|
"password": decrypt(password),
|
||||||
|
"login_channel": "password",
|
||||||
|
"is_superuser": role == "admin",
|
||||||
|
}
|
||||||
|
return create_new_user(user_info_dict)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_user(username):
|
||||||
|
# use email to delete
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
if len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
usr = user_list[0]
|
||||||
|
return delete_user_data(usr.id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_user_password(username, new_password) -> str:
|
||||||
|
# use email to find user. check exist and unique.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# check new_password different from old.
|
||||||
|
usr = user_list[0]
|
||||||
|
psw = decrypt(new_password)
|
||||||
|
if check_password_hash(usr.password, psw):
|
||||||
|
return "Same password, no need to update!"
|
||||||
|
# update password
|
||||||
|
UserService.update_user_password(usr.id, psw)
|
||||||
|
return "Password updated successfully!"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_user_activate_status(username, activate_status: str):
|
||||||
|
# use email to find user. check exist and unique.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# check activate status different from new
|
||||||
|
usr = user_list[0]
|
||||||
|
# format activate_status before handle
|
||||||
|
_activate_status = activate_status.lower()
|
||||||
|
target_status = {
|
||||||
|
'on': ActiveEnum.ACTIVE.value,
|
||||||
|
'off': ActiveEnum.INACTIVE.value,
|
||||||
|
}.get(_activate_status)
|
||||||
|
if not target_status:
|
||||||
|
raise AdminException(f"Invalid activate_status: {activate_status}")
|
||||||
|
if target_status == usr.is_active:
|
||||||
|
return f"User activate status is already {_activate_status}!"
|
||||||
|
# update is_active
|
||||||
|
UserService.update_user(usr.id, {"is_active": target_status})
|
||||||
|
return f"Turn {_activate_status} user activate status successfully!"
|
||||||
|
|
||||||
|
|
||||||
|
class UserServiceMgr:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_datasets(username):
|
||||||
|
# use email to find user.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# find tenants
|
||||||
|
usr = user_list[0]
|
||||||
|
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||||
|
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||||
|
# filter permitted kb and owned kb
|
||||||
|
return KnowledgebaseService.get_all_kb_by_tenant_ids(tenant_ids, usr.id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_agents(username):
|
||||||
|
# use email to find user.
|
||||||
|
user_list = UserService.query_user_by_email(username)
|
||||||
|
if not user_list:
|
||||||
|
raise UserNotFoundError(username)
|
||||||
|
elif len(user_list) > 1:
|
||||||
|
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||||
|
# find tenants
|
||||||
|
usr = user_list[0]
|
||||||
|
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||||
|
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||||
|
# filter permitted agents and owned agents
|
||||||
|
res = UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
|
||||||
|
return [{
|
||||||
|
'title': r['title'],
|
||||||
|
'permission': r['permission'],
|
||||||
|
'canvas_type': r['canvas_type'],
|
||||||
|
'canvas_category': r['canvas_category']
|
||||||
|
} for r in res]
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceMgr:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_services():
|
||||||
|
result = []
|
||||||
|
configs = SERVICE_CONFIGS.configs
|
||||||
|
for service_id, config in enumerate(configs):
|
||||||
|
config_dict = config.to_dict()
|
||||||
|
try:
|
||||||
|
service_detail = ServiceMgr.get_service_details(service_id)
|
||||||
|
if service_detail['alive']:
|
||||||
|
config_dict['status'] = 'Alive'
|
||||||
|
else:
|
||||||
|
config_dict['status'] = 'Timeout'
|
||||||
|
except Exception:
|
||||||
|
config_dict['status'] = 'Timeout'
|
||||||
|
result.append(config_dict)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_services_by_type(service_type_str: str):
|
||||||
|
raise AdminException("get_services_by_type: not implemented")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_service_details(service_id: int):
|
||||||
|
service_id = int(service_id)
|
||||||
|
configs = SERVICE_CONFIGS.configs
|
||||||
|
service_config_mapping = {
|
||||||
|
c.id: {
|
||||||
|
'name': c.name,
|
||||||
|
'detail_func_name': c.detail_func_name
|
||||||
|
} for c in configs
|
||||||
|
}
|
||||||
|
service_info = service_config_mapping.get(service_id, {})
|
||||||
|
if not service_info:
|
||||||
|
raise AdminException(f"Invalid service_id: {service_id}")
|
||||||
|
|
||||||
|
detail_func = getattr(health_utils, service_info.get('detail_func_name'))
|
||||||
|
res = detail_func()
|
||||||
|
res.update({'service_name': service_info.get('name')})
|
||||||
|
return res
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def shutdown_service(service_id: int):
|
||||||
|
raise AdminException("shutdown_service: not implemented")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def restart_service(service_id: int):
|
||||||
|
raise AdminException("restart_service: not implemented")
|
||||||
@ -1,54 +0,0 @@
|
|||||||
from api.db.services import UserService
|
|
||||||
from exceptions import AdminException
|
|
||||||
from config import SERVICE_CONFIGS
|
|
||||||
|
|
||||||
class UserMgr:
|
|
||||||
@staticmethod
|
|
||||||
def get_all_users():
|
|
||||||
users = UserService.get_all_users()
|
|
||||||
result = []
|
|
||||||
for user in users:
|
|
||||||
result.append({'email': user.email, 'nickname': user.nickname, 'create_date': user.create_date, 'is_active': user.is_active})
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_user_details(username):
|
|
||||||
raise AdminException("get_user_details: not implemented")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_user(username, password, role="user"):
|
|
||||||
raise AdminException("create_user: not implemented")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def delete_user(username):
|
|
||||||
raise AdminException("delete_user: not implemented")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def update_user_password(username, new_password):
|
|
||||||
raise AdminException("update_user_password: not implemented")
|
|
||||||
|
|
||||||
class ServiceMgr:
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_all_services():
|
|
||||||
result = []
|
|
||||||
configs = SERVICE_CONFIGS.configs
|
|
||||||
for config in configs:
|
|
||||||
result.append(config.to_dict())
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_services_by_type(service_type_str: str):
|
|
||||||
raise AdminException("get_services_by_type: not implemented")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_service_details(service_id: int):
|
|
||||||
raise AdminException("get_service_details: not implemented")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def shutdown_service(service_id: int):
|
|
||||||
raise AdminException("shutdown_service: not implemented")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def restart_service(service_id: int):
|
|
||||||
raise AdminException("restart_service: not implemented")
|
|
||||||
@ -27,7 +27,7 @@ from agent.component import component_class
|
|||||||
from agent.component.base import ComponentBase
|
from agent.component.base import ComponentBase
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.utils import get_uuid, hash_str2int
|
from api.utils import get_uuid, hash_str2int
|
||||||
from rag.prompts.prompts import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
|
||||||
class Graph:
|
class Graph:
|
||||||
@ -153,6 +153,16 @@ class Graph:
|
|||||||
def get_tenant_id(self):
|
def get_tenant_id(self):
|
||||||
return self._tenant_id
|
return self._tenant_id
|
||||||
|
|
||||||
|
def get_variable_value(self, exp: str) -> Any:
|
||||||
|
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
|
||||||
|
if exp.find("@") < 0:
|
||||||
|
return self.globals[exp]
|
||||||
|
cpn_id, var_nm = exp.split("@")
|
||||||
|
cpn = self.get_component(cpn_id)
|
||||||
|
if not cpn:
|
||||||
|
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
||||||
|
return cpn["obj"].output(var_nm)
|
||||||
|
|
||||||
|
|
||||||
class Canvas(Graph):
|
class Canvas(Graph):
|
||||||
|
|
||||||
@ -193,7 +203,6 @@ class Canvas(Graph):
|
|||||||
self.history = []
|
self.history = []
|
||||||
self.retrieval = []
|
self.retrieval = []
|
||||||
self.memory = []
|
self.memory = []
|
||||||
|
|
||||||
for k in self.globals.keys():
|
for k in self.globals.keys():
|
||||||
if isinstance(self.globals[k], str):
|
if isinstance(self.globals[k], str):
|
||||||
self.globals[k] = ""
|
self.globals[k] = ""
|
||||||
@ -282,7 +291,6 @@ class Canvas(Graph):
|
|||||||
"thoughts": self.get_component_thoughts(self.path[i])
|
"thoughts": self.get_component_thoughts(self.path[i])
|
||||||
})
|
})
|
||||||
_run_batch(idx, to)
|
_run_batch(idx, to)
|
||||||
|
|
||||||
# post processing of components invocation
|
# post processing of components invocation
|
||||||
for i in range(idx, to):
|
for i in range(idx, to):
|
||||||
cpn = self.get_component(self.path[i])
|
cpn = self.get_component(self.path[i])
|
||||||
@ -383,7 +391,6 @@ class Canvas(Graph):
|
|||||||
self.path = path
|
self.path = path
|
||||||
yield decorate("user_inputs", {"inputs": another_inputs, "tips": tips})
|
yield decorate("user_inputs", {"inputs": another_inputs, "tips": tips})
|
||||||
return
|
return
|
||||||
|
|
||||||
self.path = self.path[:idx]
|
self.path = self.path[:idx]
|
||||||
if not self.error:
|
if not self.error:
|
||||||
yield decorate("workflow_finished",
|
yield decorate("workflow_finished",
|
||||||
@ -406,16 +413,6 @@ class Canvas(Graph):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_variable_value(self, exp: str) -> Any:
|
|
||||||
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
|
|
||||||
if exp.find("@") < 0:
|
|
||||||
return self.globals[exp]
|
|
||||||
cpn_id, var_nm = exp.split("@")
|
|
||||||
cpn = self.get_component(cpn_id)
|
|
||||||
if not cpn:
|
|
||||||
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
|
||||||
return cpn["obj"].output(var_nm)
|
|
||||||
|
|
||||||
def get_history(self, window_size):
|
def get_history(self, window_size):
|
||||||
convs = []
|
convs = []
|
||||||
if window_size <= 0:
|
if window_size <= 0:
|
||||||
@ -490,7 +487,8 @@ class Canvas(Graph):
|
|||||||
|
|
||||||
r = self.retrieval[-1]
|
r = self.retrieval[-1]
|
||||||
for ck in chunks_format({"chunks": chunks}):
|
for ck in chunks_format({"chunks": chunks}):
|
||||||
cid = hash_str2int(ck["id"], 100)
|
cid = hash_str2int(ck["id"], 500)
|
||||||
|
# cid = uuid.uuid5(uuid.NAMESPACE_DNS, ck["id"])
|
||||||
if cid not in r:
|
if cid not in r:
|
||||||
r["chunks"][cid] = ck
|
r["chunks"][cid] = ck
|
||||||
|
|
||||||
|
|||||||
@ -28,9 +28,8 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.mcp_server_service import MCPServerService
|
from api.db.services.mcp_server_service import MCPServerService
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.prompts import message_fit_in
|
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
|
||||||
from rag.prompts.prompts import next_step, COMPLETE_TASK, analyze_task, \
|
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
|
||||||
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question
|
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||||
from agent.component.llm import LLMParam, LLM
|
from agent.component.llm import LLMParam, LLM
|
||||||
|
|
||||||
@ -138,7 +137,7 @@ class Agent(LLM, ToolBase):
|
|||||||
res.update(cpn.get_input_form())
|
res.update(cpn.get_input_form())
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if kwargs.get("user_prompt"):
|
if kwargs.get("user_prompt"):
|
||||||
usr_pmt = ""
|
usr_pmt = ""
|
||||||
@ -347,3 +346,11 @@ Respond immediately with your final comprehensive answer.
|
|||||||
|
|
||||||
return "Error occurred."
|
return "Error occurred."
|
||||||
|
|
||||||
|
def reset(self, temp=False):
|
||||||
|
"""
|
||||||
|
Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession.
|
||||||
|
"""
|
||||||
|
for k, cpn in self.tools.items():
|
||||||
|
if hasattr(cpn, "reset") and callable(cpn.reset):
|
||||||
|
cpn.reset()
|
||||||
|
|
||||||
|
|||||||
@ -244,7 +244,7 @@ class ComponentParamBase(ABC):
|
|||||||
|
|
||||||
if not value_legal:
|
if not value_legal:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
|
"Please check runtime conf, {} = {} does not match user-parameter restriction".format(
|
||||||
variable, value
|
variable, value
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -431,7 +431,7 @@ class ComponentBase(ABC):
|
|||||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
return self.output()
|
return self.output()
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from rag.llm.chat_model import ERROR_PREFIX
|
|||||||
class CategorizeParam(LLMParam):
|
class CategorizeParam(LLMParam):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Define the Categorize component parameters.
|
Define the categorize component parameters.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -80,7 +80,7 @@ Here's description of each category:
|
|||||||
- Prioritize the most specific applicable category
|
- Prioritize the most specific applicable category
|
||||||
- Return only the category name without explanations
|
- Return only the category name without explanations
|
||||||
- Use "Other" only when no other category fits
|
- Use "Other" only when no other category fits
|
||||||
|
|
||||||
""".format(
|
""".format(
|
||||||
"\n - ".join(list(self.category_description.keys())),
|
"\n - ".join(list(self.category_description.keys())),
|
||||||
"\n".join(descriptions)
|
"\n".join(descriptions)
|
||||||
@ -96,7 +96,7 @@ Here's description of each category:
|
|||||||
class Categorize(LLM, ABC):
|
class Categorize(LLM, ABC):
|
||||||
component_name = "Categorize"
|
component_name = "Categorize"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||||
if not msg:
|
if not msg:
|
||||||
@ -112,7 +112,7 @@ class Categorize(LLM, ABC):
|
|||||||
|
|
||||||
user_prompt = """
|
user_prompt = """
|
||||||
---- Real Data ----
|
---- Real Data ----
|
||||||
{} →
|
{} →
|
||||||
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
|
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
|
||||||
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
||||||
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
||||||
@ -134,4 +134,4 @@ class Categorize(LLM, ABC):
|
|||||||
self.set_output("_next", cpn_ids)
|
self.set_output("_next", cpn_ids)
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
||||||
|
|||||||
@ -19,11 +19,12 @@ import os
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from deepdoc.parser import HtmlParser
|
from deepdoc.parser import HtmlParser
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeParam(ComponentParamBase):
|
class InvokeParam(ComponentParamBase):
|
||||||
@ -43,17 +44,17 @@ class InvokeParam(ComponentParamBase):
|
|||||||
self.datatype = "json" # New parameter to determine data posting type
|
self.datatype = "json" # New parameter to determine data posting type
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put'])
|
self.check_valid_value(self.method.lower(), "Type of content from the crawler", ["get", "post", "put"])
|
||||||
self.check_empty(self.url, "End point URL")
|
self.check_empty(self.url, "End point URL")
|
||||||
self.check_positive_integer(self.timeout, "Timeout time in second")
|
self.check_positive_integer(self.timeout, "Timeout time in second")
|
||||||
self.check_boolean(self.clean_html, "Clean HTML")
|
self.check_boolean(self.clean_html, "Clean HTML")
|
||||||
self.check_valid_value(self.datatype.lower(), "Data post type", ['json', 'formdata']) # Check for valid datapost value
|
self.check_valid_value(self.datatype.lower(), "Data post type", ["json", "formdata"]) # Check for valid datapost value
|
||||||
|
|
||||||
|
|
||||||
class Invoke(ComponentBase, ABC):
|
class Invoke(ComponentBase, ABC):
|
||||||
component_name = "Invoke"
|
component_name = "Invoke"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
args = {}
|
args = {}
|
||||||
for para in self._param.variables:
|
for para in self._param.variables:
|
||||||
@ -63,6 +64,18 @@ class Invoke(ComponentBase, ABC):
|
|||||||
args[para["key"]] = self._canvas.get_variable_value(para["ref"])
|
args[para["key"]] = self._canvas.get_variable_value(para["ref"])
|
||||||
|
|
||||||
url = self._param.url.strip()
|
url = self._param.url.strip()
|
||||||
|
|
||||||
|
def replace_variable(match):
|
||||||
|
var_name = match.group(1)
|
||||||
|
try:
|
||||||
|
value = self._canvas.get_variable_value(var_name)
|
||||||
|
return str(value or "")
|
||||||
|
except Exception:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# {base_url} or {component_id@variable_name}
|
||||||
|
url = re.sub(r"\{([a-zA-Z_][a-zA-Z0-9_.@-]*)\}", replace_variable, url)
|
||||||
|
|
||||||
if url.find("http") != 0:
|
if url.find("http") != 0:
|
||||||
url = "http://" + url
|
url = "http://" + url
|
||||||
|
|
||||||
@ -75,52 +88,32 @@ class Invoke(ComponentBase, ABC):
|
|||||||
proxies = {"http": self._param.proxy, "https": self._param.proxy}
|
proxies = {"http": self._param.proxy, "https": self._param.proxy}
|
||||||
|
|
||||||
last_e = ""
|
last_e = ""
|
||||||
for _ in range(self._param.max_retries+1):
|
for _ in range(self._param.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
if method == 'get':
|
if method == "get":
|
||||||
response = requests.get(url=url,
|
response = requests.get(url=url, params=args, headers=headers, proxies=proxies, timeout=self._param.timeout)
|
||||||
params=args,
|
|
||||||
headers=headers,
|
|
||||||
proxies=proxies,
|
|
||||||
timeout=self._param.timeout)
|
|
||||||
if self._param.clean_html:
|
if self._param.clean_html:
|
||||||
sections = HtmlParser()(None, response.content)
|
sections = HtmlParser()(None, response.content)
|
||||||
self.set_output("result", "\n".join(sections))
|
self.set_output("result", "\n".join(sections))
|
||||||
else:
|
else:
|
||||||
self.set_output("result", response.text)
|
self.set_output("result", response.text)
|
||||||
|
|
||||||
if method == 'put':
|
if method == "put":
|
||||||
if self._param.datatype.lower() == 'json':
|
if self._param.datatype.lower() == "json":
|
||||||
response = requests.put(url=url,
|
response = requests.put(url=url, json=args, headers=headers, proxies=proxies, timeout=self._param.timeout)
|
||||||
json=args,
|
|
||||||
headers=headers,
|
|
||||||
proxies=proxies,
|
|
||||||
timeout=self._param.timeout)
|
|
||||||
else:
|
else:
|
||||||
response = requests.put(url=url,
|
response = requests.put(url=url, data=args, headers=headers, proxies=proxies, timeout=self._param.timeout)
|
||||||
data=args,
|
|
||||||
headers=headers,
|
|
||||||
proxies=proxies,
|
|
||||||
timeout=self._param.timeout)
|
|
||||||
if self._param.clean_html:
|
if self._param.clean_html:
|
||||||
sections = HtmlParser()(None, response.content)
|
sections = HtmlParser()(None, response.content)
|
||||||
self.set_output("result", "\n".join(sections))
|
self.set_output("result", "\n".join(sections))
|
||||||
else:
|
else:
|
||||||
self.set_output("result", response.text)
|
self.set_output("result", response.text)
|
||||||
|
|
||||||
if method == 'post':
|
if method == "post":
|
||||||
if self._param.datatype.lower() == 'json':
|
if self._param.datatype.lower() == "json":
|
||||||
response = requests.post(url=url,
|
response = requests.post(url=url, json=args, headers=headers, proxies=proxies, timeout=self._param.timeout)
|
||||||
json=args,
|
|
||||||
headers=headers,
|
|
||||||
proxies=proxies,
|
|
||||||
timeout=self._param.timeout)
|
|
||||||
else:
|
else:
|
||||||
response = requests.post(url=url,
|
response = requests.post(url=url, data=args, headers=headers, proxies=proxies, timeout=self._param.timeout)
|
||||||
data=args,
|
|
||||||
headers=headers,
|
|
||||||
proxies=proxies,
|
|
||||||
timeout=self._param.timeout)
|
|
||||||
if self._param.clean_html:
|
if self._param.clean_html:
|
||||||
self.set_output("result", "\n".join(sections))
|
self.set_output("result", "\n".join(sections))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -26,8 +26,7 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.prompts import message_fit_in, citation_prompt
|
from rag.prompts.generator import tool_call_summary, message_fit_in, citation_prompt
|
||||||
from rag.prompts.prompts import tool_call_summary
|
|
||||||
|
|
||||||
|
|
||||||
class LLMParam(ComponentParamBase):
|
class LLMParam(ComponentParamBase):
|
||||||
@ -82,9 +81,9 @@ class LLMParam(ComponentParamBase):
|
|||||||
|
|
||||||
class LLM(ComponentBase):
|
class LLM(ComponentBase):
|
||||||
component_name = "LLM"
|
component_name = "LLM"
|
||||||
|
|
||||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
def __init__(self, canvas, component_id, param: ComponentParamBase):
|
||||||
super().__init__(canvas, id, param)
|
super().__init__(canvas, component_id, param)
|
||||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id),
|
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id),
|
||||||
self._param.llm_id, max_retries=self._param.max_retries,
|
self._param.llm_id, max_retries=self._param.max_retries,
|
||||||
retry_interval=self._param.delay_after_error
|
retry_interval=self._param.delay_after_error
|
||||||
@ -102,6 +101,8 @@ class LLM(ComponentBase):
|
|||||||
|
|
||||||
def get_input_elements(self) -> dict[str, Any]:
|
def get_input_elements(self) -> dict[str, Any]:
|
||||||
res = self.get_input_elements_from_text(self._param.sys_prompt)
|
res = self.get_input_elements_from_text(self._param.sys_prompt)
|
||||||
|
if isinstance(self._param.prompts, str):
|
||||||
|
self._param.prompts = [{"role": "user", "content": self._param.prompts}]
|
||||||
for prompt in self._param.prompts:
|
for prompt in self._param.prompts:
|
||||||
d = self.get_input_elements_from_text(prompt["content"])
|
d = self.get_input_elements_from_text(prompt["content"])
|
||||||
res.update(d)
|
res.update(d)
|
||||||
@ -113,6 +114,17 @@ class LLM(ComponentBase):
|
|||||||
def add2system_prompt(self, txt):
|
def add2system_prompt(self, txt):
|
||||||
self._param.sys_prompt += txt
|
self._param.sys_prompt += txt
|
||||||
|
|
||||||
|
def _sys_prompt_and_msg(self, msg, args):
|
||||||
|
if isinstance(self._param.prompts, str):
|
||||||
|
self._param.prompts = [{"role": "user", "content": self._param.prompts}]
|
||||||
|
for p in self._param.prompts:
|
||||||
|
if msg and msg[-1]["role"] == p["role"]:
|
||||||
|
continue
|
||||||
|
p = deepcopy(p)
|
||||||
|
p["content"] = self.string_format(p["content"], args)
|
||||||
|
msg.append(p)
|
||||||
|
return msg, self.string_format(self._param.sys_prompt, args)
|
||||||
|
|
||||||
def _prepare_prompt_variables(self):
|
def _prepare_prompt_variables(self):
|
||||||
if self._param.visual_files_var:
|
if self._param.visual_files_var:
|
||||||
self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
|
self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
|
||||||
@ -128,7 +140,6 @@ class LLM(ComponentBase):
|
|||||||
|
|
||||||
args = {}
|
args = {}
|
||||||
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
|
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
|
||||||
sys_prompt = self._param.sys_prompt
|
|
||||||
for k, o in vars.items():
|
for k, o in vars.items():
|
||||||
args[k] = o["value"]
|
args[k] = o["value"]
|
||||||
if not isinstance(args[k], str):
|
if not isinstance(args[k], str):
|
||||||
@ -138,16 +149,8 @@ class LLM(ComponentBase):
|
|||||||
args[k] = str(args[k])
|
args[k] = str(args[k])
|
||||||
self.set_input_value(k, args[k])
|
self.set_input_value(k, args[k])
|
||||||
|
|
||||||
msg = self._canvas.get_history(self._param.message_history_window_size)[:-1]
|
msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args)
|
||||||
for p in self._param.prompts:
|
|
||||||
if msg and msg[-1]["role"] == p["role"]:
|
|
||||||
continue
|
|
||||||
msg.append(deepcopy(p))
|
|
||||||
|
|
||||||
sys_prompt = self.string_format(sys_prompt, args)
|
|
||||||
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
|
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
|
||||||
for m in msg:
|
|
||||||
m["content"] = self.string_format(m["content"], args)
|
|
||||||
if self._param.cite and self._canvas.get_reference()["chunks"]:
|
if self._param.cite and self._canvas.get_reference()["chunks"]:
|
||||||
sys_prompt += citation_prompt(user_defined_prompt)
|
sys_prompt += citation_prompt(user_defined_prompt)
|
||||||
|
|
||||||
@ -202,7 +205,7 @@ class LLM(ComponentBase):
|
|||||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
||||||
yield delta(txt)
|
yield delta(txt)
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
def clean_formated_answer(ans: str) -> str:
|
def clean_formated_answer(ans: str) -> str:
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
@ -210,7 +213,7 @@ class LLM(ComponentBase):
|
|||||||
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||||
|
|
||||||
prompt, msg, _ = self._prepare_prompt_variables()
|
prompt, msg, _ = self._prepare_prompt_variables()
|
||||||
error = ""
|
error: str = ""
|
||||||
|
|
||||||
if self._param.output_structure:
|
if self._param.output_structure:
|
||||||
prompt += "\nThe output MUST follow this JSON format:\n"+json.dumps(self._param.output_structure, ensure_ascii=False, indent=2)
|
prompt += "\nThe output MUST follow this JSON format:\n"+json.dumps(self._param.output_structure, ensure_ascii=False, indent=2)
|
||||||
|
|||||||
@ -49,7 +49,7 @@ class MessageParam(ComponentParamBase):
|
|||||||
class Message(ComponentBase):
|
class Message(ComponentBase):
|
||||||
component_name = "Message"
|
component_name = "Message"
|
||||||
|
|
||||||
def get_kwargs(self, script:str, kwargs:dict = {}, delimeter:str=None) -> tuple[str, dict[str, str | list | Any]]:
|
def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]:
|
||||||
for k,v in self.get_input_elements_from_text(script).items():
|
for k,v in self.get_input_elements_from_text(script).items():
|
||||||
if k in kwargs:
|
if k in kwargs:
|
||||||
continue
|
continue
|
||||||
@ -60,8 +60,8 @@ class Message(ComponentBase):
|
|||||||
if isinstance(v, partial):
|
if isinstance(v, partial):
|
||||||
for t in v():
|
for t in v():
|
||||||
ans += t
|
ans += t
|
||||||
elif isinstance(v, list) and delimeter:
|
elif isinstance(v, list) and delimiter:
|
||||||
ans = delimeter.join([str(vv) for vv in v])
|
ans = delimiter.join([str(vv) for vv in v])
|
||||||
elif not isinstance(v, str):
|
elif not isinstance(v, str):
|
||||||
try:
|
try:
|
||||||
ans = json.dumps(v, ensure_ascii=False)
|
ans = json.dumps(v, ensure_ascii=False)
|
||||||
@ -127,7 +127,7 @@ class Message(ComponentBase):
|
|||||||
]
|
]
|
||||||
return any([re.search(p, content) for p in patt])
|
return any([re.search(p, content) for p in patt])
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
rand_cnt = random.choice(self._param.content)
|
rand_cnt = random.choice(self._param.content)
|
||||||
if self._param.stream and not self._is_jinjia2(rand_cnt):
|
if self._param.stream and not self._is_jinjia2(rand_cnt):
|
||||||
|
|||||||
@ -56,7 +56,7 @@ class StringTransform(Message, ABC):
|
|||||||
"type": "line"
|
"type": "line"
|
||||||
} for k, o in self.get_input_elements_from_text(self._param.script).items()}
|
} for k, o in self.get_input_elements_from_text(self._param.script).items()}
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if self._param.method == "split":
|
if self._param.method == "split":
|
||||||
self._split(kwargs.get("line"))
|
self._split(kwargs.get("line"))
|
||||||
@ -90,7 +90,7 @@ class StringTransform(Message, ABC):
|
|||||||
for k,v in kwargs.items():
|
for k,v in kwargs.items():
|
||||||
if not v:
|
if not v:
|
||||||
v = ""
|
v = ""
|
||||||
script = re.sub(k, v, script)
|
script = re.sub(k, lambda match: v, script)
|
||||||
|
|
||||||
self.set_output("result", script)
|
self.set_output("result", script)
|
||||||
|
|
||||||
|
|||||||
@ -61,7 +61,7 @@ class SwitchParam(ComponentParamBase):
|
|||||||
class Switch(ComponentBase, ABC):
|
class Switch(ComponentBase, ABC):
|
||||||
component_name = "Switch"
|
component_name = "Switch"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
for cond in self._param.conditions:
|
for cond in self._param.conditions:
|
||||||
res = []
|
res = []
|
||||||
|
|||||||
726
agent/templates/advanced_ingestion_pipeline.json
Normal file
726
agent/templates/advanced_ingestion_pipeline.json
Normal file
File diff suppressed because one or more lines are too long
493
agent/templates/chunk_summary.json
Normal file
493
agent/templates/chunk_summary.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1172
agent/templates/stock_research_report.json
Normal file
1172
agent/templates/stock_research_report.json
Normal file
File diff suppressed because one or more lines are too long
369
agent/templates/title_chunker.json
Normal file
369
agent/templates/title_chunker.json
Normal file
File diff suppressed because one or more lines are too long
@ -61,7 +61,7 @@ class ArXivParam(ToolParamBase):
|
|||||||
class ArXiv(ToolBase, ABC):
|
class ArXiv(ToolBase, ABC):
|
||||||
component_name = "ArXiv"
|
component_name = "ArXiv"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -97,6 +97,6 @@ class ArXiv(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from typing import TypedDict, List, Any
|
|||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
from agent.component.base import ComponentParamBase, ComponentBase
|
||||||
from api.utils import hash_str2int
|
from api.utils import hash_str2int
|
||||||
from rag.llm.chat_model import ToolCallSession
|
from rag.llm.chat_model import ToolCallSession
|
||||||
from rag.prompts.prompts import kb_prompt
|
from rag.prompts.generator import kb_prompt
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
|
|||||||
@ -129,7 +129,7 @@ module.exports = { main };
|
|||||||
class CodeExec(ToolBase, ABC):
|
class CodeExec(ToolBase, ABC):
|
||||||
component_name = "CodeExec"
|
component_name = "CodeExec"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
lang = kwargs.get("lang", self._param.lang)
|
lang = kwargs.get("lang", self._param.lang)
|
||||||
script = kwargs.get("script", self._param.script)
|
script = kwargs.get("script", self._param.script)
|
||||||
@ -156,7 +156,7 @@ class CodeExec(ToolBase, ABC):
|
|||||||
self.set_output("_ERROR", "construct code request error: " + str(e))
|
self.set_output("_ERROR", "construct code request error: " + str(e))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:")
|
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:")
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
|||||||
@ -73,7 +73,7 @@ class DuckDuckGoParam(ToolParamBase):
|
|||||||
class DuckDuckGo(ToolBase, ABC):
|
class DuckDuckGo(ToolBase, ABC):
|
||||||
component_name = "DuckDuckGo"
|
component_name = "DuckDuckGo"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -115,6 +115,6 @@ class DuckDuckGo(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -98,8 +98,8 @@ class EmailParam(ToolParamBase):
|
|||||||
|
|
||||||
class Email(ToolBase, ABC):
|
class Email(ToolBase, ABC):
|
||||||
component_name = "Email"
|
component_name = "Email"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("to_email"):
|
if not kwargs.get("to_email"):
|
||||||
self.set_output("success", False)
|
self.set_output("success", False)
|
||||||
@ -212,4 +212,4 @@ class Email(ToolBase, ABC):
|
|||||||
To: {}
|
To: {}
|
||||||
Subject: {}
|
Subject: {}
|
||||||
Your email is on its way—sit tight!
|
Your email is on its way—sit tight!
|
||||||
""".format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!"))
|
""".format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!"))
|
||||||
|
|||||||
@ -53,12 +53,13 @@ class ExeSQLParam(ToolParamBase):
|
|||||||
self.max_records = 1024
|
self.max_records = 1024
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql'])
|
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2', 'trino'])
|
||||||
self.check_empty(self.database, "Database name")
|
self.check_empty(self.database, "Database name")
|
||||||
self.check_empty(self.username, "database username")
|
self.check_empty(self.username, "database username")
|
||||||
self.check_empty(self.host, "IP Address")
|
self.check_empty(self.host, "IP Address")
|
||||||
self.check_positive_integer(self.port, "IP Port")
|
self.check_positive_integer(self.port, "IP Port")
|
||||||
self.check_empty(self.password, "Database password")
|
if self.db_type != "trino":
|
||||||
|
self.check_empty(self.password, "Database password")
|
||||||
self.check_positive_integer(self.max_records, "Maximum number of records")
|
self.check_positive_integer(self.max_records, "Maximum number of records")
|
||||||
if self.database == "rag_flow":
|
if self.database == "rag_flow":
|
||||||
if self.host == "ragflow-mysql":
|
if self.host == "ragflow-mysql":
|
||||||
@ -78,7 +79,7 @@ class ExeSQLParam(ToolParamBase):
|
|||||||
class ExeSQL(ToolBase, ABC):
|
class ExeSQL(ToolBase, ABC):
|
||||||
component_name = "ExeSQL"
|
component_name = "ExeSQL"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
|
|
||||||
def convert_decimals(obj):
|
def convert_decimals(obj):
|
||||||
@ -123,6 +124,94 @@ class ExeSQL(ToolBase, ABC):
|
|||||||
r'PWD=' + self._param.password
|
r'PWD=' + self._param.password
|
||||||
)
|
)
|
||||||
db = pyodbc.connect(conn_str)
|
db = pyodbc.connect(conn_str)
|
||||||
|
elif self._param.db_type == 'trino':
|
||||||
|
try:
|
||||||
|
import trino
|
||||||
|
from trino.auth import BasicAuthentication
|
||||||
|
except Exception:
|
||||||
|
raise Exception("Missing dependency 'trino'. Please install: pip install trino")
|
||||||
|
|
||||||
|
def _parse_catalog_schema(db: str):
|
||||||
|
if not db:
|
||||||
|
return None, None
|
||||||
|
if "." in db:
|
||||||
|
c, s = db.split(".", 1)
|
||||||
|
elif "/" in db:
|
||||||
|
c, s = db.split("/", 1)
|
||||||
|
else:
|
||||||
|
c, s = db, "default"
|
||||||
|
return c, s
|
||||||
|
|
||||||
|
catalog, schema = _parse_catalog_schema(self._param.database)
|
||||||
|
if not catalog:
|
||||||
|
raise Exception("For Trino, `database` must be 'catalog.schema' or at least 'catalog'.")
|
||||||
|
|
||||||
|
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
|
||||||
|
auth = None
|
||||||
|
if http_scheme == "https" and self._param.password:
|
||||||
|
auth = BasicAuthentication(self._param.username, self._param.password)
|
||||||
|
|
||||||
|
try:
|
||||||
|
db = trino.dbapi.connect(
|
||||||
|
host=self._param.host,
|
||||||
|
port=int(self._param.port or 8080),
|
||||||
|
user=self._param.username or "ragflow",
|
||||||
|
catalog=catalog,
|
||||||
|
schema=schema or "default",
|
||||||
|
http_scheme=http_scheme,
|
||||||
|
auth=auth
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("Database Connection Failed! \n" + str(e))
|
||||||
|
elif self._param.db_type == 'IBM DB2':
|
||||||
|
import ibm_db
|
||||||
|
conn_str = (
|
||||||
|
f"DATABASE={self._param.database};"
|
||||||
|
f"HOSTNAME={self._param.host};"
|
||||||
|
f"PORT={self._param.port};"
|
||||||
|
f"PROTOCOL=TCPIP;"
|
||||||
|
f"UID={self._param.username};"
|
||||||
|
f"PWD={self._param.password};"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
conn = ibm_db.connect(conn_str, "", "")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("Database Connection Failed! \n" + str(e))
|
||||||
|
|
||||||
|
sql_res = []
|
||||||
|
formalized_content = []
|
||||||
|
for single_sql in sqls:
|
||||||
|
single_sql = single_sql.replace("```", "").strip()
|
||||||
|
if not single_sql:
|
||||||
|
continue
|
||||||
|
single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql)
|
||||||
|
|
||||||
|
stmt = ibm_db.exec_immediate(conn, single_sql)
|
||||||
|
rows = []
|
||||||
|
row = ibm_db.fetch_assoc(stmt)
|
||||||
|
while row and len(rows) < self._param.max_records:
|
||||||
|
rows.append(row)
|
||||||
|
row = ibm_db.fetch_assoc(stmt)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
sql_res.append({"content": "No record in the database!"})
|
||||||
|
continue
|
||||||
|
|
||||||
|
df = pd.DataFrame(rows)
|
||||||
|
for col in df.columns:
|
||||||
|
if pd.api.types.is_datetime64_any_dtype(df[col]):
|
||||||
|
df[col] = df[col].dt.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
df = df.where(pd.notnull(df), None)
|
||||||
|
|
||||||
|
sql_res.append(convert_decimals(df.to_dict(orient="records")))
|
||||||
|
formalized_content.append(df.to_markdown(index=False, floatfmt=".6f"))
|
||||||
|
|
||||||
|
ibm_db.close(conn)
|
||||||
|
|
||||||
|
self.set_output("json", sql_res)
|
||||||
|
self.set_output("formalized_content", "\n\n".join(formalized_content))
|
||||||
|
return self.output("formalized_content")
|
||||||
try:
|
try:
|
||||||
cursor = db.cursor()
|
cursor = db.cursor()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -150,6 +239,8 @@ class ExeSQL(ToolBase, ABC):
|
|||||||
if pd.api.types.is_datetime64_any_dtype(single_res[col]):
|
if pd.api.types.is_datetime64_any_dtype(single_res[col]):
|
||||||
single_res[col] = single_res[col].dt.strftime('%Y-%m-%d')
|
single_res[col] = single_res[col].dt.strftime('%Y-%m-%d')
|
||||||
|
|
||||||
|
single_res = single_res.where(pd.notnull(single_res), None)
|
||||||
|
|
||||||
sql_res.append(convert_decimals(single_res.to_dict(orient='records')))
|
sql_res.append(convert_decimals(single_res.to_dict(orient='records')))
|
||||||
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
|
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
|
||||||
|
|
||||||
|
|||||||
@ -57,7 +57,7 @@ class GitHubParam(ToolParamBase):
|
|||||||
class GitHub(ToolBase, ABC):
|
class GitHub(ToolBase, ABC):
|
||||||
component_name = "GitHub"
|
component_name = "GitHub"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -88,4 +88,4 @@ class GitHub(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!"))
|
return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -116,7 +116,7 @@ class GoogleParam(ToolParamBase):
|
|||||||
class Google(ToolBase, ABC):
|
class Google(ToolBase, ABC):
|
||||||
component_name = "Google"
|
component_name = "Google"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("q"):
|
if not kwargs.get("q"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -154,6 +154,6 @@ class Google(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -63,7 +63,7 @@ class GoogleScholarParam(ToolParamBase):
|
|||||||
class GoogleScholar(ToolBase, ABC):
|
class GoogleScholar(ToolBase, ABC):
|
||||||
component_name = "GoogleScholar"
|
component_name = "GoogleScholar"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -93,4 +93,4 @@ class GoogleScholar(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class PubMedParam(ToolParamBase):
|
|||||||
self.meta:ToolMeta = {
|
self.meta:ToolMeta = {
|
||||||
"name": "pubmed_search",
|
"name": "pubmed_search",
|
||||||
"description": """
|
"description": """
|
||||||
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
|
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
|
||||||
In addition to MEDLINE, PubMed provides access to:
|
In addition to MEDLINE, PubMed provides access to:
|
||||||
- older references from the print version of Index Medicus, back to 1951 and earlier
|
- older references from the print version of Index Medicus, back to 1951 and earlier
|
||||||
- references to some journals before they were indexed in Index Medicus and MEDLINE, for instance Science, BMJ, and Annals of Surgery
|
- references to some journals before they were indexed in Index Medicus and MEDLINE, for instance Science, BMJ, and Annals of Surgery
|
||||||
@ -69,7 +69,7 @@ In addition to MEDLINE, PubMed provides access to:
|
|||||||
class PubMed(ToolBase, ABC):
|
class PubMed(ToolBase, ABC):
|
||||||
component_name = "PubMed"
|
component_name = "PubMed"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -85,13 +85,7 @@ class PubMed(ToolBase, ABC):
|
|||||||
self._retrieve_chunks(pubmedcnt.findall("PubmedArticle"),
|
self._retrieve_chunks(pubmedcnt.findall("PubmedArticle"),
|
||||||
get_title=lambda child: child.find("MedlineCitation").find("Article").find("ArticleTitle").text,
|
get_title=lambda child: child.find("MedlineCitation").find("Article").find("ArticleTitle").text,
|
||||||
get_url=lambda child: "https://pubmed.ncbi.nlm.nih.gov/" + child.find("MedlineCitation").find("PMID").text,
|
get_url=lambda child: "https://pubmed.ncbi.nlm.nih.gov/" + child.find("MedlineCitation").find("PMID").text,
|
||||||
get_content=lambda child: child.find("MedlineCitation") \
|
get_content=lambda child: self._format_pubmed_content(child),)
|
||||||
.find("Article") \
|
|
||||||
.find("Abstract") \
|
|
||||||
.find("AbstractText").text \
|
|
||||||
if child.find("MedlineCitation")\
|
|
||||||
.find("Article").find("Abstract") \
|
|
||||||
else "No abstract available")
|
|
||||||
return self.output("formalized_content")
|
return self.output("formalized_content")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
last_e = e
|
last_e = e
|
||||||
@ -104,5 +98,50 @@ class PubMed(ToolBase, ABC):
|
|||||||
|
|
||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
|
def _format_pubmed_content(self, child):
|
||||||
|
"""Extract structured reference info from PubMed XML"""
|
||||||
|
def safe_find(path):
|
||||||
|
node = child
|
||||||
|
for p in path.split("/"):
|
||||||
|
if node is None:
|
||||||
|
return None
|
||||||
|
node = node.find(p)
|
||||||
|
return node.text if node is not None and node.text else None
|
||||||
|
|
||||||
|
title = safe_find("MedlineCitation/Article/ArticleTitle") or "No title"
|
||||||
|
abstract = safe_find("MedlineCitation/Article/Abstract/AbstractText") or "No abstract available"
|
||||||
|
journal = safe_find("MedlineCitation/Article/Journal/Title") or "Unknown Journal"
|
||||||
|
volume = safe_find("MedlineCitation/Article/Journal/JournalIssue/Volume") or "-"
|
||||||
|
issue = safe_find("MedlineCitation/Article/Journal/JournalIssue/Issue") or "-"
|
||||||
|
pages = safe_find("MedlineCitation/Article/Pagination/MedlinePgn") or "-"
|
||||||
|
|
||||||
|
# Authors
|
||||||
|
authors = []
|
||||||
|
for author in child.findall(".//AuthorList/Author"):
|
||||||
|
lastname = safe_find("LastName") or ""
|
||||||
|
forename = safe_find("ForeName") or ""
|
||||||
|
fullname = f"{forename} {lastname}".strip()
|
||||||
|
if fullname:
|
||||||
|
authors.append(fullname)
|
||||||
|
authors_str = ", ".join(authors) if authors else "Unknown Authors"
|
||||||
|
|
||||||
|
# DOI
|
||||||
|
doi = None
|
||||||
|
for eid in child.findall(".//ArticleId"):
|
||||||
|
if eid.attrib.get("IdType") == "doi":
|
||||||
|
doi = eid.text
|
||||||
|
break
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"Title: {title}\n"
|
||||||
|
f"Authors: {authors_str}\n"
|
||||||
|
f"Journal: {journal}\n"
|
||||||
|
f"Volume: {volume}\n"
|
||||||
|
f"Issue: {issue}\n"
|
||||||
|
f"Pages: {pages}\n"
|
||||||
|
f"DOI: {doi or '-'}\n"
|
||||||
|
f"Abstract: {abstract.strip()}"
|
||||||
|
)
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -23,8 +23,7 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts import kb_prompt
|
from rag.prompts.generator import cross_languages, kb_prompt
|
||||||
from rag.prompts.prompts import cross_languages
|
|
||||||
|
|
||||||
|
|
||||||
class RetrievalParam(ToolParamBase):
|
class RetrievalParam(ToolParamBase):
|
||||||
@ -58,6 +57,7 @@ class RetrievalParam(ToolParamBase):
|
|||||||
self.empty_response = ""
|
self.empty_response = ""
|
||||||
self.use_kg = False
|
self.use_kg = False
|
||||||
self.cross_languages = []
|
self.cross_languages = []
|
||||||
|
self.toc_enhance = False
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
|
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
|
||||||
@ -75,7 +75,7 @@ class RetrievalParam(ToolParamBase):
|
|||||||
class Retrieval(ToolBase, ABC):
|
class Retrieval(ToolBase, ABC):
|
||||||
component_name = "Retrieval"
|
component_name = "Retrieval"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", self._param.empty_response)
|
self.set_output("formalized_content", self._param.empty_response)
|
||||||
@ -122,7 +122,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
|
|
||||||
if kbs:
|
if kbs:
|
||||||
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
||||||
kbinfos = settings.retrievaler.retrieval(
|
kbinfos = settings.retriever.retrieval(
|
||||||
query,
|
query,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
[kb.tenant_id for kb in kbs],
|
[kb.tenant_id for kb in kbs],
|
||||||
@ -135,8 +135,13 @@ class Retrieval(ToolBase, ABC):
|
|||||||
rerank_mdl=rerank_mdl,
|
rerank_mdl=rerank_mdl,
|
||||||
rank_feature=label_question(query, kbs),
|
rank_feature=label_question(query, kbs),
|
||||||
)
|
)
|
||||||
|
if self._param.toc_enhance:
|
||||||
|
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
|
||||||
|
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
|
||||||
|
if cks:
|
||||||
|
kbinfos["chunks"] = cks
|
||||||
if self._param.use_kg:
|
if self._param.use_kg:
|
||||||
ck = settings.kg_retrievaler.retrieval(query,
|
ck = settings.kg_retriever.retrieval(query,
|
||||||
[kb.tenant_id for kb in kbs],
|
[kb.tenant_id for kb in kbs],
|
||||||
kb_ids,
|
kb_ids,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
@ -147,7 +152,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
kbinfos = {"chunks": [], "doc_aggs": []}
|
kbinfos = {"chunks": [], "doc_aggs": []}
|
||||||
|
|
||||||
if self._param.use_kg and kbs:
|
if self._param.use_kg and kbs:
|
||||||
ck = settings.kg_retrievaler.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
ck["content"] = ck["content_with_weight"]
|
ck["content"] = ck["content_with_weight"]
|
||||||
del ck["content_with_weight"]
|
del ck["content_with_weight"]
|
||||||
@ -165,18 +170,18 @@ class Retrieval(ToolBase, ABC):
|
|||||||
|
|
||||||
# Format the chunks for JSON output (similar to how other tools do it)
|
# Format the chunks for JSON output (similar to how other tools do it)
|
||||||
json_output = kbinfos["chunks"].copy()
|
json_output = kbinfos["chunks"].copy()
|
||||||
|
|
||||||
self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"])
|
self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"])
|
||||||
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
|
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
|
||||||
|
|
||||||
# Set both formalized content and JSON output
|
# Set both formalized content and JSON output
|
||||||
self.set_output("formalized_content", form_cnt)
|
self.set_output("formalized_content", form_cnt)
|
||||||
self.set_output("json", json_output)
|
self.set_output("json", json_output)
|
||||||
|
|
||||||
return form_cnt
|
return form_cnt
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -77,7 +77,7 @@ class SearXNGParam(ToolParamBase):
|
|||||||
class SearXNG(ToolBase, ABC):
|
class SearXNG(ToolBase, ABC):
|
||||||
component_name = "SearXNG"
|
component_name = "SearXNG"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
# Gracefully handle try-run without inputs
|
# Gracefully handle try-run without inputs
|
||||||
query = kwargs.get("query")
|
query = kwargs.get("query")
|
||||||
@ -85,7 +85,7 @@ class SearXNG(ToolBase, ABC):
|
|||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
searxng_url = (kwargs.get("searxng_url") or getattr(self._param, "searxng_url", "") or "").strip()
|
searxng_url = (getattr(self._param, "searxng_url", "") or kwargs.get("searxng_url") or "").strip()
|
||||||
# In try-run, if no URL configured, just return empty instead of raising
|
# In try-run, if no URL configured, just return empty instead of raising
|
||||||
if not searxng_url:
|
if not searxng_url:
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -94,7 +94,6 @@ class SearXNG(ToolBase, ABC):
|
|||||||
last_e = ""
|
last_e = ""
|
||||||
for _ in range(self._param.max_retries+1):
|
for _ in range(self._param.max_retries+1):
|
||||||
try:
|
try:
|
||||||
# 构建搜索参数
|
|
||||||
search_params = {
|
search_params = {
|
||||||
'q': query,
|
'q': query,
|
||||||
'format': 'json',
|
'format': 'json',
|
||||||
@ -104,33 +103,29 @@ class SearXNG(ToolBase, ABC):
|
|||||||
'pageno': 1
|
'pageno': 1
|
||||||
}
|
}
|
||||||
|
|
||||||
# 发送搜索请求
|
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{searxng_url}/search",
|
f"{searxng_url}/search",
|
||||||
params=search_params,
|
params=search_params,
|
||||||
timeout=10
|
timeout=10
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
# 验证响应数据
|
|
||||||
if not data or not isinstance(data, dict):
|
if not data or not isinstance(data, dict):
|
||||||
raise ValueError("Invalid response from SearXNG")
|
raise ValueError("Invalid response from SearXNG")
|
||||||
|
|
||||||
results = data.get("results", [])
|
results = data.get("results", [])
|
||||||
if not isinstance(results, list):
|
if not isinstance(results, list):
|
||||||
raise ValueError("Invalid results format from SearXNG")
|
raise ValueError("Invalid results format from SearXNG")
|
||||||
|
|
||||||
# 限制结果数量
|
|
||||||
results = results[:self._param.top_n]
|
results = results[:self._param.top_n]
|
||||||
|
|
||||||
# 处理搜索结果
|
|
||||||
self._retrieve_chunks(results,
|
self._retrieve_chunks(results,
|
||||||
get_title=lambda r: r.get("title", ""),
|
get_title=lambda r: r.get("title", ""),
|
||||||
get_url=lambda r: r.get("url", ""),
|
get_url=lambda r: r.get("url", ""),
|
||||||
get_content=lambda r: r.get("content", ""))
|
get_content=lambda r: r.get("content", ""))
|
||||||
|
|
||||||
self.set_output("json", results)
|
self.set_output("json", results)
|
||||||
return self.output("formalized_content")
|
return self.output("formalized_content")
|
||||||
|
|
||||||
@ -151,6 +146,6 @@ class SearXNG(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Searching with SearXNG for relevant results...
|
Searching with SearXNG for relevant results...
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class TavilySearchParam(ToolParamBase):
|
|||||||
self.meta:ToolMeta = {
|
self.meta:ToolMeta = {
|
||||||
"name": "tavily_search",
|
"name": "tavily_search",
|
||||||
"description": """
|
"description": """
|
||||||
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
|
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
|
||||||
When searching:
|
When searching:
|
||||||
- Start with specific query which should focus on just a single aspect.
|
- Start with specific query which should focus on just a single aspect.
|
||||||
- Number of keywords in query should be less than 5.
|
- Number of keywords in query should be less than 5.
|
||||||
@ -101,7 +101,7 @@ When searching:
|
|||||||
class TavilySearch(ToolBase, ABC):
|
class TavilySearch(ToolBase, ABC):
|
||||||
component_name = "TavilySearch"
|
component_name = "TavilySearch"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -136,7 +136,7 @@ class TavilySearch(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|
||||||
@ -199,7 +199,7 @@ class TavilyExtractParam(ToolParamBase):
|
|||||||
class TavilyExtract(ToolBase, ABC):
|
class TavilyExtract(ToolBase, ABC):
|
||||||
component_name = "TavilyExtract"
|
component_name = "TavilyExtract"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
||||||
last_e = None
|
last_e = None
|
||||||
@ -224,4 +224,4 @@ class TavilyExtract(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))
|
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))
|
||||||
|
|||||||
@ -68,7 +68,7 @@ fund selection platform: through AI technology, is committed to providing excell
|
|||||||
class WenCai(ToolBase, ABC):
|
class WenCai(ToolBase, ABC):
|
||||||
component_name = "WenCai"
|
component_name = "WenCai"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("report", "")
|
self.set_output("report", "")
|
||||||
@ -111,4 +111,4 @@ class WenCai(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!"))
|
return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -64,7 +64,7 @@ class WikipediaParam(ToolParamBase):
|
|||||||
class Wikipedia(ToolBase, ABC):
|
class Wikipedia(ToolBase, ABC):
|
||||||
component_name = "Wikipedia"
|
component_name = "Wikipedia"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("query"):
|
if not kwargs.get("query"):
|
||||||
self.set_output("formalized_content", "")
|
self.set_output("formalized_content", "")
|
||||||
@ -99,6 +99,6 @@ class Wikipedia(ToolBase, ABC):
|
|||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
Looking for the most relevant articles.
|
Looking for the most relevant articles.
|
||||||
""".format(self.get_input().get("query", "-_-!"))
|
""".format(self.get_input().get("query", "-_-!"))
|
||||||
|
|||||||
@ -72,7 +72,7 @@ class YahooFinanceParam(ToolParamBase):
|
|||||||
class YahooFinance(ToolBase, ABC):
|
class YahooFinance(ToolBase, ABC):
|
||||||
component_name = "YahooFinance"
|
component_name = "YahooFinance"
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if not kwargs.get("stock_code"):
|
if not kwargs.get("stock_code"):
|
||||||
self.set_output("report", "")
|
self.set_output("report", "")
|
||||||
@ -111,4 +111,4 @@ class YahooFinance(ToolBase, ABC):
|
|||||||
assert False, self.output()
|
assert False, self.output()
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!"))
|
return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!"))
|
||||||
|
|||||||
@ -27,7 +27,8 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
|||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
from api.db.db_models import close_connection
|
from api.db.db_models import close_connection
|
||||||
from api.db.services import UserService
|
from api.db.services import UserService
|
||||||
from api.utils import CustomJSONEncoder, commands
|
from api.utils.json import CustomJSONEncoder
|
||||||
|
from api.utils import commands
|
||||||
|
|
||||||
from flask_mail import Mail
|
from flask_mail import Mail
|
||||||
from flask_session import Session
|
from flask_session import Session
|
||||||
|
|||||||
@ -39,7 +39,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
|
|||||||
|
|
||||||
from api.utils.file_utils import filename_type, thumbnail
|
from api.utils.file_utils import filename_type, thumbnail
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts import keyword_extraction
|
from rag.prompts.generator import keyword_extraction
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
from api.db.services.canvas_service import UserCanvasService
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
@ -536,7 +536,7 @@ def list_chunks():
|
|||||||
)
|
)
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||||
|
|
||||||
res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
|
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
|
||||||
res = [
|
res = [
|
||||||
{
|
{
|
||||||
"content": res_item["content_with_weight"],
|
"content": res_item["content_with_weight"],
|
||||||
@ -884,7 +884,7 @@ def retrieval():
|
|||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
question += keyword_extraction(chat_mdl, question)
|
||||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||||
similarity_threshold, vector_similarity_weight, top,
|
similarity_threshold, vector_similarity_weight, top,
|
||||||
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
|
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
|
||||||
rank_feature=label_question(question, kbs))
|
rank_feature=label_question(question, kbs))
|
||||||
|
|||||||
@ -19,15 +19,19 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import flask
|
||||||
import trio
|
import trio
|
||||||
from flask import request, Response
|
from flask import request, Response
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
from agent.component import LLM
|
from agent.component import LLM
|
||||||
|
from api import settings
|
||||||
from api.db import CanvasCategory, FileType
|
from api.db import CanvasCategory, FileType
|
||||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||||
|
from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, TaskService
|
||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||||
from api.settings import RetCode
|
from api.settings import RetCode
|
||||||
@ -35,25 +39,19 @@ from api.utils import get_uuid
|
|||||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken, Task
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
||||||
|
from rag.flow.pipeline import Pipeline
|
||||||
|
from rag.nlp import search
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/templates', methods=['GET']) # noqa: F821
|
@manager.route('/templates', methods=['GET']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def templates():
|
def templates():
|
||||||
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.Agent)])
|
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()])
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
|
||||||
@login_required
|
|
||||||
def canvas_list():
|
|
||||||
return get_json_result(data=sorted([c.to_dict() for c in \
|
|
||||||
UserCanvasService.query(user_id=current_user.id, canvas_category=CanvasCategory.Agent)], key=lambda x: x["update_time"]*-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||||
@ -77,9 +75,10 @@ def save():
|
|||||||
if not isinstance(req["dsl"], str):
|
if not isinstance(req["dsl"], str):
|
||||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||||
req["dsl"] = json.loads(req["dsl"])
|
req["dsl"] = json.loads(req["dsl"])
|
||||||
|
cate = req.get("canvas_category", CanvasCategory.Agent)
|
||||||
if "id" not in req:
|
if "id" not in req:
|
||||||
req["user_id"] = current_user.id
|
req["user_id"] = current_user.id
|
||||||
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=CanvasCategory.Agent):
|
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=cate):
|
||||||
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
|
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
|
||||||
req["id"] = get_uuid()
|
req["id"] = get_uuid()
|
||||||
if not UserCanvasService.save(**req):
|
if not UserCanvasService.save(**req):
|
||||||
@ -101,7 +100,7 @@ def save():
|
|||||||
def get(canvas_id):
|
def get(canvas_id):
|
||||||
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||||
return get_data_error_result(message="canvas not found.")
|
return get_data_error_result(message="canvas not found.")
|
||||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||||
return get_json_result(data=c)
|
return get_json_result(data=c)
|
||||||
|
|
||||||
|
|
||||||
@ -148,6 +147,14 @@ def run():
|
|||||||
if not isinstance(cvs.dsl, str):
|
if not isinstance(cvs.dsl, str):
|
||||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
|
|
||||||
|
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||||
|
task_id = get_uuid()
|
||||||
|
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||||
|
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
|
||||||
|
if not ok:
|
||||||
|
return get_data_error_result(message=error_message)
|
||||||
|
return get_json_result(data={"message_id": task_id})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
canvas = Canvas(cvs.dsl, current_user.id, req["id"])
|
canvas = Canvas(cvs.dsl, current_user.id, req["id"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -173,6 +180,44 @@ def run():
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
||||||
|
@validate_request("id", "dsl", "component_id")
|
||||||
|
@login_required
|
||||||
|
def rerun():
|
||||||
|
req = request.json
|
||||||
|
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||||
|
if not doc:
|
||||||
|
return get_data_error_result(message="Document not found.")
|
||||||
|
doc = doc[0]
|
||||||
|
if 0 < doc["progress"] < 1:
|
||||||
|
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
||||||
|
|
||||||
|
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
|
||||||
|
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
||||||
|
doc["progress_msg"] = ""
|
||||||
|
doc["chunk_num"] = 0
|
||||||
|
doc["token_num"] = 0
|
||||||
|
DocumentService.clear_chunk_num_when_rerun(doc["id"])
|
||||||
|
DocumentService.update_by_id(id, doc)
|
||||||
|
TaskService.filter_delete([Task.doc_id == id])
|
||||||
|
|
||||||
|
dsl = req["dsl"]
|
||||||
|
dsl["path"] = [req["component_id"]]
|
||||||
|
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
|
||||||
|
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def cancel(task_id):
|
||||||
|
try:
|
||||||
|
REDIS_CONN.set(f"{task_id}-cancel", "x")
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/reset', methods=['POST']) # noqa: F821
|
@manager.route('/reset', methods=['POST']) # noqa: F821
|
||||||
@validate_request("id")
|
@validate_request("id")
|
||||||
@login_required
|
@login_required
|
||||||
@ -198,7 +243,7 @@ def reset():
|
|||||||
|
|
||||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||||
def upload(canvas_id):
|
def upload(canvas_id):
|
||||||
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
|
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="canvas not found.")
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
|
||||||
@ -348,6 +393,65 @@ def test_db_connect():
|
|||||||
cursor = db.cursor()
|
cursor = db.cursor()
|
||||||
cursor.execute("SELECT 1")
|
cursor.execute("SELECT 1")
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
elif req["db_type"] == 'IBM DB2':
|
||||||
|
import ibm_db
|
||||||
|
conn_str = (
|
||||||
|
f"DATABASE={req['database']};"
|
||||||
|
f"HOSTNAME={req['host']};"
|
||||||
|
f"PORT={req['port']};"
|
||||||
|
f"PROTOCOL=TCPIP;"
|
||||||
|
f"UID={req['username']};"
|
||||||
|
f"PWD={req['password']};"
|
||||||
|
)
|
||||||
|
logging.info(conn_str)
|
||||||
|
conn = ibm_db.connect(conn_str, "", "")
|
||||||
|
stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1")
|
||||||
|
ibm_db.fetch_assoc(stmt)
|
||||||
|
ibm_db.close(conn)
|
||||||
|
return get_json_result(data="Database Connection Successful!")
|
||||||
|
elif req["db_type"] == 'trino':
|
||||||
|
def _parse_catalog_schema(db: str):
|
||||||
|
if not db:
|
||||||
|
return None, None
|
||||||
|
if "." in db:
|
||||||
|
c, s = db.split(".", 1)
|
||||||
|
elif "/" in db:
|
||||||
|
c, s = db.split("/", 1)
|
||||||
|
else:
|
||||||
|
c, s = db, "default"
|
||||||
|
return c, s
|
||||||
|
try:
|
||||||
|
import trino
|
||||||
|
import os
|
||||||
|
from trino.auth import BasicAuthentication
|
||||||
|
except Exception:
|
||||||
|
return server_error_response("Missing dependency 'trino'. Please install: pip install trino")
|
||||||
|
|
||||||
|
catalog, schema = _parse_catalog_schema(req["database"])
|
||||||
|
if not catalog:
|
||||||
|
return server_error_response("For Trino, 'database' must be 'catalog.schema' or at least 'catalog'.")
|
||||||
|
|
||||||
|
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
|
||||||
|
|
||||||
|
auth = None
|
||||||
|
if http_scheme == "https" and req.get("password"):
|
||||||
|
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||||
|
|
||||||
|
conn = trino.dbapi.connect(
|
||||||
|
host=req["host"],
|
||||||
|
port=int(req["port"] or 8080),
|
||||||
|
user=req["username"] or "ragflow",
|
||||||
|
catalog=catalog,
|
||||||
|
schema=schema or "default",
|
||||||
|
http_scheme=http_scheme,
|
||||||
|
auth=auth
|
||||||
|
)
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute("SELECT 1")
|
||||||
|
cur.fetchall()
|
||||||
|
cur.close()
|
||||||
|
conn.close()
|
||||||
|
return get_json_result(data="Database Connection Successful!")
|
||||||
else:
|
else:
|
||||||
return server_error_response("Unsupported database type.")
|
return server_error_response("Unsupported database type.")
|
||||||
if req["db_type"] != 'mssql':
|
if req["db_type"] != 'mssql':
|
||||||
@ -383,22 +487,32 @@ def getversion( version_id):
|
|||||||
return get_json_result(data=f"Error getting history file: {e}")
|
return get_json_result(data=f"Error getting history file: {e}")
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/listteam', methods=['GET']) # noqa: F821
|
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def list_canvas():
|
def list_canvas():
|
||||||
keywords = request.args.get("keywords", "")
|
keywords = request.args.get("keywords", "")
|
||||||
page_number = int(request.args.get("page", 1))
|
page_number = int(request.args.get("page", 0))
|
||||||
items_per_page = int(request.args.get("page_size", 150))
|
items_per_page = int(request.args.get("page_size", 0))
|
||||||
orderby = request.args.get("orderby", "create_time")
|
orderby = request.args.get("orderby", "create_time")
|
||||||
desc = request.args.get("desc", True)
|
canvas_category = request.args.get("canvas_category")
|
||||||
try:
|
if request.args.get("desc", "true").lower() == "false":
|
||||||
|
desc = False
|
||||||
|
else:
|
||||||
|
desc = True
|
||||||
|
owner_ids = [id for id in request.args.get("owner_ids", "").strip().split(",") if id]
|
||||||
|
if not owner_ids:
|
||||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||||
|
tenants = [m["tenant_id"] for m in tenants]
|
||||||
|
tenants.append(current_user.id)
|
||||||
canvas, total = UserCanvasService.get_by_tenant_ids(
|
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||||
[m["tenant_id"] for m in tenants], current_user.id, page_number,
|
tenants, current_user.id, page_number,
|
||||||
items_per_page, orderby, desc, keywords, canvas_category=CanvasCategory.Agent)
|
items_per_page, orderby, desc, keywords, canvas_category)
|
||||||
return get_json_result(data={"canvas": canvas, "total": total})
|
else:
|
||||||
except Exception as e:
|
tenants = owner_ids
|
||||||
return server_error_response(e)
|
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||||
|
tenants, current_user.id, 0,
|
||||||
|
0, orderby, desc, keywords, canvas_category)
|
||||||
|
return get_json_result(data={"canvas": canvas, "total": total})
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/setting', methods=['POST']) # noqa: F821
|
@manager.route('/setting', methods=['POST']) # noqa: F821
|
||||||
@ -474,7 +588,7 @@ def sessions(canvas_id):
|
|||||||
@manager.route('/prompts', methods=['GET']) # noqa: F821
|
@manager.route('/prompts', methods=['GET']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def prompts():
|
def prompts():
|
||||||
from rag.prompts.prompts import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
||||||
return get_json_result(data={
|
return get_json_result(data={
|
||||||
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
||||||
"plan_generation": NEXT_STEP,
|
"plan_generation": NEXT_STEP,
|
||||||
@ -483,3 +597,11 @@ def prompts():
|
|||||||
#"context_ranking": RANK_MEMORY,
|
#"context_ranking": RANK_MEMORY,
|
||||||
"citation_guidelines": CITATION_PROMPT_TEMPLATE
|
"citation_guidelines": CITATION_PROMPT_TEMPLATE
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/download', methods=['GET']) # noqa: F821
|
||||||
|
def download():
|
||||||
|
id = request.args.get("id")
|
||||||
|
created_by = request.args.get("created_by")
|
||||||
|
blob = FileService.get_blob(created_by, id)
|
||||||
|
return flask.make_response(blob)
|
||||||
@ -33,8 +33,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, server_e
|
|||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
from rag.prompts import cross_languages, keyword_extraction
|
from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
|
||||||
from rag.prompts.prompts import gen_meta_filter
|
|
||||||
from rag.settings import PAGERANK_FLD
|
from rag.settings import PAGERANK_FLD
|
||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
|
|
||||||
@ -61,7 +60,7 @@ def list_chunk():
|
|||||||
}
|
}
|
||||||
if "available_int" in req:
|
if "available_int" in req:
|
||||||
query["available_int"] = int(req["available_int"])
|
query["available_int"] = int(req["available_int"])
|
||||||
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
||||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||||
for id in sres.ids:
|
for id in sres.ids:
|
||||||
d = {
|
d = {
|
||||||
@ -347,7 +346,7 @@ def retrieval_test():
|
|||||||
question += keyword_extraction(chat_mdl, question)
|
question += keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
labels = label_question(question, [kb])
|
labels = label_question(question, [kb])
|
||||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||||
float(req.get("similarity_threshold", 0.0)),
|
float(req.get("similarity_threshold", 0.0)),
|
||||||
float(req.get("vector_similarity_weight", 0.3)),
|
float(req.get("vector_similarity_weight", 0.3)),
|
||||||
top,
|
top,
|
||||||
@ -355,7 +354,7 @@ def retrieval_test():
|
|||||||
rank_feature=labels
|
rank_feature=labels
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retrievaler.retrieval(question,
|
ck = settings.kg_retriever.retrieval(question,
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
kb_ids,
|
kb_ids,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
@ -385,7 +384,7 @@ def knowledge_graph():
|
|||||||
"doc_ids": [doc_id],
|
"doc_ids": [doc_id],
|
||||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||||
}
|
}
|
||||||
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
for id in sres.ids[:2]:
|
for id in sres.ids[:2]:
|
||||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||||
|
|||||||
@ -15,7 +15,7 @@
|
|||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import traceback
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from flask import Response, request
|
from flask import Response, request
|
||||||
from flask_login import current_user, login_required
|
from flask_login import current_user, login_required
|
||||||
@ -29,8 +29,8 @@ from api.db.services.search_service import SearchService
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||||
from rag.prompts.prompt_template import load_prompt
|
from rag.prompts.template import load_prompt
|
||||||
from rag.prompts.prompts import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||||
@ -226,7 +226,7 @@ def completion():
|
|||||||
if not is_embedded:
|
if not is_embedded:
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
logging.exception(e)
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
|
|||||||
@ -1,353 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import trio
|
|
||||||
from flask import request
|
|
||||||
from flask_login import current_user, login_required
|
|
||||||
|
|
||||||
from agent.canvas import Canvas
|
|
||||||
from agent.component import LLM
|
|
||||||
from api.db import CanvasCategory, FileType
|
|
||||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
|
||||||
from api.db.services.document_service import DocumentService
|
|
||||||
from api.db.services.file_service import FileService
|
|
||||||
from api.db.services.task_service import queue_dataflow
|
|
||||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
|
||||||
from api.db.services.user_service import TenantService
|
|
||||||
from api.settings import RetCode
|
|
||||||
from api.utils import get_uuid
|
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
|
||||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
|
||||||
from rag.flow.pipeline import Pipeline
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/templates", methods=["GET"]) # noqa: F821
|
|
||||||
@login_required
|
|
||||||
def templates():
|
|
||||||
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.DataFlow)])
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
|
||||||
@login_required
|
|
||||||
def canvas_list():
|
|
||||||
return get_json_result(data=sorted([c.to_dict() for c in UserCanvasService.query(user_id=current_user.id, canvas_category=CanvasCategory.DataFlow)], key=lambda x: x["update_time"] * -1))
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
|
||||||
@validate_request("canvas_ids")
|
|
||||||
@login_required
|
|
||||||
def rm():
|
|
||||||
for i in request.json["canvas_ids"]:
|
|
||||||
if not UserCanvasService.accessible(i, current_user.id):
|
|
||||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
|
||||||
UserCanvasService.delete_by_id(i)
|
|
||||||
return get_json_result(data=True)
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
|
||||||
@validate_request("dsl", "title")
|
|
||||||
@login_required
|
|
||||||
def save():
|
|
||||||
req = request.json
|
|
||||||
if not isinstance(req["dsl"], str):
|
|
||||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
|
||||||
req["dsl"] = json.loads(req["dsl"])
|
|
||||||
req["canvas_category"] = CanvasCategory.DataFlow
|
|
||||||
if "id" not in req:
|
|
||||||
req["user_id"] = current_user.id
|
|
||||||
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=CanvasCategory.DataFlow):
|
|
||||||
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
|
|
||||||
req["id"] = get_uuid()
|
|
||||||
|
|
||||||
if not UserCanvasService.save(**req):
|
|
||||||
return get_data_error_result(message="Fail to save canvas.")
|
|
||||||
else:
|
|
||||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
|
||||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
|
||||||
UserCanvasService.update_by_id(req["id"], req)
|
|
||||||
# save version
|
|
||||||
UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
|
|
||||||
UserCanvasVersionService.delete_all_versions(req["id"])
|
|
||||||
return get_json_result(data=req)
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/get/<canvas_id>", methods=["GET"]) # noqa: F821
|
|
||||||
@login_required
|
|
||||||
def get(canvas_id):
|
|
||||||
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
|
||||||
return get_data_error_result(message="canvas not found.")
|
|
||||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
|
||||||
return get_json_result(data=c)
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/run", methods=["POST"]) # noqa: F821
|
|
||||||
@validate_request("id")
|
|
||||||
@login_required
|
|
||||||
def run():
|
|
||||||
req = request.json
|
|
||||||
flow_id = req.get("id", "")
|
|
||||||
doc_id = req.get("doc_id", "")
|
|
||||||
if not all([flow_id, doc_id]):
|
|
||||||
return get_data_error_result(message="id and doc_id are required.")
|
|
||||||
|
|
||||||
if not DocumentService.get_by_id(doc_id):
|
|
||||||
return get_data_error_result(message=f"Document for {doc_id} not found.")
|
|
||||||
|
|
||||||
user_id = req.get("user_id", current_user.id)
|
|
||||||
if not UserCanvasService.accessible(flow_id, current_user.id):
|
|
||||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
|
||||||
|
|
||||||
e, cvs = UserCanvasService.get_by_id(flow_id)
|
|
||||||
if not e:
|
|
||||||
return get_data_error_result(message="canvas not found.")
|
|
||||||
|
|
||||||
if not isinstance(cvs.dsl, str):
|
|
||||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
|
||||||
|
|
||||||
task_id = get_uuid()
|
|
||||||
|
|
||||||
ok, error_message = queue_dataflow(dsl=cvs.dsl, tenant_id=user_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id, priority=0)
|
|
||||||
if not ok:
|
|
||||||
return server_error_response(error_message)
|
|
||||||
|
|
||||||
return get_json_result(data={"task_id": task_id, "flow_id": flow_id})
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/reset", methods=["POST"]) # noqa: F821
|
|
||||||
@validate_request("id")
|
|
||||||
@login_required
|
|
||||||
def reset():
|
|
||||||
req = request.json
|
|
||||||
flow_id = req.get("id", "")
|
|
||||||
if not flow_id:
|
|
||||||
return get_data_error_result(message="id is required.")
|
|
||||||
|
|
||||||
if not UserCanvasService.accessible(flow_id, current_user.id):
|
|
||||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
|
||||||
|
|
||||||
task_id = req.get("task_id", "")
|
|
||||||
|
|
||||||
try:
|
|
||||||
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
|
||||||
if not e:
|
|
||||||
return get_data_error_result(message="canvas not found.")
|
|
||||||
|
|
||||||
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id=task_id)
|
|
||||||
dataflow.reset()
|
|
||||||
req["dsl"] = json.loads(str(dataflow))
|
|
||||||
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
|
|
||||||
return get_json_result(data=req["dsl"])
|
|
||||||
except Exception as e:
|
|
||||||
return server_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
|
||||||
def upload(canvas_id):
|
|
||||||
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
|
|
||||||
if not e:
|
|
||||||
return get_data_error_result(message="canvas not found.")
|
|
||||||
|
|
||||||
user_id = cvs["user_id"]
|
|
||||||
|
|
||||||
def structured(filename, filetype, blob, content_type):
|
|
||||||
nonlocal user_id
|
|
||||||
if filetype == FileType.PDF.value:
|
|
||||||
blob = read_potential_broken_pdf(blob)
|
|
||||||
|
|
||||||
location = get_uuid()
|
|
||||||
FileService.put_blob(user_id, location, blob)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": location,
|
|
||||||
"name": filename,
|
|
||||||
"size": sys.getsizeof(blob),
|
|
||||||
"extension": filename.split(".")[-1].lower(),
|
|
||||||
"mime_type": content_type,
|
|
||||||
"created_by": user_id,
|
|
||||||
"created_at": time.time(),
|
|
||||||
"preview_url": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
if request.args.get("url"):
|
|
||||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CrawlResult, DefaultMarkdownGenerator, PruningContentFilter
|
|
||||||
|
|
||||||
try:
|
|
||||||
url = request.args.get("url")
|
|
||||||
filename = re.sub(r"\?.*", "", url.split("/")[-1])
|
|
||||||
|
|
||||||
async def adownload():
|
|
||||||
browser_config = BrowserConfig(
|
|
||||||
headless=True,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
|
||||||
crawler_config = CrawlerRunConfig(markdown_generator=DefaultMarkdownGenerator(content_filter=PruningContentFilter()), pdf=True, screenshot=False)
|
|
||||||
result: CrawlResult = await crawler.arun(url=url, config=crawler_config)
|
|
||||||
return result
|
|
||||||
|
|
||||||
page = trio.run(adownload())
|
|
||||||
if page.pdf:
|
|
||||||
if filename.split(".")[-1].lower() != "pdf":
|
|
||||||
filename += ".pdf"
|
|
||||||
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
|
|
||||||
|
|
||||||
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return server_error_response(e)
|
|
||||||
|
|
||||||
file = request.files["file"]
|
|
||||||
try:
|
|
||||||
DocumentService.check_doc_health(user_id, file.filename)
|
|
||||||
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
|
|
||||||
except Exception as e:
|
|
||||||
return server_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/input_form", methods=["GET"]) # noqa: F821
|
|
||||||
@login_required
|
|
||||||
def input_form():
|
|
||||||
flow_id = request.args.get("id")
|
|
||||||
cpn_id = request.args.get("component_id")
|
|
||||||
try:
|
|
||||||
e, user_canvas = UserCanvasService.get_by_id(flow_id)
|
|
||||||
if not e:
|
|
||||||
return get_data_error_result(message="canvas not found.")
|
|
||||||
if not UserCanvasService.query(user_id=current_user.id, id=flow_id):
|
|
||||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
|
||||||
|
|
||||||
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id="")
|
|
||||||
|
|
||||||
return get_json_result(data=dataflow.get_component_input_form(cpn_id))
|
|
||||||
except Exception as e:
|
|
||||||
return server_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/debug", methods=["POST"]) # noqa: F821
|
|
||||||
@validate_request("id", "component_id", "params")
|
|
||||||
@login_required
|
|
||||||
def debug():
|
|
||||||
req = request.json
|
|
||||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
|
||||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
|
||||||
try:
|
|
||||||
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
|
||||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
|
||||||
canvas.reset()
|
|
||||||
canvas.message_id = get_uuid()
|
|
||||||
component = canvas.get_component(req["component_id"])["obj"]
|
|
||||||
component.reset()
|
|
||||||
|
|
||||||
if isinstance(component, LLM):
|
|
||||||
component.set_debug_inputs(req["params"])
|
|
||||||
component.invoke(**{k: o["value"] for k, o in req["params"].items()})
|
|
||||||
outputs = component.output()
|
|
||||||
for k in outputs.keys():
|
|
||||||
if isinstance(outputs[k], partial):
|
|
||||||
txt = ""
|
|
||||||
for c in outputs[k]():
|
|
||||||
txt += c
|
|
||||||
outputs[k] = txt
|
|
||||||
return get_json_result(data=outputs)
|
|
||||||
except Exception as e:
|
|
||||||
return server_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
# api get list version dsl of canvas
|
|
||||||
@manager.route("/getlistversion/<canvas_id>", methods=["GET"]) # noqa: F821
|
|
||||||
@login_required
|
|
||||||
def getlistversion(canvas_id):
|
|
||||||
try:
|
|
||||||
list = sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"] * -1)
|
|
||||||
return get_json_result(data=list)
|
|
||||||
except Exception as e:
|
|
||||||
return get_data_error_result(message=f"Error getting history files: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# api get version dsl of canvas
|
|
||||||
@manager.route("/getversion/<version_id>", methods=["GET"]) # noqa: F821
|
|
||||||
@login_required
|
|
||||||
def getversion(version_id):
|
|
||||||
try:
|
|
||||||
e, version = UserCanvasVersionService.get_by_id(version_id)
|
|
||||||
if version:
|
|
||||||
return get_json_result(data=version.to_dict())
|
|
||||||
except Exception as e:
|
|
||||||
return get_json_result(data=f"Error getting history file: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/listteam", methods=["GET"]) # noqa: F821
|
|
||||||
@login_required
|
|
||||||
def list_canvas():
|
|
||||||
keywords = request.args.get("keywords", "")
|
|
||||||
page_number = int(request.args.get("page", 1))
|
|
||||||
items_per_page = int(request.args.get("page_size", 150))
|
|
||||||
orderby = request.args.get("orderby", "create_time")
|
|
||||||
desc = request.args.get("desc", True)
|
|
||||||
try:
|
|
||||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
|
||||||
canvas, total = UserCanvasService.get_by_tenant_ids(
|
|
||||||
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords, canvas_category=CanvasCategory.DataFlow
|
|
||||||
)
|
|
||||||
return get_json_result(data={"canvas": canvas, "total": total})
|
|
||||||
except Exception as e:
|
|
||||||
return server_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/setting", methods=["POST"]) # noqa: F821
|
|
||||||
@validate_request("id", "title", "permission")
|
|
||||||
@login_required
|
|
||||||
def setting():
|
|
||||||
req = request.json
|
|
||||||
req["user_id"] = current_user.id
|
|
||||||
|
|
||||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
|
||||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
|
||||||
|
|
||||||
e, flow = UserCanvasService.get_by_id(req["id"])
|
|
||||||
if not e:
|
|
||||||
return get_data_error_result(message="canvas not found.")
|
|
||||||
flow = flow.to_dict()
|
|
||||||
flow["title"] = req["title"]
|
|
||||||
for key in ("description", "permission", "avatar"):
|
|
||||||
if value := req.get(key):
|
|
||||||
flow[key] = value
|
|
||||||
|
|
||||||
num = UserCanvasService.update_by_id(req["id"], flow)
|
|
||||||
return get_json_result(data=num)
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/trace", methods=["GET"]) # noqa: F821
|
|
||||||
def trace():
|
|
||||||
dataflow_id = request.args.get("dataflow_id")
|
|
||||||
task_id = request.args.get("task_id")
|
|
||||||
if not all([dataflow_id, task_id]):
|
|
||||||
return get_data_error_result(message="dataflow_id and task_id are required.")
|
|
||||||
|
|
||||||
e, dataflow_canvas = UserCanvasService.get_by_id(dataflow_id)
|
|
||||||
if not e:
|
|
||||||
return get_data_error_result(message="dataflow not found.")
|
|
||||||
|
|
||||||
dsl_str = json.dumps(dataflow_canvas.dsl, ensure_ascii=False)
|
|
||||||
dataflow = Pipeline(dsl=dsl_str, tenant_id=dataflow_canvas.user_id, flow_id=dataflow_id, task_id=task_id)
|
|
||||||
log = dataflow.fetch_logs()
|
|
||||||
|
|
||||||
return get_json_result(data=log)
|
|
||||||
@ -24,6 +24,7 @@ from flask import request
|
|||||||
from flask_login import current_user, login_required
|
from flask_login import current_user, login_required
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
|
from api.common.check_team_permission import check_kb_team_permission
|
||||||
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
||||||
from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileSource, FileType, ParserType, TaskStatus
|
from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileSource, FileType, ParserType, TaskStatus
|
||||||
from api.db.db_models import File, Task
|
from api.db.db_models import File, Task
|
||||||
@ -32,7 +33,7 @@ from api.db.services.document_service import DocumentService, doc_upload_and_par
|
|||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks
|
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks, queue_dataflow
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import (
|
from api.utils.api_utils import (
|
||||||
@ -68,8 +69,10 @@ def upload():
|
|||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
raise LookupError("Can't find this knowledgebase!")
|
raise LookupError("Can't find this knowledgebase!")
|
||||||
err, files = FileService.upload_document(kb, file_objs, current_user.id)
|
if not check_kb_team_permission(kb, current_user.id):
|
||||||
|
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
|
err, files = FileService.upload_document(kb, file_objs, current_user.id)
|
||||||
if err:
|
if err:
|
||||||
return get_json_result(data=files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
return get_json_result(data=files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
@ -94,6 +97,8 @@ def web_crawl():
|
|||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
raise LookupError("Can't find this knowledgebase!")
|
raise LookupError("Can't find this knowledgebase!")
|
||||||
|
if check_kb_team_permission(kb, current_user.id):
|
||||||
|
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
blob = html2pdf(url)
|
blob = html2pdf(url)
|
||||||
if not blob:
|
if not blob:
|
||||||
@ -182,6 +187,7 @@ def create():
|
|||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"kb_id": kb.id,
|
"kb_id": kb.id,
|
||||||
"parser_id": kb.parser_id,
|
"parser_id": kb.parser_id,
|
||||||
|
"pipeline_id": kb.pipeline_id,
|
||||||
"parser_config": kb.parser_config,
|
"parser_config": kb.parser_config,
|
||||||
"created_by": current_user.id,
|
"created_by": current_user.id,
|
||||||
"type": FileType.VIRTUAL,
|
"type": FileType.VIRTUAL,
|
||||||
@ -479,8 +485,11 @@ def run():
|
|||||||
kb_table_num_map[kb_id] = count
|
kb_table_num_map[kb_id] = count
|
||||||
if kb_table_num_map[kb_id] <= 0:
|
if kb_table_num_map[kb_id] <= 0:
|
||||||
KnowledgebaseService.delete_field_map(kb_id)
|
KnowledgebaseService.delete_field_map(kb_id)
|
||||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
if doc.get("pipeline_id", ""):
|
||||||
queue_tasks(doc, bucket, name, 0)
|
queue_dataflow(tenant_id, flow_id=doc["pipeline_id"], task_id=get_uuid(), doc_id=id)
|
||||||
|
else:
|
||||||
|
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||||
|
queue_tasks(doc, bucket, name, 0)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -546,31 +555,22 @@ def get(doc_id):
|
|||||||
|
|
||||||
@manager.route("/change_parser", methods=["POST"]) # noqa: F821
|
@manager.route("/change_parser", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_id", "parser_id")
|
@validate_request("doc_id")
|
||||||
def change_parser():
|
def change_parser():
|
||||||
req = request.json
|
|
||||||
|
|
||||||
|
req = request.json
|
||||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
try:
|
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="Document not found!")
|
||||||
|
|
||||||
|
def reset_doc():
|
||||||
|
nonlocal doc
|
||||||
|
e = DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"], "parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": TaskStatus.UNSTART.value})
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
if doc.parser_id.lower() == req["parser_id"].lower():
|
|
||||||
if "parser_config" in req:
|
|
||||||
if req["parser_config"] == doc.parser_config:
|
|
||||||
return get_json_result(data=True)
|
|
||||||
else:
|
|
||||||
return get_json_result(data=True)
|
|
||||||
|
|
||||||
if (doc.type == FileType.VISUAL and req["parser_id"] != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation"):
|
|
||||||
return get_data_error_result(message="Not supported yet!")
|
|
||||||
|
|
||||||
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": TaskStatus.UNSTART.value})
|
|
||||||
if not e:
|
|
||||||
return get_data_error_result(message="Document not found!")
|
|
||||||
if "parser_config" in req:
|
|
||||||
DocumentService.update_parser_config(doc.id, req["parser_config"])
|
|
||||||
if doc.token_num > 0:
|
if doc.token_num > 0:
|
||||||
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duration * -1)
|
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duration * -1)
|
||||||
if not e:
|
if not e:
|
||||||
@ -581,6 +581,26 @@ def change_parser():
|
|||||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if "pipeline_id" in req and req["pipeline_id"] != "":
|
||||||
|
if doc.pipeline_id == req["pipeline_id"]:
|
||||||
|
return get_json_result(data=True)
|
||||||
|
DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"]})
|
||||||
|
reset_doc()
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
if doc.parser_id.lower() == req["parser_id"].lower():
|
||||||
|
if "parser_config" in req:
|
||||||
|
if req["parser_config"] == doc.parser_config:
|
||||||
|
return get_json_result(data=True)
|
||||||
|
else:
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
if (doc.type == FileType.VISUAL and req["parser_id"] != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation"):
|
||||||
|
return get_data_error_result(message="Not supported yet!")
|
||||||
|
if "parser_config" in req:
|
||||||
|
DocumentService.update_parser_config(doc.id, req["parser_config"])
|
||||||
|
reset_doc()
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import flask
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
|
from api.common.check_team_permission import check_file_team_permission
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
@ -246,6 +247,8 @@ def rm():
|
|||||||
return get_data_error_result(message="File or Folder not found!")
|
return get_data_error_result(message="File or Folder not found!")
|
||||||
if not file.tenant_id:
|
if not file.tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
if not check_file_team_permission(file, current_user.id):
|
||||||
|
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -292,6 +295,8 @@ def rename():
|
|||||||
e, file = FileService.get_by_id(req["file_id"])
|
e, file = FileService.get_by_id(req["file_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="File not found!")
|
return get_data_error_result(message="File not found!")
|
||||||
|
if not check_file_team_permission(file, current_user.id):
|
||||||
|
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
if file.type != FileType.FOLDER.value \
|
if file.type != FileType.FOLDER.value \
|
||||||
and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||||
file.name.lower()).suffix:
|
file.name.lower()).suffix:
|
||||||
@ -328,6 +333,8 @@ def get(file_id):
|
|||||||
e, file = FileService.get_by_id(file_id)
|
e, file = FileService.get_by_id(file_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
|
if not check_file_team_permission(file, current_user.id):
|
||||||
|
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
blob = STORAGE_IMPL.get(file.parent_id, file.location)
|
blob = STORAGE_IMPL.get(file.parent_id, file.location)
|
||||||
if not blob:
|
if not blob:
|
||||||
@ -367,6 +374,8 @@ def move():
|
|||||||
return get_data_error_result(message="File or Folder not found!")
|
return get_data_error_result(message="File or Folder not found!")
|
||||||
if not file.tenant_id:
|
if not file.tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
if not check_file_team_permission(file, current_user.id):
|
||||||
|
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
fe, _ = FileService.get_by_id(parent_id)
|
fe, _ = FileService.get_by_id(parent_id)
|
||||||
if not fe:
|
if not fe:
|
||||||
return get_data_error_result(message="Parent Folder not found!")
|
return get_data_error_result(message="Parent Folder not found!")
|
||||||
|
|||||||
@ -14,18 +14,21 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||||
|
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.db import StatusEnum, FileSource
|
from api.db import PipelineTaskType, StatusEnum, FileSource, VALID_FILE_TYPES, VALID_TASK_STATUS
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.db_models import File
|
from api.db.db_models import File
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
@ -33,6 +36,7 @@ from api import settings
|
|||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
from api.constants import DATASET_NAME_LIMIT
|
from api.constants import DATASET_NAME_LIMIT
|
||||||
from rag.settings import PAGERANK_FLD
|
from rag.settings import PAGERANK_FLD
|
||||||
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
|
|
||||||
@ -61,10 +65,39 @@ def create():
|
|||||||
req["name"] = dataset_name
|
req["name"] = dataset_name
|
||||||
req["tenant_id"] = current_user.id
|
req["tenant_id"] = current_user.id
|
||||||
req["created_by"] = current_user.id
|
req["created_by"] = current_user.id
|
||||||
|
if not req.get("parser_id"):
|
||||||
|
req["parser_id"] = "naive"
|
||||||
e, t = TenantService.get_by_id(current_user.id)
|
e, t = TenantService.get_by_id(current_user.id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Tenant not found.")
|
return get_data_error_result(message="Tenant not found.")
|
||||||
req["embd_id"] = t.embd_id
|
req["parser_config"] = {
|
||||||
|
"layout_recognize": "DeepDOC",
|
||||||
|
"chunk_token_num": 512,
|
||||||
|
"delimiter": "\n",
|
||||||
|
"auto_keywords": 0,
|
||||||
|
"auto_questions": 0,
|
||||||
|
"html4excel": False,
|
||||||
|
"topn_tags": 3,
|
||||||
|
"raptor": {
|
||||||
|
"use_raptor": True,
|
||||||
|
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
|
||||||
|
"max_token": 256,
|
||||||
|
"threshold": 0.1,
|
||||||
|
"max_cluster": 64,
|
||||||
|
"random_seed": 0
|
||||||
|
},
|
||||||
|
"graphrag": {
|
||||||
|
"use_graphrag": True,
|
||||||
|
"entity_types": [
|
||||||
|
"organization",
|
||||||
|
"person",
|
||||||
|
"geo",
|
||||||
|
"event",
|
||||||
|
"category"
|
||||||
|
],
|
||||||
|
"method": "light"
|
||||||
|
}
|
||||||
|
}
|
||||||
if not KnowledgebaseService.save(**req):
|
if not KnowledgebaseService.save(**req):
|
||||||
return get_data_error_result()
|
return get_data_error_result()
|
||||||
return get_json_result(data={"kb_id": req["id"]})
|
return get_json_result(data={"kb_id": req["id"]})
|
||||||
@ -155,6 +188,9 @@ def detail():
|
|||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Can't find this knowledgebase!")
|
message="Can't find this knowledgebase!")
|
||||||
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
|
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
|
||||||
|
for key in ["graphrag_task_finish_at", "raptor_task_finish_at", "mindmap_task_finish_at"]:
|
||||||
|
if finish_at := kb.get(key):
|
||||||
|
kb[key] = finish_at.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
return get_json_result(data=kb)
|
return get_json_result(data=kb)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -250,7 +286,7 @@ def list_tags(kb_id):
|
|||||||
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
||||||
tags = []
|
tags = []
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
tags += settings.retrievaler.all_tags(tenant["tenant_id"], [kb_id])
|
tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id])
|
||||||
return get_json_result(data=tags)
|
return get_json_result(data=tags)
|
||||||
|
|
||||||
|
|
||||||
@ -269,7 +305,7 @@ def list_tags_from_kbs():
|
|||||||
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
||||||
tags = []
|
tags = []
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
tags += settings.retrievaler.all_tags(tenant["tenant_id"], kb_ids)
|
tags += settings.retriever.all_tags(tenant["tenant_id"], kb_ids)
|
||||||
return get_json_result(data=tags)
|
return get_json_result(data=tags)
|
||||||
|
|
||||||
|
|
||||||
@ -330,7 +366,7 @@ def knowledge_graph(kb_id):
|
|||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
|
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
|
||||||
return get_json_result(data=obj)
|
return get_json_result(data=obj)
|
||||||
sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [kb_id])
|
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||||
if not len(sres.ids):
|
if not len(sres.ids):
|
||||||
return get_json_result(data=obj)
|
return get_json_result(data=obj)
|
||||||
|
|
||||||
@ -395,3 +431,359 @@ def get_basic_info():
|
|||||||
basic_info = DocumentService.knowledgebase_basic_info(kb_id)
|
basic_info = DocumentService.knowledgebase_basic_info(kb_id)
|
||||||
|
|
||||||
return get_json_result(data=basic_info)
|
return get_json_result(data=basic_info)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def list_pipeline_logs():
|
||||||
|
kb_id = request.args.get("kb_id")
|
||||||
|
if not kb_id:
|
||||||
|
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
keywords = request.args.get("keywords", "")
|
||||||
|
|
||||||
|
page_number = int(request.args.get("page", 0))
|
||||||
|
items_per_page = int(request.args.get("page_size", 0))
|
||||||
|
orderby = request.args.get("orderby", "create_time")
|
||||||
|
if request.args.get("desc", "true").lower() == "false":
|
||||||
|
desc = False
|
||||||
|
else:
|
||||||
|
desc = True
|
||||||
|
create_date_from = request.args.get("create_date_from", "")
|
||||||
|
create_date_to = request.args.get("create_date_to", "")
|
||||||
|
if create_date_to > create_date_from:
|
||||||
|
return get_data_error_result(message="Create data filter is abnormal.")
|
||||||
|
|
||||||
|
req = request.get_json()
|
||||||
|
|
||||||
|
operation_status = req.get("operation_status", [])
|
||||||
|
if operation_status:
|
||||||
|
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
|
||||||
|
if invalid_status:
|
||||||
|
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
|
||||||
|
|
||||||
|
types = req.get("types", [])
|
||||||
|
if types:
|
||||||
|
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||||
|
if invalid_types:
|
||||||
|
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
|
||||||
|
|
||||||
|
suffix = req.get("suffix", [])
|
||||||
|
|
||||||
|
try:
|
||||||
|
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from, create_date_to)
|
||||||
|
return get_json_result(data={"total": tol, "logs": logs})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def list_pipeline_dataset_logs():
|
||||||
|
kb_id = request.args.get("kb_id")
|
||||||
|
if not kb_id:
|
||||||
|
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
page_number = int(request.args.get("page", 0))
|
||||||
|
items_per_page = int(request.args.get("page_size", 0))
|
||||||
|
orderby = request.args.get("orderby", "create_time")
|
||||||
|
if request.args.get("desc", "true").lower() == "false":
|
||||||
|
desc = False
|
||||||
|
else:
|
||||||
|
desc = True
|
||||||
|
create_date_from = request.args.get("create_date_from", "")
|
||||||
|
create_date_to = request.args.get("create_date_to", "")
|
||||||
|
if create_date_to > create_date_from:
|
||||||
|
return get_data_error_result(message="Create data filter is abnormal.")
|
||||||
|
|
||||||
|
req = request.get_json()
|
||||||
|
|
||||||
|
operation_status = req.get("operation_status", [])
|
||||||
|
if operation_status:
|
||||||
|
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
|
||||||
|
if invalid_status:
|
||||||
|
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from, create_date_to)
|
||||||
|
return get_json_result(data={"total": tol, "logs": logs})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def delete_pipeline_logs():
|
||||||
|
kb_id = request.args.get("kb_id")
|
||||||
|
if not kb_id:
|
||||||
|
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
req = request.get_json()
|
||||||
|
log_ids = req.get("log_ids", [])
|
||||||
|
|
||||||
|
PipelineOperationLogService.delete_by_ids(log_ids)
|
||||||
|
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/pipeline_log_detail", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def pipeline_log_detail():
|
||||||
|
log_id = request.args.get("log_id")
|
||||||
|
if not log_id:
|
||||||
|
return get_json_result(data=False, message='Lack of "Pipeline log ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
ok, log = PipelineOperationLogService.get_by_id(log_id)
|
||||||
|
if not ok:
|
||||||
|
return get_data_error_result(message="Invalid pipeline log ID")
|
||||||
|
|
||||||
|
return get_json_result(data=log.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def run_graphrag():
|
||||||
|
req = request.json
|
||||||
|
|
||||||
|
kb_id = req.get("kb_id", "")
|
||||||
|
if not kb_id:
|
||||||
|
return get_error_data_result(message='Lack of "KB ID"')
|
||||||
|
|
||||||
|
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not ok:
|
||||||
|
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||||
|
|
||||||
|
task_id = kb.graphrag_task_id
|
||||||
|
if task_id:
|
||||||
|
ok, task = TaskService.get_by_id(task_id)
|
||||||
|
if not ok:
|
||||||
|
logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}")
|
||||||
|
|
||||||
|
if task and task.progress not in [-1, 1]:
|
||||||
|
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
|
||||||
|
|
||||||
|
documents, _ = DocumentService.get_by_kb_id(
|
||||||
|
kb_id=kb_id,
|
||||||
|
page_number=0,
|
||||||
|
items_per_page=0,
|
||||||
|
orderby="create_time",
|
||||||
|
desc=False,
|
||||||
|
keywords="",
|
||||||
|
run_status=[],
|
||||||
|
types=[],
|
||||||
|
suffix=[],
|
||||||
|
)
|
||||||
|
if not documents:
|
||||||
|
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||||
|
|
||||||
|
sample_document = documents[0]
|
||||||
|
document_ids = [document["id"] for document in documents]
|
||||||
|
|
||||||
|
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||||
|
|
||||||
|
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
||||||
|
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
|
||||||
|
|
||||||
|
return get_json_result(data={"graphrag_task_id": task_id})
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/trace_graphrag", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def trace_graphrag():
|
||||||
|
kb_id = request.args.get("kb_id", "")
|
||||||
|
if not kb_id:
|
||||||
|
return get_error_data_result(message='Lack of "KB ID"')
|
||||||
|
|
||||||
|
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not ok:
|
||||||
|
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||||
|
|
||||||
|
task_id = kb.graphrag_task_id
|
||||||
|
if not task_id:
|
||||||
|
return get_json_result(data={})
|
||||||
|
|
||||||
|
ok, task = TaskService.get_by_id(task_id)
|
||||||
|
if not ok:
|
||||||
|
return get_error_data_result(message="GraphRAG Task Not Found or Error Occurred")
|
||||||
|
|
||||||
|
return get_json_result(data=task.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def run_raptor():
|
||||||
|
req = request.json
|
||||||
|
|
||||||
|
kb_id = req.get("kb_id", "")
|
||||||
|
if not kb_id:
|
||||||
|
return get_error_data_result(message='Lack of "KB ID"')
|
||||||
|
|
||||||
|
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not ok:
|
||||||
|
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||||
|
|
||||||
|
task_id = kb.raptor_task_id
|
||||||
|
if task_id:
|
||||||
|
ok, task = TaskService.get_by_id(task_id)
|
||||||
|
if not ok:
|
||||||
|
logging.warning(f"A valid RAPTOR task id is expected for kb {kb_id}")
|
||||||
|
|
||||||
|
if task and task.progress not in [-1, 1]:
|
||||||
|
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
|
||||||
|
|
||||||
|
documents, _ = DocumentService.get_by_kb_id(
|
||||||
|
kb_id=kb_id,
|
||||||
|
page_number=0,
|
||||||
|
items_per_page=0,
|
||||||
|
orderby="create_time",
|
||||||
|
desc=False,
|
||||||
|
keywords="",
|
||||||
|
run_status=[],
|
||||||
|
types=[],
|
||||||
|
suffix=[],
|
||||||
|
)
|
||||||
|
if not documents:
|
||||||
|
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||||
|
|
||||||
|
sample_document = documents[0]
|
||||||
|
document_ids = [document["id"] for document in documents]
|
||||||
|
|
||||||
|
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||||
|
|
||||||
|
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
|
||||||
|
logging.warning(f"Cannot save raptor_task_id for kb {kb_id}")
|
||||||
|
|
||||||
|
return get_json_result(data={"raptor_task_id": task_id})
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/trace_raptor", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def trace_raptor():
|
||||||
|
kb_id = request.args.get("kb_id", "")
|
||||||
|
if not kb_id:
|
||||||
|
return get_error_data_result(message='Lack of "KB ID"')
|
||||||
|
|
||||||
|
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not ok:
|
||||||
|
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||||
|
|
||||||
|
task_id = kb.raptor_task_id
|
||||||
|
if not task_id:
|
||||||
|
return get_json_result(data={})
|
||||||
|
|
||||||
|
ok, task = TaskService.get_by_id(task_id)
|
||||||
|
if not ok:
|
||||||
|
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
|
||||||
|
|
||||||
|
return get_json_result(data=task.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def run_mindmap():
|
||||||
|
req = request.json
|
||||||
|
|
||||||
|
kb_id = req.get("kb_id", "")
|
||||||
|
if not kb_id:
|
||||||
|
return get_error_data_result(message='Lack of "KB ID"')
|
||||||
|
|
||||||
|
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not ok:
|
||||||
|
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||||
|
|
||||||
|
task_id = kb.mindmap_task_id
|
||||||
|
if task_id:
|
||||||
|
ok, task = TaskService.get_by_id(task_id)
|
||||||
|
if not ok:
|
||||||
|
logging.warning(f"A valid Mindmap task id is expected for kb {kb_id}")
|
||||||
|
|
||||||
|
if task and task.progress not in [-1, 1]:
|
||||||
|
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Mindmap Task is already running.")
|
||||||
|
|
||||||
|
documents, _ = DocumentService.get_by_kb_id(
|
||||||
|
kb_id=kb_id,
|
||||||
|
page_number=0,
|
||||||
|
items_per_page=0,
|
||||||
|
orderby="create_time",
|
||||||
|
desc=False,
|
||||||
|
keywords="",
|
||||||
|
run_status=[],
|
||||||
|
types=[],
|
||||||
|
suffix=[],
|
||||||
|
)
|
||||||
|
if not documents:
|
||||||
|
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||||
|
|
||||||
|
sample_document = documents[0]
|
||||||
|
document_ids = [document["id"] for document in documents]
|
||||||
|
|
||||||
|
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||||
|
|
||||||
|
if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}):
|
||||||
|
logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}")
|
||||||
|
|
||||||
|
return get_json_result(data={"mindmap_task_id": task_id})
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/trace_mindmap", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def trace_mindmap():
|
||||||
|
kb_id = request.args.get("kb_id", "")
|
||||||
|
if not kb_id:
|
||||||
|
return get_error_data_result(message='Lack of "KB ID"')
|
||||||
|
|
||||||
|
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not ok:
|
||||||
|
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||||
|
|
||||||
|
task_id = kb.mindmap_task_id
|
||||||
|
if not task_id:
|
||||||
|
return get_json_result(data={})
|
||||||
|
|
||||||
|
ok, task = TaskService.get_by_id(task_id)
|
||||||
|
if not ok:
|
||||||
|
return get_error_data_result(message="Mindmap Task Not Found or Error Occurred")
|
||||||
|
|
||||||
|
return get_json_result(data=task.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/unbind_task", methods=["DELETE"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def delete_kb_task():
|
||||||
|
kb_id = request.args.get("kb_id", "")
|
||||||
|
if not kb_id:
|
||||||
|
return get_error_data_result(message='Lack of "KB ID"')
|
||||||
|
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not ok:
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
pipeline_task_type = request.args.get("pipeline_task_type", "")
|
||||||
|
if not pipeline_task_type or pipeline_task_type not in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
|
||||||
|
return get_error_data_result(message="Invalid task type")
|
||||||
|
|
||||||
|
match pipeline_task_type:
|
||||||
|
case PipelineTaskType.GRAPH_RAG:
|
||||||
|
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||||
|
kb_task_id_field = "graphrag_task_id"
|
||||||
|
task_id = kb.graphrag_task_id
|
||||||
|
kb_task_finish_at = "graphrag_task_finish_at"
|
||||||
|
case PipelineTaskType.RAPTOR:
|
||||||
|
kb_task_id_field = "raptor_task_id"
|
||||||
|
task_id = kb.raptor_task_id
|
||||||
|
kb_task_finish_at = "raptor_task_finish_at"
|
||||||
|
case PipelineTaskType.MINDMAP:
|
||||||
|
kb_task_id_field = "mindmap_task_id"
|
||||||
|
task_id = kb.mindmap_task_id
|
||||||
|
kb_task_finish_at = "mindmap_task_finish_at"
|
||||||
|
case _:
|
||||||
|
return get_error_data_result(message="Internal Error: Invalid task type")
|
||||||
|
|
||||||
|
def cancel_task(task_id):
|
||||||
|
REDIS_CONN.set(f"{task_id}-cancel", "x")
|
||||||
|
cancel_task(task_id)
|
||||||
|
|
||||||
|
ok = KnowledgebaseService.update_by_id(kb_id, {kb_task_id_field: "", kb_task_finish_at: None})
|
||||||
|
if not ok:
|
||||||
|
return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}")
|
||||||
|
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|||||||
@ -1,8 +1,26 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
from flask import Response
|
from flask import Response
|
||||||
from flask_login import login_required
|
from flask_login import login_required
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
from plugin import GlobalPluginManager
|
from plugin import GlobalPluginManager
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/llm_tools', methods=['GET']) # noqa: F821
|
@manager.route('/llm_tools', methods=['GET']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def llm_tools() -> Response:
|
def llm_tools() -> Response:
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from api.utils.api_utils import get_data_error_result, get_error_data_result, ge
|
|||||||
from api.utils.api_utils import get_result
|
from api.utils.api_utils import get_result
|
||||||
from flask import request
|
from flask import request
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def list_agents(tenant_id):
|
def list_agents(tenant_id):
|
||||||
@ -41,7 +42,7 @@ def list_agents(tenant_id):
|
|||||||
desc = False
|
desc = False
|
||||||
else:
|
else:
|
||||||
desc = True
|
desc = True
|
||||||
canvas = UserCanvasService.get_list(tenant_id,page_number,items_per_page,orderby,desc,id,title)
|
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title)
|
||||||
return get_result(data=canvas)
|
return get_result(data=canvas)
|
||||||
|
|
||||||
|
|
||||||
@ -93,7 +94,7 @@ def update_agent(tenant_id: str, agent_id: str):
|
|||||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||||
|
|
||||||
req["dsl"] = json.loads(req["dsl"])
|
req["dsl"] = json.loads(req["dsl"])
|
||||||
|
|
||||||
if req.get("title") is not None:
|
if req.get("title") is not None:
|
||||||
req["title"] = req["title"].strip()
|
req["title"] = req["title"].strip()
|
||||||
|
|
||||||
|
|||||||
@ -215,7 +215,8 @@ def delete(tenant_id):
|
|||||||
continue
|
continue
|
||||||
kb_id_instance_pairs.append((kb_id, kb))
|
kb_id_instance_pairs.append((kb_id, kb))
|
||||||
if len(error_kb_ids) > 0:
|
if len(error_kb_ids) > 0:
|
||||||
return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
|
return get_error_permission_result(
|
||||||
|
message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
success_count = 0
|
success_count = 0
|
||||||
@ -232,7 +233,8 @@ def delete(tenant_id):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
File2DocumentService.delete_by_document_id(doc.id)
|
File2DocumentService.delete_by_document_id(doc.id)
|
||||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
|
FileService.filter_delete(
|
||||||
|
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
|
||||||
if not KnowledgebaseService.delete_by_id(kb_id):
|
if not KnowledgebaseService.delete_by_id(kb_id):
|
||||||
errors.append(f"Delete dataset error for {kb_id}")
|
errors.append(f"Delete dataset error for {kb_id}")
|
||||||
continue
|
continue
|
||||||
@ -329,7 +331,8 @@ def update(tenant_id, dataset_id):
|
|||||||
try:
|
try:
|
||||||
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
|
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
|
||||||
if kb is None:
|
if kb is None:
|
||||||
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
|
return get_error_permission_result(
|
||||||
|
message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
|
||||||
|
|
||||||
if req.get("parser_config"):
|
if req.get("parser_config"):
|
||||||
req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"])
|
req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"])
|
||||||
@ -341,7 +344,8 @@ def update(tenant_id, dataset_id):
|
|||||||
del req["parser_config"]
|
del req["parser_config"]
|
||||||
|
|
||||||
if "name" in req and req["name"].lower() != kb.name.lower():
|
if "name" in req and req["name"].lower() != kb.name.lower():
|
||||||
exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id,
|
||||||
|
status=StatusEnum.VALID.value)
|
||||||
if exists:
|
if exists:
|
||||||
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
|
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
|
||||||
|
|
||||||
@ -349,7 +353,8 @@ def update(tenant_id, dataset_id):
|
|||||||
if not req["embd_id"]:
|
if not req["embd_id"]:
|
||||||
req["embd_id"] = kb.embd_id
|
req["embd_id"] = kb.embd_id
|
||||||
if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
|
if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
|
||||||
return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
|
return get_error_data_result(
|
||||||
|
message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
|
||||||
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
|
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
|
||||||
if not ok:
|
if not ok:
|
||||||
return err
|
return err
|
||||||
@ -359,10 +364,12 @@ def update(tenant_id, dataset_id):
|
|||||||
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
|
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
|
||||||
|
|
||||||
if req["pagerank"] > 0:
|
if req["pagerank"] > 0:
|
||||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id)
|
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||||
|
search.index_name(kb.tenant_id), kb.id)
|
||||||
else:
|
else:
|
||||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id)
|
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||||
|
search.index_name(kb.tenant_id), kb.id)
|
||||||
|
|
||||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||||
return get_error_data_result(message="Update dataset error.(Database error)")
|
return get_error_data_result(message="Update dataset error.(Database error)")
|
||||||
@ -454,7 +461,7 @@ def list_datasets(tenant_id):
|
|||||||
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'")
|
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'")
|
||||||
|
|
||||||
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
|
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
|
||||||
kbs = KnowledgebaseService.get_list(
|
kbs, total = KnowledgebaseService.get_list(
|
||||||
[m["tenant_id"] for m in tenants],
|
[m["tenant_id"] for m in tenants],
|
||||||
tenant_id,
|
tenant_id,
|
||||||
args["page"],
|
args["page"],
|
||||||
@ -468,14 +475,15 @@ def list_datasets(tenant_id):
|
|||||||
response_data_list = []
|
response_data_list = []
|
||||||
for kb in kbs:
|
for kb in kbs:
|
||||||
response_data_list.append(remap_dictionary_keys(kb))
|
response_data_list.append(remap_dictionary_keys(kb))
|
||||||
return get_result(data=response_data_list)
|
return get_result(data=response_data_list, total=total)
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
return get_error_data_result(message="Database operation failed")
|
return get_error_data_result(message="Database operation failed")
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def knowledge_graph(tenant_id,dataset_id):
|
def knowledge_graph(tenant_id, dataset_id):
|
||||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||||
return get_result(
|
return get_result(
|
||||||
data=False,
|
data=False,
|
||||||
@ -491,7 +499,7 @@ def knowledge_graph(tenant_id,dataset_id):
|
|||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
|
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
|
||||||
return get_result(data=obj)
|
return get_result(data=obj)
|
||||||
sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||||
if not len(sres.ids):
|
if not len(sres.ids):
|
||||||
return get_result(data=obj)
|
return get_result(data=obj)
|
||||||
|
|
||||||
@ -507,14 +515,16 @@ def knowledge_graph(tenant_id,dataset_id):
|
|||||||
if "nodes" in obj["graph"]:
|
if "nodes" in obj["graph"]:
|
||||||
obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
|
obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
|
||||||
if "edges" in obj["graph"]:
|
if "edges" in obj["graph"]:
|
||||||
node_id_set = { o["id"] for o in obj["graph"]["nodes"] }
|
node_id_set = {o["id"] for o in obj["graph"]["nodes"]}
|
||||||
filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
|
filtered_edges = [o for o in obj["graph"]["edges"] if
|
||||||
|
o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
|
||||||
obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128]
|
obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128]
|
||||||
return get_result(data=obj)
|
return get_result(data=obj)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['DELETE']) # noqa: F821
|
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['DELETE']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def delete_knowledge_graph(tenant_id,dataset_id):
|
def delete_knowledge_graph(tenant_id, dataset_id):
|
||||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||||
return get_result(
|
return get_result(
|
||||||
data=False,
|
data=False,
|
||||||
@ -522,6 +532,7 @@ def delete_knowledge_graph(tenant_id,dataset_id):
|
|||||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||||
)
|
)
|
||||||
_, kb = KnowledgebaseService.get_by_id(dataset_id)
|
_, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id)
|
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
|
||||||
|
search.index_name(kb.tenant_id), dataset_id)
|
||||||
|
|
||||||
return get_result(data=True)
|
return get_result(data=True)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -31,6 +31,89 @@ from api.db.services.dialog_service import meta_filter, convert_conditions
|
|||||||
@apikey_required
|
@apikey_required
|
||||||
@validate_request("knowledge_id", "query")
|
@validate_request("knowledge_id", "query")
|
||||||
def retrieval(tenant_id):
|
def retrieval(tenant_id):
|
||||||
|
"""
|
||||||
|
Dify-compatible retrieval API
|
||||||
|
---
|
||||||
|
tags:
|
||||||
|
- SDK
|
||||||
|
security:
|
||||||
|
- ApiKeyAuth: []
|
||||||
|
parameters:
|
||||||
|
- in: body
|
||||||
|
name: body
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
required:
|
||||||
|
- knowledge_id
|
||||||
|
- query
|
||||||
|
properties:
|
||||||
|
knowledge_id:
|
||||||
|
type: string
|
||||||
|
description: Knowledge base ID
|
||||||
|
query:
|
||||||
|
type: string
|
||||||
|
description: Query text
|
||||||
|
use_kg:
|
||||||
|
type: boolean
|
||||||
|
description: Whether to use knowledge graph
|
||||||
|
default: false
|
||||||
|
retrieval_setting:
|
||||||
|
type: object
|
||||||
|
description: Retrieval configuration
|
||||||
|
properties:
|
||||||
|
score_threshold:
|
||||||
|
type: number
|
||||||
|
description: Similarity threshold
|
||||||
|
default: 0.0
|
||||||
|
top_k:
|
||||||
|
type: integer
|
||||||
|
description: Number of results to return
|
||||||
|
default: 1024
|
||||||
|
metadata_condition:
|
||||||
|
type: object
|
||||||
|
description: Metadata filter condition
|
||||||
|
properties:
|
||||||
|
conditions:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
description: Field name
|
||||||
|
comparison_operator:
|
||||||
|
type: string
|
||||||
|
description: Comparison operator
|
||||||
|
value:
|
||||||
|
type: string
|
||||||
|
description: Field value
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Retrieval succeeded
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
records:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
content:
|
||||||
|
type: string
|
||||||
|
description: Content text
|
||||||
|
score:
|
||||||
|
type: number
|
||||||
|
description: Similarity score
|
||||||
|
title:
|
||||||
|
type: string
|
||||||
|
description: Document title
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
description: Metadata info
|
||||||
|
404:
|
||||||
|
description: Knowledge base or document not found
|
||||||
|
"""
|
||||||
req = request.json
|
req = request.json
|
||||||
question = req["query"]
|
question = req["query"]
|
||||||
kb_id = req["knowledge_id"]
|
kb_id = req["knowledge_id"]
|
||||||
@ -38,9 +121,9 @@ def retrieval(tenant_id):
|
|||||||
retrieval_setting = req.get("retrieval_setting", {})
|
retrieval_setting = req.get("retrieval_setting", {})
|
||||||
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
|
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
|
||||||
top = int(retrieval_setting.get("top_k", 1024))
|
top = int(retrieval_setting.get("top_k", 1024))
|
||||||
metadata_condition = req.get("metadata_condition",{})
|
metadata_condition = req.get("metadata_condition", {})
|
||||||
metas = DocumentService.get_meta_by_kbs([kb_id])
|
metas = DocumentService.get_meta_by_kbs([kb_id])
|
||||||
|
|
||||||
doc_ids = []
|
doc_ids = []
|
||||||
try:
|
try:
|
||||||
|
|
||||||
@ -50,12 +133,12 @@ def retrieval(tenant_id):
|
|||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
print(metadata_condition)
|
print(metadata_condition)
|
||||||
print("after",convert_conditions(metadata_condition))
|
# print("after", convert_conditions(metadata_condition))
|
||||||
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition)))
|
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition)))
|
||||||
print("doc_ids",doc_ids)
|
# print("doc_ids", doc_ids)
|
||||||
if not doc_ids and metadata_condition is not None:
|
if not doc_ids and metadata_condition is not None:
|
||||||
doc_ids = ['-999']
|
doc_ids = ['-999']
|
||||||
ranks = settings.retrievaler.retrieval(
|
ranks = settings.retriever.retrieval(
|
||||||
question,
|
question,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
kb.tenant_id,
|
kb.tenant_id,
|
||||||
@ -70,17 +153,17 @@ def retrieval(tenant_id):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retrievaler.retrieval(question,
|
ck = settings.kg_retriever.retrieval(question,
|
||||||
[tenant_id],
|
[tenant_id],
|
||||||
[kb_id],
|
[kb_id],
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
ranks["chunks"].insert(0, ck)
|
ranks["chunks"].insert(0, ck)
|
||||||
|
|
||||||
records = []
|
records = []
|
||||||
for c in ranks["chunks"]:
|
for c in ranks["chunks"]:
|
||||||
e, doc = DocumentService.get_by_id( c["doc_id"])
|
e, doc = DocumentService.get_by_id(c["doc_id"])
|
||||||
c.pop("vector", None)
|
c.pop("vector", None)
|
||||||
meta = getattr(doc, 'meta_fields', {})
|
meta = getattr(doc, 'meta_fields', {})
|
||||||
meta["doc_id"] = c["doc_id"]
|
meta["doc_id"] = c["doc_id"]
|
||||||
@ -100,5 +183,3 @@ def retrieval(tenant_id):
|
|||||||
)
|
)
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)
|
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -40,7 +40,7 @@ from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_
|
|||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
from rag.prompts import cross_languages, keyword_extraction
|
from rag.prompts.generator import cross_languages, keyword_extraction
|
||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
@ -458,7 +458,7 @@ def list_docs(dataset_id, tenant_id):
|
|||||||
required: false
|
required: false
|
||||||
default: true
|
default: true
|
||||||
description: Order in descending.
|
description: Order in descending.
|
||||||
- in: query
|
- in: query
|
||||||
name: create_time_from
|
name: create_time_from
|
||||||
type: integer
|
type: integer
|
||||||
required: false
|
required: false
|
||||||
@ -982,7 +982,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|||||||
_ = Chunk(**final_chunk)
|
_ = Chunk(**final_chunk)
|
||||||
|
|
||||||
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
||||||
sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||||
res["total"] = sres.total
|
res["total"] = sres.total
|
||||||
for id in sres.ids:
|
for id in sres.ids:
|
||||||
d = {
|
d = {
|
||||||
@ -1446,7 +1446,7 @@ def retrieval_test(tenant_id):
|
|||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
question += keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
ranks = settings.retrievaler.retrieval(
|
ranks = settings.retriever.retrieval(
|
||||||
question,
|
question,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
@ -1462,7 +1462,7 @@ def retrieval_test(tenant_id):
|
|||||||
rank_feature=label_question(question, kbs),
|
rank_feature=label_question(question, kbs),
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retrievaler.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
ranks["chunks"].insert(0, ck)
|
ranks["chunks"].insert(0, ck)
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,20 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -17,7 +34,8 @@ from api.utils.api_utils import get_json_result
|
|||||||
from api.utils.file_utils import filename_type
|
from api.utils.file_utils import filename_type
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
|
||||||
|
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def upload(tenant_id):
|
def upload(tenant_id):
|
||||||
"""
|
"""
|
||||||
@ -44,22 +62,22 @@ def upload(tenant_id):
|
|||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
data:
|
data:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
id:
|
id:
|
||||||
type: string
|
type: string
|
||||||
description: File ID
|
description: File ID
|
||||||
name:
|
name:
|
||||||
type: string
|
type: string
|
||||||
description: File name
|
description: File name
|
||||||
size:
|
size:
|
||||||
type: integer
|
type: integer
|
||||||
description: File size in bytes
|
description: File size in bytes
|
||||||
type:
|
type:
|
||||||
type: string
|
type: string
|
||||||
description: File type (e.g., document, folder)
|
description: File type (e.g., document, folder)
|
||||||
"""
|
"""
|
||||||
pf_id = request.form.get("parent_id")
|
pf_id = request.form.get("parent_id")
|
||||||
|
|
||||||
@ -83,26 +101,28 @@ def upload(tenant_id):
|
|||||||
return get_json_result(data=False, message="Can't find this folder!", code=404)
|
return get_json_result(data=False, message="Can't find this folder!", code=404)
|
||||||
|
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
# 文件路径处理
|
# Handle file path
|
||||||
full_path = '/' + file_obj.filename
|
full_path = '/' + file_obj.filename
|
||||||
file_obj_names = full_path.split('/')
|
file_obj_names = full_path.split('/')
|
||||||
file_len = len(file_obj_names)
|
file_len = len(file_obj_names)
|
||||||
|
|
||||||
# 获取文件夹路径ID
|
# Get folder path ID
|
||||||
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
||||||
len_id_list = len(file_id_list)
|
len_id_list = len(file_id_list)
|
||||||
|
|
||||||
# 创建文件夹结构
|
# Crete file folder
|
||||||
if file_len != len_id_list:
|
if file_len != len_id_list:
|
||||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
return get_json_result(data=False, message="Folder not found!", code=404)
|
||||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names, len_id_list)
|
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
||||||
|
len_id_list)
|
||||||
else:
|
else:
|
||||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
||||||
if not e:
|
if not e:
|
||||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
return get_json_result(data=False, message="Folder not found!", code=404)
|
||||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names, len_id_list)
|
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
||||||
|
len_id_list)
|
||||||
|
|
||||||
filetype = filename_type(file_obj_names[file_len - 1])
|
filetype = filename_type(file_obj_names[file_len - 1])
|
||||||
location = file_obj_names[file_len - 1]
|
location = file_obj_names[file_len - 1]
|
||||||
@ -129,7 +149,7 @@ def upload(tenant_id):
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/file/create', methods=['POST']) # noqa: F821
|
@manager.route('/file/create', methods=['POST']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def create(tenant_id):
|
def create(tenant_id):
|
||||||
"""
|
"""
|
||||||
@ -207,7 +227,7 @@ def create(tenant_id):
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/file/list', methods=['GET']) # noqa: F821
|
@manager.route('/file/list', methods=['GET']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def list_files(tenant_id):
|
def list_files(tenant_id):
|
||||||
"""
|
"""
|
||||||
@ -299,7 +319,7 @@ def list_files(tenant_id):
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821
|
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def get_root_folder(tenant_id):
|
def get_root_folder(tenant_id):
|
||||||
"""
|
"""
|
||||||
@ -335,7 +355,7 @@ def get_root_folder(tenant_id):
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
|
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def get_parent_folder():
|
def get_parent_folder():
|
||||||
"""
|
"""
|
||||||
@ -380,7 +400,7 @@ def get_parent_folder():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
|
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def get_all_parent_folders(tenant_id):
|
def get_all_parent_folders(tenant_id):
|
||||||
"""
|
"""
|
||||||
@ -428,7 +448,7 @@ def get_all_parent_folders(tenant_id):
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/file/rm', methods=['POST']) # noqa: F821
|
@manager.route('/file/rm', methods=['POST']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def rm(tenant_id):
|
def rm(tenant_id):
|
||||||
"""
|
"""
|
||||||
@ -502,7 +522,7 @@ def rm(tenant_id):
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/file/rename', methods=['POST']) # noqa: F821
|
@manager.route('/file/rename', methods=['POST']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def rename(tenant_id):
|
def rename(tenant_id):
|
||||||
"""
|
"""
|
||||||
@ -542,7 +562,8 @@ def rename(tenant_id):
|
|||||||
if not e:
|
if not e:
|
||||||
return get_json_result(message="File not found!", code=404)
|
return get_json_result(message="File not found!", code=404)
|
||||||
|
|
||||||
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(file.name.lower()).suffix:
|
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||||
|
file.name.lower()).suffix:
|
||||||
return get_json_result(data=False, message="The extension of file can't be changed", code=400)
|
return get_json_result(data=False, message="The extension of file can't be changed", code=400)
|
||||||
|
|
||||||
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||||
@ -562,9 +583,9 @@ def rename(tenant_id):
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/file/get/<file_id>', methods=['GET']) # noqa: F821
|
@manager.route('/file/get/<file_id>', methods=['GET']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def get(tenant_id,file_id):
|
def get(tenant_id, file_id):
|
||||||
"""
|
"""
|
||||||
Download a file.
|
Download a file.
|
||||||
---
|
---
|
||||||
@ -610,7 +631,7 @@ def get(tenant_id,file_id):
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def move(tenant_id):
|
def move(tenant_id):
|
||||||
"""
|
"""
|
||||||
@ -669,6 +690,7 @@ def move(tenant_id):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def convert(tenant_id):
|
def convert(tenant_id):
|
||||||
@ -735,4 +757,4 @@ def convert(tenant_id):
|
|||||||
file2documents.append(file2document.to_json())
|
file2documents.append(file2document.to_json())
|
||||||
return get_json_result(data=file2documents)
|
return get_json_result(data=file2documents)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|||||||
@ -36,11 +36,11 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.search_service import SearchService
|
from api.db.services.search_service import SearchService
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request
|
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
|
||||||
|
get_result, server_error_response, token_required, validate_request
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts import chunks_format
|
from rag.prompts.template import load_prompt
|
||||||
from rag.prompts.prompt_template import load_prompt
|
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
||||||
from rag.prompts.prompts import cross_languages, gen_meta_filter, keyword_extraction
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||||
@ -89,7 +89,8 @@ def create_agent_session(tenant_id, agent_id):
|
|||||||
canvas.reset()
|
canvas.reset()
|
||||||
|
|
||||||
cvs.dsl = json.loads(str(canvas))
|
cvs.dsl = json.loads(str(canvas))
|
||||||
conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
|
conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id,
|
||||||
|
"message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
|
||||||
API4ConversationService.save(**conv)
|
API4ConversationService.save(**conv)
|
||||||
conv["agent_id"] = conv.pop("dialog_id")
|
conv["agent_id"] = conv.pop("dialog_id")
|
||||||
return get_result(data=conv)
|
return get_result(data=conv)
|
||||||
@ -280,7 +281,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
||||||
if reasoning_match:
|
if reasoning_match:
|
||||||
reasoning_part = reasoning_match.group(1)
|
reasoning_part = reasoning_match.group(1)
|
||||||
content_part = answer[reasoning_match.end() :]
|
content_part = answer[reasoning_match.end():]
|
||||||
else:
|
else:
|
||||||
reasoning_part = ""
|
reasoning_part = ""
|
||||||
content_part = answer
|
content_part = answer
|
||||||
@ -325,7 +326,8 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
response["choices"][0]["delta"]["content"] = None
|
response["choices"][0]["delta"]["content"] = None
|
||||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||||
response["choices"][0]["finish_reason"] = "stop"
|
response["choices"][0]["finish_reason"] = "stop"
|
||||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
|
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used,
|
||||||
|
"total_tokens": len(prompt) + token_used}
|
||||||
if need_reference:
|
if need_reference:
|
||||||
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
|
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
|
||||||
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
|
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
|
||||||
@ -560,7 +562,8 @@ def list_agent_session(tenant_id, agent_id):
|
|||||||
desc = True
|
desc = True
|
||||||
# dsl defaults to True in all cases except for False and false
|
# dsl defaults to True in all cases except for False and false
|
||||||
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
|
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
|
||||||
total, convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl)
|
total, convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id,
|
||||||
|
user_id, include_dsl)
|
||||||
if not convs:
|
if not convs:
|
||||||
return get_result(data=[])
|
return get_result(data=[])
|
||||||
for conv in convs:
|
for conv in convs:
|
||||||
@ -582,7 +585,8 @@ def list_agent_session(tenant_id, agent_id):
|
|||||||
if message_num != 0 and messages[message_num]["role"] != "user":
|
if message_num != 0 and messages[message_num]["role"] != "user":
|
||||||
chunk_list = []
|
chunk_list = []
|
||||||
# Add boundary and type checks to prevent KeyError
|
# Add boundary and type checks to prevent KeyError
|
||||||
if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance(conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]:
|
if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance(
|
||||||
|
conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]:
|
||||||
chunks = conv["reference"][chunk_num]["chunks"]
|
chunks = conv["reference"][chunk_num]["chunks"]
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
# Ensure chunk is a dictionary before calling get method
|
# Ensure chunk is a dictionary before calling get method
|
||||||
@ -640,13 +644,16 @@ def delete(tenant_id, chat_id):
|
|||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
return get_result(data={"success_count": success_count, "errors": errors},
|
||||||
|
message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
||||||
else:
|
else:
|
||||||
return get_error_data_result(message="; ".join(errors))
|
return get_error_data_result(message="; ".join(errors))
|
||||||
|
|
||||||
if duplicate_messages:
|
if duplicate_messages:
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
|
return get_result(
|
||||||
|
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
|
||||||
|
data={"success_count": success_count, "errors": duplicate_messages})
|
||||||
else:
|
else:
|
||||||
return get_error_data_result(message=";".join(duplicate_messages))
|
return get_error_data_result(message=";".join(duplicate_messages))
|
||||||
|
|
||||||
@ -692,13 +699,16 @@ def delete_agent_session(tenant_id, agent_id):
|
|||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
return get_result(data={"success_count": success_count, "errors": errors},
|
||||||
|
message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
||||||
else:
|
else:
|
||||||
return get_error_data_result(message="; ".join(errors))
|
return get_error_data_result(message="; ".join(errors))
|
||||||
|
|
||||||
if duplicate_messages:
|
if duplicate_messages:
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
|
return get_result(
|
||||||
|
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
|
||||||
|
data={"success_count": success_count, "errors": duplicate_messages})
|
||||||
else:
|
else:
|
||||||
return get_error_data_result(message=";".join(duplicate_messages))
|
return get_error_data_result(message=";".join(duplicate_messages))
|
||||||
|
|
||||||
@ -731,7 +741,9 @@ def ask_about(tenant_id):
|
|||||||
for ans in ask(req["question"], req["kb_ids"], uid):
|
for ans in ask(req["question"], req["kb_ids"], uid):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps(
|
||||||
|
{"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||||
|
ensure_ascii=False) + "\n\n"
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
resp = Response(stream(), mimetype="text/event-stream")
|
resp = Response(stream(), mimetype="text/event-stream")
|
||||||
@ -883,7 +895,9 @@ def begin_inputs(agent_id):
|
|||||||
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
|
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
|
||||||
|
|
||||||
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id)
|
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id)
|
||||||
return get_result(data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"), "prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
|
return get_result(
|
||||||
|
data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"),
|
||||||
|
"prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
|
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
|
||||||
@ -912,7 +926,9 @@ def ask_about_embedded():
|
|||||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps(
|
||||||
|
{"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||||
|
ensure_ascii=False) + "\n\n"
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
resp = Response(stream(), mimetype="text/event-stream")
|
resp = Response(stream(), mimetype="text/event-stream")
|
||||||
@ -979,7 +995,8 @@ def retrieval_test_embedded():
|
|||||||
tenant_ids.append(tenant.tenant_id)
|
tenant_ids.append(tenant.tenant_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.",
|
||||||
|
code=settings.RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||||
if not e:
|
if not e:
|
||||||
@ -999,11 +1016,13 @@ def retrieval_test_embedded():
|
|||||||
question += keyword_extraction(chat_mdl, question)
|
question += keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
labels = label_question(question, [kb])
|
labels = label_question(question, [kb])
|
||||||
ranks = settings.retrievaler.retrieval(
|
ranks = settings.retriever.retrieval(
|
||||||
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||||
|
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retrievaler.retrieval(question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl,
|
||||||
|
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
ranks["chunks"].insert(0, ck)
|
ranks["chunks"].insert(0, ck)
|
||||||
|
|
||||||
@ -1014,7 +1033,8 @@ def retrieval_test_embedded():
|
|||||||
return get_json_result(data=ranks)
|
return get_json_result(data=ranks)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message="No chunk found! Check the chunk status please!", code=settings.RetCode.DATA_ERROR)
|
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
||||||
|
code=settings.RetCode.DATA_ERROR)
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@ -1083,7 +1103,8 @@ def detail_share_embedded():
|
|||||||
if SearchService.query(tenant_id=tenant.tenant_id, id=search_id):
|
if SearchService.query(tenant_id=tenant.tenant_id, id=search_id):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return get_json_result(data=False, message="Has no permission for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
return get_json_result(data=False, message="Has no permission for this operation.",
|
||||||
|
code=settings.RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
search = SearchService.get_detail(search_id)
|
search = SearchService.get_detail(search_id)
|
||||||
if not search:
|
if not search:
|
||||||
|
|||||||
@ -39,6 +39,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
|||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
from api.utils.health_utils import run_health_checks
|
from api.utils.health_utils import run_health_checks
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/version", methods=["GET"]) # noqa: F821
|
@manager.route("/version", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def version():
|
def version():
|
||||||
@ -161,7 +162,7 @@ def status():
|
|||||||
task_executors = REDIS_CONN.smembers("TASKEXE")
|
task_executors = REDIS_CONN.smembers("TASKEXE")
|
||||||
now = datetime.now().timestamp()
|
now = datetime.now().timestamp()
|
||||||
for task_executor_id in task_executors:
|
for task_executor_id in task_executors:
|
||||||
heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60*30, now)
|
heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60 * 30, now)
|
||||||
heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats]
|
heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats]
|
||||||
task_executor_heartbeats[task_executor_id] = heartbeats
|
task_executor_heartbeats[task_executor_id] = heartbeats
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -177,6 +178,11 @@ def healthz():
|
|||||||
return jsonify(result), (200 if all_ok else 500)
|
return jsonify(result), (200 if all_ok else 500)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/ping", methods=["GET"]) # noqa: F821
|
||||||
|
def ping():
|
||||||
|
return "pong", 200
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/new_token", methods=["POST"]) # noqa: F821
|
@manager.route("/new_token", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def new_token():
|
def new_token():
|
||||||
@ -268,7 +274,8 @@ def token_list():
|
|||||||
objs = [o.to_dict() for o in objs]
|
objs = [o.to_dict() for o in objs]
|
||||||
for o in objs:
|
for o in objs:
|
||||||
if not o["beta"]:
|
if not o["beta"]:
|
||||||
o["beta"] = generate_confirmation_token(generate_confirmation_token(tenants[0].tenant_id)).replace("ragflow-", "")[:32]
|
o["beta"] = generate_confirmation_token(generate_confirmation_token(tenants[0].tenant_id)).replace(
|
||||||
|
"ragflow-", "")[:32]
|
||||||
APITokenService.filter_update([APIToken.tenant_id == tenant_id, APIToken.token == o["token"]], o)
|
APITokenService.filter_update([APIToken.tenant_id == tenant_id, APIToken.token == o["token"]], o)
|
||||||
return get_json_result(data=objs)
|
return get_json_result(data=objs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -70,7 +70,8 @@ def create(tenant_id):
|
|||||||
return get_data_error_result(message=f"{invite_user_email} is already in the team.")
|
return get_data_error_result(message=f"{invite_user_email} is already in the team.")
|
||||||
if user_tenant_role == UserTenantRole.OWNER:
|
if user_tenant_role == UserTenantRole.OWNER:
|
||||||
return get_data_error_result(message=f"{invite_user_email} is the owner of the team.")
|
return get_data_error_result(message=f"{invite_user_email} is the owner of the team.")
|
||||||
return get_data_error_result(message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid.")
|
return get_data_error_result(
|
||||||
|
message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid.")
|
||||||
|
|
||||||
UserTenantService.save(
|
UserTenantService.save(
|
||||||
id=get_uuid(),
|
id=get_uuid(),
|
||||||
@ -132,7 +133,8 @@ def tenant_list():
|
|||||||
@login_required
|
@login_required
|
||||||
def agree(tenant_id):
|
def agree(tenant_id):
|
||||||
try:
|
try:
|
||||||
UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id], {"role": UserTenantRole.NORMAL})
|
UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id],
|
||||||
|
{"role": UserTenantRole.NORMAL})
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|||||||
@ -34,7 +34,6 @@ from api.db.services.user_service import TenantService, UserService, UserTenantS
|
|||||||
from api.utils import (
|
from api.utils import (
|
||||||
current_timestamp,
|
current_timestamp,
|
||||||
datetime_format,
|
datetime_format,
|
||||||
decrypt,
|
|
||||||
download_img,
|
download_img,
|
||||||
get_format_time,
|
get_format_time,
|
||||||
get_uuid,
|
get_uuid,
|
||||||
@ -46,6 +45,7 @@ from api.utils.api_utils import (
|
|||||||
server_error_response,
|
server_error_response,
|
||||||
validate_request,
|
validate_request,
|
||||||
)
|
)
|
||||||
|
from api.utils.crypt import decrypt
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||||
@ -98,7 +98,14 @@ def login():
|
|||||||
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password")
|
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password")
|
||||||
|
|
||||||
user = UserService.query_user(email, password)
|
user = UserService.query_user(email, password)
|
||||||
if user:
|
|
||||||
|
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||||
|
return get_json_result(
|
||||||
|
data=False,
|
||||||
|
code=settings.RetCode.FORBIDDEN,
|
||||||
|
message="This account has been disabled, please contact the administrator!",
|
||||||
|
)
|
||||||
|
elif user:
|
||||||
response_data = user.to_json()
|
response_data = user.to_json()
|
||||||
user.access_token = get_uuid()
|
user.access_token = get_uuid()
|
||||||
login_user(user)
|
login_user(user)
|
||||||
@ -227,6 +234,9 @@ def oauth_callback(channel):
|
|||||||
# User exists, try to log in
|
# User exists, try to log in
|
||||||
user = users[0]
|
user = users[0]
|
||||||
user.access_token = get_uuid()
|
user.access_token = get_uuid()
|
||||||
|
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||||
|
return redirect("/?error=user_inactive")
|
||||||
|
|
||||||
login_user(user)
|
login_user(user)
|
||||||
user.save()
|
user.save()
|
||||||
return redirect(f"/?auth={user.get_id()}")
|
return redirect(f"/?auth={user.get_id()}")
|
||||||
@ -317,6 +327,8 @@ def github_callback():
|
|||||||
# User has already registered, try to log in
|
# User has already registered, try to log in
|
||||||
user = users[0]
|
user = users[0]
|
||||||
user.access_token = get_uuid()
|
user.access_token = get_uuid()
|
||||||
|
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||||
|
return redirect("/?error=user_inactive")
|
||||||
login_user(user)
|
login_user(user)
|
||||||
user.save()
|
user.save()
|
||||||
return redirect("/?auth=%s" % user.get_id())
|
return redirect("/?auth=%s" % user.get_id())
|
||||||
@ -418,6 +430,8 @@ def feishu_callback():
|
|||||||
|
|
||||||
# User has already registered, try to log in
|
# User has already registered, try to log in
|
||||||
user = users[0]
|
user = users[0]
|
||||||
|
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||||
|
return redirect("/?error=user_inactive")
|
||||||
user.access_token = get_uuid()
|
user.access_token = get_uuid()
|
||||||
login_user(user)
|
login_user(user)
|
||||||
user.save()
|
user.save()
|
||||||
|
|||||||
2
api/common/README.md
Normal file
2
api/common/README.md
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
The python files in this directory are shared between service. They contain common utilities, models, and functions that can be used across various
|
||||||
|
services to ensure consistency and reduce code duplication.
|
||||||
21
api/common/base64.py
Normal file
21
api/common/base64.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import base64
|
||||||
|
|
||||||
|
def encode_to_base64(input_string):
|
||||||
|
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
||||||
|
return base64_encoded.decode('utf-8')
|
||||||
59
api/common/check_team_permission.py
Normal file
59
api/common/check_team_permission.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
from api.db import TenantPermission
|
||||||
|
from api.db.db_models import File, Knowledgebase
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.db.services.user_service import TenantService
|
||||||
|
|
||||||
|
|
||||||
|
def check_kb_team_permission(kb: dict | Knowledgebase, other: str) -> bool:
|
||||||
|
kb = kb.to_dict() if isinstance(kb, Knowledgebase) else kb
|
||||||
|
|
||||||
|
kb_tenant_id = kb["tenant_id"]
|
||||||
|
|
||||||
|
if kb_tenant_id == other:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if kb["permission"] != TenantPermission.TEAM:
|
||||||
|
return False
|
||||||
|
|
||||||
|
joined_tenants = TenantService.get_joined_tenants_by_user_id(other)
|
||||||
|
return any(tenant["tenant_id"] == kb_tenant_id for tenant in joined_tenants)
|
||||||
|
|
||||||
|
|
||||||
|
def check_file_team_permission(file: dict | File, other: str) -> bool:
|
||||||
|
file = file.to_dict() if isinstance(file, File) else file
|
||||||
|
|
||||||
|
file_tenant_id = file["tenant_id"]
|
||||||
|
if file_tenant_id == other:
|
||||||
|
return True
|
||||||
|
|
||||||
|
file_id = file["id"]
|
||||||
|
|
||||||
|
kb_ids = [kb_info["kb_id"] for kb_info in FileService.get_kb_id_by_file_id(file_id)]
|
||||||
|
|
||||||
|
for kb_id in kb_ids:
|
||||||
|
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not ok:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if check_kb_team_permission(kb, other):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
38
api/common/exceptions.py
Normal file
38
api/common/exceptions.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
class AdminException(Exception):
|
||||||
|
def __init__(self, message, code=400):
|
||||||
|
super().__init__(message)
|
||||||
|
self.type = "admin"
|
||||||
|
self.code = code
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
|
class UserNotFoundError(AdminException):
|
||||||
|
def __init__(self, username):
|
||||||
|
super().__init__(f"User '{username}' not found", 404)
|
||||||
|
|
||||||
|
|
||||||
|
class UserAlreadyExistsError(AdminException):
|
||||||
|
def __init__(self, username):
|
||||||
|
super().__init__(f"User '{username}' already exists", 409)
|
||||||
|
|
||||||
|
|
||||||
|
class CannotDeleteAdminError(AdminException):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("Cannot delete admin account", 403)
|
||||||
@ -23,6 +23,11 @@ class StatusEnum(Enum):
|
|||||||
INVALID = "0"
|
INVALID = "0"
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveEnum(Enum):
|
||||||
|
ACTIVE = "1"
|
||||||
|
INACTIVE = "0"
|
||||||
|
|
||||||
|
|
||||||
class UserTenantRole(StrEnum):
|
class UserTenantRole(StrEnum):
|
||||||
OWNER = 'owner'
|
OWNER = 'owner'
|
||||||
ADMIN = 'admin'
|
ADMIN = 'admin'
|
||||||
@ -111,7 +116,7 @@ class CanvasCategory(StrEnum):
|
|||||||
Agent = "agent_canvas"
|
Agent = "agent_canvas"
|
||||||
DataFlow = "dataflow_canvas"
|
DataFlow = "dataflow_canvas"
|
||||||
|
|
||||||
VALID_CAVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow}
|
VALID_CANVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow}
|
||||||
|
|
||||||
|
|
||||||
class MCPServerType(StrEnum):
|
class MCPServerType(StrEnum):
|
||||||
@ -122,4 +127,15 @@ class MCPServerType(StrEnum):
|
|||||||
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineTaskType(StrEnum):
|
||||||
|
PARSE = "Parse"
|
||||||
|
DOWNLOAD = "Download"
|
||||||
|
RAPTOR = "RAPTOR"
|
||||||
|
GRAPH_RAG = "GraphRAG"
|
||||||
|
MINDMAP = "Mindmap"
|
||||||
|
|
||||||
|
|
||||||
|
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
|
||||||
|
|
||||||
|
|
||||||
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
||||||
|
|||||||
@ -26,12 +26,14 @@ from functools import wraps
|
|||||||
|
|
||||||
from flask_login import UserMixin
|
from flask_login import UserMixin
|
||||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||||
from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||||
|
|
||||||
from api import settings, utils
|
from api import settings, utils
|
||||||
from api.db import ParserType, SerializedType
|
from api.db import ParserType, SerializedType
|
||||||
|
from api.utils.json import json_dumps, json_loads
|
||||||
|
from api.utils.configs import deserialize_b64, serialize_b64
|
||||||
|
|
||||||
|
|
||||||
def singleton(cls, *args, **kw):
|
def singleton(cls, *args, **kw):
|
||||||
@ -70,12 +72,12 @@ class JSONField(LongTextField):
|
|||||||
def db_value(self, value):
|
def db_value(self, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
value = self.default_value
|
value = self.default_value
|
||||||
return utils.json_dumps(value)
|
return json_dumps(value)
|
||||||
|
|
||||||
def python_value(self, value):
|
def python_value(self, value):
|
||||||
if not value:
|
if not value:
|
||||||
return self.default_value
|
return self.default_value
|
||||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||||
|
|
||||||
|
|
||||||
class ListField(JSONField):
|
class ListField(JSONField):
|
||||||
@ -91,21 +93,21 @@ class SerializedField(LongTextField):
|
|||||||
|
|
||||||
def db_value(self, value):
|
def db_value(self, value):
|
||||||
if self._serialized_type == SerializedType.PICKLE:
|
if self._serialized_type == SerializedType.PICKLE:
|
||||||
return utils.serialize_b64(value, to_str=True)
|
return serialize_b64(value, to_str=True)
|
||||||
elif self._serialized_type == SerializedType.JSON:
|
elif self._serialized_type == SerializedType.JSON:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
return utils.json_dumps(value, with_type=True)
|
return json_dumps(value, with_type=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||||
|
|
||||||
def python_value(self, value):
|
def python_value(self, value):
|
||||||
if self._serialized_type == SerializedType.PICKLE:
|
if self._serialized_type == SerializedType.PICKLE:
|
||||||
return utils.deserialize_b64(value)
|
return deserialize_b64(value)
|
||||||
elif self._serialized_type == SerializedType.JSON:
|
elif self._serialized_type == SerializedType.JSON:
|
||||||
if value is None:
|
if value is None:
|
||||||
return {}
|
return {}
|
||||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||||
|
|
||||||
@ -250,36 +252,63 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def execute_sql(self, sql, params=None, commit=True):
|
def execute_sql(self, sql, params=None, commit=True):
|
||||||
from peewee import OperationalError
|
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
return super().execute_sql(sql, params, commit)
|
return super().execute_sql(sql, params, commit)
|
||||||
except OperationalError as e:
|
except (OperationalError, InterfaceError) as e:
|
||||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
error_codes = [2013, 2006]
|
||||||
logging.warning(f"Lost connection (attempt {attempt + 1}/{self.max_retries}): {e}")
|
error_messages = ['', 'Lost connection']
|
||||||
|
should_retry = (
|
||||||
|
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||||
|
(str(e) in error_messages) or
|
||||||
|
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_retry and attempt < self.max_retries:
|
||||||
|
logging.warning(
|
||||||
|
f"Database connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
|
||||||
|
)
|
||||||
self._handle_connection_loss()
|
self._handle_connection_loss()
|
||||||
time.sleep(self.retry_delay * (2**attempt))
|
time.sleep(self.retry_delay * (2 ** attempt))
|
||||||
else:
|
else:
|
||||||
logging.error(f"DB execution failure: {e}")
|
logging.error(f"DB execution failure: {e}")
|
||||||
raise
|
raise
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _handle_connection_loss(self):
|
def _handle_connection_loss(self):
|
||||||
self.close_all()
|
# self.close_all()
|
||||||
self.connect()
|
# self.connect()
|
||||||
|
try:
|
||||||
|
self.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
self.connect()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to reconnect: {e}")
|
||||||
|
time.sleep(0.1)
|
||||||
|
self.connect()
|
||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
from peewee import OperationalError
|
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
return super().begin()
|
return super().begin()
|
||||||
except OperationalError as e:
|
except (OperationalError, InterfaceError) as e:
|
||||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
error_codes = [2013, 2006]
|
||||||
logging.warning(f"Lost connection during transaction (attempt {attempt + 1}/{self.max_retries})")
|
error_messages = ['', 'Lost connection']
|
||||||
|
|
||||||
|
should_retry = (
|
||||||
|
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||||
|
(str(e) in error_messages) or
|
||||||
|
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_retry and attempt < self.max_retries:
|
||||||
|
logging.warning(
|
||||||
|
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
|
||||||
|
)
|
||||||
self._handle_connection_loss()
|
self._handle_connection_loss()
|
||||||
time.sleep(self.retry_delay * (2**attempt))
|
time.sleep(self.retry_delay * (2 ** attempt))
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -299,7 +328,16 @@ class BaseDataBase:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
database_config = settings.DATABASE.copy()
|
database_config = settings.DATABASE.copy()
|
||||||
db_name = database_config.pop("name")
|
db_name = database_config.pop("name")
|
||||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
|
||||||
|
pool_config = {
|
||||||
|
'max_retries': 5,
|
||||||
|
'retry_delay': 1,
|
||||||
|
}
|
||||||
|
database_config.update(pool_config)
|
||||||
|
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(
|
||||||
|
db_name, **database_config
|
||||||
|
)
|
||||||
|
# self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||||
logging.info("init database on cluster mode successfully")
|
logging.info("init database on cluster mode successfully")
|
||||||
|
|
||||||
|
|
||||||
@ -603,7 +641,7 @@ class TenantLLM(DataBaseModel):
|
|||||||
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name", index=True)
|
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name", index=True)
|
||||||
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR", index=True)
|
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR", index=True)
|
||||||
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="", index=True)
|
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="", index=True)
|
||||||
api_key = CharField(max_length=2048, null=True, help_text="API KEY", index=True)
|
api_key = TextField(null=True, help_text="API KEY")
|
||||||
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||||
max_tokens = IntegerField(default=8192, index=True)
|
max_tokens = IntegerField(default=8192, index=True)
|
||||||
used_tokens = IntegerField(default=0, index=True)
|
used_tokens = IntegerField(default=0, index=True)
|
||||||
@ -646,8 +684,17 @@ class Knowledgebase(DataBaseModel):
|
|||||||
vector_similarity_weight = FloatField(default=0.3, index=True)
|
vector_similarity_weight = FloatField(default=0.3, index=True)
|
||||||
|
|
||||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True)
|
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True)
|
||||||
|
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
||||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||||
pagerank = IntegerField(default=0, index=False)
|
pagerank = IntegerField(default=0, index=False)
|
||||||
|
|
||||||
|
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
|
||||||
|
graphrag_task_finish_at = DateTimeField(null=True)
|
||||||
|
raptor_task_id = CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)
|
||||||
|
raptor_task_finish_at = DateTimeField(null=True)
|
||||||
|
mindmap_task_id = CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)
|
||||||
|
mindmap_task_finish_at = DateTimeField(null=True)
|
||||||
|
|
||||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@ -662,6 +709,7 @@ class Document(DataBaseModel):
|
|||||||
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
||||||
kb_id = CharField(max_length=256, null=False, index=True)
|
kb_id = CharField(max_length=256, null=False, index=True)
|
||||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
|
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
|
||||||
|
pipeline_id = CharField(max_length=32, null=True, help_text="pipleline ID", index=True)
|
||||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||||
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
|
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
|
||||||
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
|
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
|
||||||
@ -904,6 +952,32 @@ class Search(DataBaseModel):
|
|||||||
db_table = "search"
|
db_table = "search"
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineOperationLog(DataBaseModel):
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
document_id = CharField(max_length=32, index=True)
|
||||||
|
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||||
|
kb_id = CharField(max_length=32, null=False, index=True)
|
||||||
|
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
||||||
|
pipeline_title = CharField(max_length=32, null=True, help_text="Pipeline title", index=True)
|
||||||
|
parser_id = CharField(max_length=32, null=False, help_text="Parser ID", index=True)
|
||||||
|
document_name = CharField(max_length=255, null=False, help_text="File name")
|
||||||
|
document_suffix = CharField(max_length=255, null=False, help_text="File suffix")
|
||||||
|
document_type = CharField(max_length=255, null=False, help_text="Document type")
|
||||||
|
source_from = CharField(max_length=255, null=False, help_text="Source")
|
||||||
|
progress = FloatField(default=0, index=True)
|
||||||
|
progress_msg = TextField(null=True, help_text="process message", default="")
|
||||||
|
process_begin_at = DateTimeField(null=True, index=True)
|
||||||
|
process_duration = FloatField(default=0)
|
||||||
|
dsl = JSONField(null=True, default=dict)
|
||||||
|
task_type = CharField(max_length=32, null=False, default="")
|
||||||
|
operation_status = CharField(max_length=32, null=False, help_text="Operation status")
|
||||||
|
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||||
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "pipeline_operation_log"
|
||||||
|
|
||||||
|
|
||||||
def migrate_db():
|
def migrate_db():
|
||||||
logging.disable(logging.ERROR)
|
logging.disable(logging.ERROR)
|
||||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||||
@ -1020,7 +1094,6 @@ def migrate_db():
|
|||||||
migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
|
migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
migrate(migrator.alter_column_type("canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title")))
|
migrate(migrator.alter_column_type("canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title")))
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -1037,4 +1110,40 @@ def migrate_db():
|
|||||||
migrate(migrator.add_column("canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
|
migrate(migrator.add_column("canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("knowledgebase", "raptor_task_finish_at", CharField(null=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("knowledgebase", "mindmap_task_finish_at", CharField(null=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.alter_column_type("tenant_llm", "api_key", TextField(null=True, help_text="API KEY")))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
logging.disable(logging.NOTSET)
|
logging.disable(logging.NOTSET)
|
||||||
|
|||||||
@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@ -32,11 +31,7 @@ from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_l
|
|||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
from api.common.base64 import encode_to_base64
|
||||||
|
|
||||||
def encode_to_base64(input_string):
|
|
||||||
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
|
||||||
return base64_encoded.decode('utf-8')
|
|
||||||
|
|
||||||
|
|
||||||
def init_superuser():
|
def init_superuser():
|
||||||
|
|||||||
327
api/db/joint_services/user_account_service.py
Normal file
327
api/db/joint_services/user_account_service.py
Normal file
@ -0,0 +1,327 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from api import settings
|
||||||
|
from api.utils.api_utils import group_by
|
||||||
|
from api.db import FileType, UserTenantRole, ActiveEnum
|
||||||
|
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||||
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
|
from api.db.services.conversation_service import ConversationService
|
||||||
|
from api.db.services.dialog_service import DialogService
|
||||||
|
from api.db.services.document_service import DocumentService
|
||||||
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.db.services.langfuse_service import TenantLangfuseService
|
||||||
|
from api.db.services.llm_service import get_init_tenant_llm
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.mcp_server_service import MCPServerService
|
||||||
|
from api.db.services.search_service import SearchService
|
||||||
|
from api.db.services.task_service import TaskService
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
|
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||||
|
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||||
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
from rag.nlp import search
|
||||||
|
|
||||||
|
|
||||||
|
def create_new_user(user_info: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Add a new user, and create tenant, tenant llm, file folder for new user.
|
||||||
|
:param user_info: {
|
||||||
|
"email": <example@example.com>,
|
||||||
|
"nickname": <str, "name">,
|
||||||
|
"password": <decrypted password>,
|
||||||
|
"login_channel": <enum, "password">,
|
||||||
|
"is_superuser": <bool, role == "admin">,
|
||||||
|
}
|
||||||
|
:return: {
|
||||||
|
"success": <bool>,
|
||||||
|
"user_info": <dict>, # if true, return user_info
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# generate user_id and access_token for user
|
||||||
|
user_id = uuid.uuid1().hex
|
||||||
|
user_info['id'] = user_id
|
||||||
|
user_info['access_token'] = uuid.uuid1().hex
|
||||||
|
# construct tenant info
|
||||||
|
tenant = {
|
||||||
|
"id": user_id,
|
||||||
|
"name": user_info["nickname"] + "‘s Kingdom",
|
||||||
|
"llm_id": settings.CHAT_MDL,
|
||||||
|
"embd_id": settings.EMBEDDING_MDL,
|
||||||
|
"asr_id": settings.ASR_MDL,
|
||||||
|
"parser_ids": settings.PARSERS,
|
||||||
|
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||||
|
"rerank_id": settings.RERANK_MDL,
|
||||||
|
}
|
||||||
|
usr_tenant = {
|
||||||
|
"tenant_id": user_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"invited_by": user_id,
|
||||||
|
"role": UserTenantRole.OWNER,
|
||||||
|
}
|
||||||
|
# construct file folder info
|
||||||
|
file_id = uuid.uuid1().hex
|
||||||
|
file = {
|
||||||
|
"id": file_id,
|
||||||
|
"parent_id": file_id,
|
||||||
|
"tenant_id": user_id,
|
||||||
|
"created_by": user_id,
|
||||||
|
"name": "/",
|
||||||
|
"type": FileType.FOLDER.value,
|
||||||
|
"size": 0,
|
||||||
|
"location": "",
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
tenant_llm = get_init_tenant_llm(user_id)
|
||||||
|
|
||||||
|
if not UserService.save(**user_info):
|
||||||
|
return {"success": False}
|
||||||
|
|
||||||
|
TenantService.insert(**tenant)
|
||||||
|
UserTenantService.insert(**usr_tenant)
|
||||||
|
TenantLLMService.insert_many(tenant_llm)
|
||||||
|
FileService.insert(file)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"user_info": user_info,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as create_error:
|
||||||
|
logging.exception(create_error)
|
||||||
|
# rollback
|
||||||
|
try:
|
||||||
|
TenantService.delete_by_id(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
try:
|
||||||
|
u = UserTenantService.query(tenant_id=user_id)
|
||||||
|
if u:
|
||||||
|
UserTenantService.delete_by_id(u[0].id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
try:
|
||||||
|
TenantLLMService.delete_by_tenant_id(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
try:
|
||||||
|
FileService.delete_by_id(file["id"])
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
# delete user row finally
|
||||||
|
try:
|
||||||
|
UserService.delete_by_id(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
# reraise
|
||||||
|
raise create_error
|
||||||
|
|
||||||
|
|
||||||
|
def delete_user_data(user_id: str) -> dict:
|
||||||
|
# use user_id to delete
|
||||||
|
usr = UserService.filter_by_id(user_id)
|
||||||
|
if not usr:
|
||||||
|
return {"success": False, "message": f"{user_id} can't be found."}
|
||||||
|
# check is inactive and not admin
|
||||||
|
if usr.is_active == ActiveEnum.ACTIVE.value:
|
||||||
|
return {"success": False, "message": f"{user_id} is active and can't be deleted."}
|
||||||
|
if usr.is_superuser:
|
||||||
|
return {"success": False, "message": "Can't delete the super user."}
|
||||||
|
# tenant info
|
||||||
|
tenants = UserTenantService.get_user_tenant_relation_by_user_id(usr.id)
|
||||||
|
owned_tenant = [t for t in tenants if t["role"] == UserTenantRole.OWNER.value]
|
||||||
|
|
||||||
|
done_msg = ''
|
||||||
|
try:
|
||||||
|
# step1. delete owned tenant info
|
||||||
|
if owned_tenant:
|
||||||
|
done_msg += "Start to delete owned tenant.\n"
|
||||||
|
tenant_id = owned_tenant[0]["tenant_id"]
|
||||||
|
kb_ids = KnowledgebaseService.get_kb_ids(usr.id)
|
||||||
|
# step1.1 delete knowledgebase related file and info
|
||||||
|
if kb_ids:
|
||||||
|
# step1.1.1 delete files in storage, remove bucket
|
||||||
|
for kb_id in kb_ids:
|
||||||
|
if STORAGE_IMPL.bucket_exists(kb_id):
|
||||||
|
STORAGE_IMPL.remove_bucket(kb_id)
|
||||||
|
done_msg += f"- Removed {len(kb_ids)} dataset's buckets.\n"
|
||||||
|
# step1.1.2 delete file and document info in db
|
||||||
|
doc_ids = DocumentService.get_all_doc_ids_by_kb_ids(kb_ids)
|
||||||
|
if doc_ids:
|
||||||
|
doc_delete_res = DocumentService.delete_by_ids([i["id"] for i in doc_ids])
|
||||||
|
done_msg += f"- Deleted {doc_delete_res} document records.\n"
|
||||||
|
task_delete_res = TaskService.delete_by_doc_ids([i["id"] for i in doc_ids])
|
||||||
|
done_msg += f"- Deleted {task_delete_res} task records.\n"
|
||||||
|
file_ids = FileService.get_all_file_ids_by_tenant_id(usr.id)
|
||||||
|
if file_ids:
|
||||||
|
file_delete_res = FileService.delete_by_ids([f["id"] for f in file_ids])
|
||||||
|
done_msg += f"- Deleted {file_delete_res} file records.\n"
|
||||||
|
if doc_ids or file_ids:
|
||||||
|
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
|
||||||
|
[i["id"] for i in doc_ids],
|
||||||
|
[f["id"] for f in file_ids]
|
||||||
|
)
|
||||||
|
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||||
|
# step1.1.3 delete chunk in es
|
||||||
|
r = settings.docStoreConn.delete({"kb_id": kb_ids},
|
||||||
|
search.index_name(tenant_id), kb_ids)
|
||||||
|
done_msg += f"- Deleted {r} chunk records.\n"
|
||||||
|
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
|
||||||
|
done_msg += f"- Deleted {kb_delete_res} knowledgebase records.\n"
|
||||||
|
# step1.1.4 delete agents
|
||||||
|
agent_delete_res = delete_user_agents(usr.id)
|
||||||
|
done_msg += f"- Deleted {agent_delete_res['agents_deleted_count']} agent, {agent_delete_res['version_deleted_count']} versions records.\n"
|
||||||
|
# step1.1.5 delete dialogs
|
||||||
|
dialog_delete_res = delete_user_dialogs(usr.id)
|
||||||
|
done_msg += f"- Deleted {dialog_delete_res['dialogs_deleted_count']} dialogs, {dialog_delete_res['conversations_deleted_count']} conversations, {dialog_delete_res['api_token_deleted_count']} api tokens, {dialog_delete_res['api4conversation_deleted_count']} api4conversations.\n"
|
||||||
|
# step1.1.6 delete mcp server
|
||||||
|
mcp_delete_res = MCPServerService.delete_by_tenant_id(usr.id)
|
||||||
|
done_msg += f"- Deleted {mcp_delete_res} MCP server.\n"
|
||||||
|
# step1.1.7 delete search
|
||||||
|
search_delete_res = SearchService.delete_by_tenant_id(usr.id)
|
||||||
|
done_msg += f"- Deleted {search_delete_res} search records.\n"
|
||||||
|
# step1.2 delete tenant_llm and tenant_langfuse
|
||||||
|
llm_delete_res = TenantLLMService.delete_by_tenant_id(tenant_id)
|
||||||
|
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
|
||||||
|
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
|
||||||
|
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
|
||||||
|
# step1.3 delete own tenant
|
||||||
|
tenant_delete_res = TenantService.delete_by_id(tenant_id)
|
||||||
|
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
|
||||||
|
# step2 delete user-tenant relation
|
||||||
|
if tenants:
|
||||||
|
# step2.1 delete docs and files in joined team
|
||||||
|
joined_tenants = [t for t in tenants if t["role"] == UserTenantRole.NORMAL.value]
|
||||||
|
if joined_tenants:
|
||||||
|
done_msg += "Start to delete data in joined tenants.\n"
|
||||||
|
created_documents = DocumentService.get_all_docs_by_creator_id(usr.id)
|
||||||
|
if created_documents:
|
||||||
|
# step2.1.1 delete files
|
||||||
|
doc_file_info = File2DocumentService.get_by_document_ids([d['id'] for d in created_documents])
|
||||||
|
created_files = FileService.get_by_ids([f['file_id'] for f in doc_file_info])
|
||||||
|
if created_files:
|
||||||
|
# step2.1.1.1 delete file in storage
|
||||||
|
for f in created_files:
|
||||||
|
STORAGE_IMPL.rm(f.parent_id, f.location)
|
||||||
|
done_msg += f"- Deleted {len(created_files)} uploaded file.\n"
|
||||||
|
# step2.1.1.2 delete file record
|
||||||
|
file_delete_res = FileService.delete_by_ids([f.id for f in created_files])
|
||||||
|
done_msg += f"- Deleted {file_delete_res} file records.\n"
|
||||||
|
# step2.1.2 delete document-file relation record
|
||||||
|
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
|
||||||
|
[d['id'] for d in created_documents],
|
||||||
|
[f.id for f in created_files]
|
||||||
|
)
|
||||||
|
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||||
|
# step2.1.3 delete chunks
|
||||||
|
doc_groups = group_by(created_documents, "tenant_id")
|
||||||
|
kb_grouped_doc = {k: group_by(v, "kb_id") for k, v in doc_groups.items()}
|
||||||
|
# chunks in {'tenant_id': {'kb_id': [{'id': doc_id}]}} structure
|
||||||
|
chunk_delete_res = 0
|
||||||
|
kb_doc_info = {}
|
||||||
|
for _tenant_id, kb_doc in kb_grouped_doc.items():
|
||||||
|
for _kb_id, docs in kb_doc.items():
|
||||||
|
chunk_delete_res += settings.docStoreConn.delete(
|
||||||
|
{"doc_id": [d["id"] for d in docs]},
|
||||||
|
search.index_name(_tenant_id), _kb_id
|
||||||
|
)
|
||||||
|
# record doc info
|
||||||
|
if _kb_id in kb_doc_info.keys():
|
||||||
|
kb_doc_info[_kb_id]['doc_num'] += 1
|
||||||
|
kb_doc_info[_kb_id]['token_num'] += sum([d["token_num"] for d in docs])
|
||||||
|
kb_doc_info[_kb_id]['chunk_num'] += sum([d["chunk_num"] for d in docs])
|
||||||
|
else:
|
||||||
|
kb_doc_info[_kb_id] = {
|
||||||
|
'doc_num': 1,
|
||||||
|
'token_num': sum([d["token_num"] for d in docs]),
|
||||||
|
'chunk_num': sum([d["chunk_num"] for d in docs])
|
||||||
|
}
|
||||||
|
done_msg += f"- Deleted {chunk_delete_res} chunks.\n"
|
||||||
|
# step2.1.4 delete tasks
|
||||||
|
task_delete_res = TaskService.delete_by_doc_ids([d['id'] for d in created_documents])
|
||||||
|
done_msg += f"- Deleted {task_delete_res} tasks.\n"
|
||||||
|
# step2.1.5 delete document record
|
||||||
|
doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents])
|
||||||
|
done_msg += f"- Deleted {doc_delete_res} documents.\n"
|
||||||
|
# step2.1.6 update knowledge base doc&chunk&token cnt
|
||||||
|
for kb_id, doc_num in kb_doc_info.items():
|
||||||
|
KnowledgebaseService.decrease_document_num_in_delete(kb_id, doc_num)
|
||||||
|
|
||||||
|
# step2.2 delete relation
|
||||||
|
user_tenant_delete_res = UserTenantService.delete_by_ids([t["id"] for t in tenants])
|
||||||
|
done_msg += f"- Deleted {user_tenant_delete_res} user-tenant records.\n"
|
||||||
|
# step3 finally delete user
|
||||||
|
user_delete_res = UserService.delete_by_id(usr.id)
|
||||||
|
done_msg += f"- Deleted {user_delete_res} user.\nDelete done!"
|
||||||
|
|
||||||
|
return {"success": True, "message": f"Successfully deleted user. Details:\n{done_msg}"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
return {"success": False, "message": f"Error: {str(e)}. Already done:\n{done_msg}"}
|
||||||
|
|
||||||
|
|
||||||
|
def delete_user_agents(user_id: str) -> dict:
|
||||||
|
"""
|
||||||
|
use user_id to delete
|
||||||
|
:return: {
|
||||||
|
"agents_deleted_count": 1,
|
||||||
|
"version_deleted_count": 2
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
agents_deleted_count, agents_version_deleted_count = 0, 0
|
||||||
|
user_agents = UserCanvasService.get_all_agents_by_tenant_ids([user_id], user_id)
|
||||||
|
if user_agents:
|
||||||
|
agents_version = UserCanvasVersionService.get_all_canvas_version_by_canvas_ids([a['id'] for a in user_agents])
|
||||||
|
agents_version_deleted_count = UserCanvasVersionService.delete_by_ids([v['id'] for v in agents_version])
|
||||||
|
agents_deleted_count = UserCanvasService.delete_by_ids([a['id'] for a in user_agents])
|
||||||
|
return {
|
||||||
|
"agents_deleted_count": agents_deleted_count,
|
||||||
|
"version_deleted_count": agents_version_deleted_count
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def delete_user_dialogs(user_id: str) -> dict:
|
||||||
|
"""
|
||||||
|
use user_id to delete
|
||||||
|
:return: {
|
||||||
|
"dialogs_deleted_count": 1,
|
||||||
|
"conversations_deleted_count": 1,
|
||||||
|
"api_token_deleted_count": 2,
|
||||||
|
"api4conversation_deleted_count": 2
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
dialog_deleted_count, conversations_deleted_count, api_token_deleted_count, api4conversation_deleted_count = 0, 0, 0, 0
|
||||||
|
user_dialogs = DialogService.get_all_dialogs_by_tenant_id(user_id)
|
||||||
|
if user_dialogs:
|
||||||
|
# delete conversation
|
||||||
|
conversations = ConversationService.get_all_conversation_by_dialog_ids([ud['id'] for ud in user_dialogs])
|
||||||
|
conversations_deleted_count = ConversationService.delete_by_ids([c['id'] for c in conversations])
|
||||||
|
# delete api token
|
||||||
|
api_token_deleted_count = APITokenService.delete_by_tenant_id(user_id)
|
||||||
|
# delete api for conversation
|
||||||
|
api4conversation_deleted_count = API4ConversationService.delete_by_dialog_ids([ud['id'] for ud in user_dialogs])
|
||||||
|
# delete dialog at last
|
||||||
|
dialog_deleted_count = DialogService.delete_by_ids([ud['id'] for ud in user_dialogs])
|
||||||
|
return {
|
||||||
|
"dialogs_deleted_count": dialog_deleted_count,
|
||||||
|
"conversations_deleted_count": conversations_deleted_count,
|
||||||
|
"api_token_deleted_count": api_token_deleted_count,
|
||||||
|
"api4conversation_deleted_count": api4conversation_deleted_count
|
||||||
|
}
|
||||||
@ -19,7 +19,7 @@ from pathlib import PurePath
|
|||||||
from .user_service import UserService as UserService
|
from .user_service import UserService as UserService
|
||||||
|
|
||||||
|
|
||||||
def split_name_counter(filename: str) -> tuple[str, int | None]:
|
def _split_name_counter(filename: str) -> tuple[str, int | None]:
|
||||||
"""
|
"""
|
||||||
Splits a filename into main part and counter (if present in parentheses).
|
Splits a filename into main part and counter (if present in parentheses).
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ def duplicate_name(query_func, **kwargs) -> str:
|
|||||||
stem = path.stem
|
stem = path.stem
|
||||||
suffix = path.suffix
|
suffix = path.suffix
|
||||||
|
|
||||||
main_part, counter = split_name_counter(stem)
|
main_part, counter = _split_name_counter(stem)
|
||||||
counter = counter + 1 if counter else 1
|
counter = counter + 1 if counter else 1
|
||||||
|
|
||||||
new_name = f"{main_part}({counter}){suffix}"
|
new_name = f"{main_part}({counter}){suffix}"
|
||||||
|
|||||||
@ -35,6 +35,11 @@ class APITokenService(CommonService):
|
|||||||
cls.model.token == token
|
cls.model.token == token
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_by_tenant_id(cls, tenant_id):
|
||||||
|
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||||
|
|
||||||
|
|
||||||
class API4ConversationService(CommonService):
|
class API4ConversationService(CommonService):
|
||||||
model = API4Conversation
|
model = API4Conversation
|
||||||
@ -100,3 +105,8 @@ class API4ConversationService(CommonService):
|
|||||||
cls.model.create_date <= to_date,
|
cls.model.create_date <= to_date,
|
||||||
cls.model.source == source
|
cls.model.source == source
|
||||||
).group_by(cls.model.create_date.truncate("day")).dicts()
|
).group_by(cls.model.create_date.truncate("day")).dicts()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_by_dialog_ids(cls, dialog_ids):
|
||||||
|
return cls.model.delete().where(cls.model.dialog_id.in_(dialog_ids)).execute()
|
||||||
|
|||||||
@ -63,7 +63,38 @@ class UserCanvasService(CommonService):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_tenant_id(cls, pid):
|
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
|
||||||
|
# will get all permitted agents, be cautious
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.title,
|
||||||
|
cls.model.permission,
|
||||||
|
cls.model.canvas_type,
|
||||||
|
cls.model.canvas_category
|
||||||
|
]
|
||||||
|
# find team agents and owned agents
|
||||||
|
agents = cls.model.select(*fields).where(
|
||||||
|
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
|
||||||
|
cls.model.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# sort by create_time, asc
|
||||||
|
agents.order_by(cls.model.create_time.asc())
|
||||||
|
# maybe cause slow query by deep paginate, optimize later
|
||||||
|
offset, limit = 0, 50
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
ag_batch = agents.offset(offset).limit(limit)
|
||||||
|
_temp = list(ag_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_by_canvas_id(cls, pid):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
@ -95,7 +126,7 @@ class UserCanvasService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||||
page_number, items_per_page,
|
page_number, items_per_page,
|
||||||
orderby, desc, keywords, canvas_category=CanvasCategory.Agent,
|
orderby, desc, keywords, canvas_category=None
|
||||||
):
|
):
|
||||||
fields = [
|
fields = [
|
||||||
cls.model.id,
|
cls.model.id,
|
||||||
@ -104,6 +135,7 @@ class UserCanvasService(CommonService):
|
|||||||
cls.model.dsl,
|
cls.model.dsl,
|
||||||
cls.model.description,
|
cls.model.description,
|
||||||
cls.model.permission,
|
cls.model.permission,
|
||||||
|
cls.model.user_id.alias("tenant_id"),
|
||||||
User.nickname,
|
User.nickname,
|
||||||
User.avatar.alias('tenant_avatar'),
|
User.avatar.alias('tenant_avatar'),
|
||||||
cls.model.update_time,
|
cls.model.update_time,
|
||||||
@ -111,31 +143,30 @@ class UserCanvasService(CommonService):
|
|||||||
]
|
]
|
||||||
if keywords:
|
if keywords:
|
||||||
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
||||||
((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)),
|
||||||
TenantPermission.TEAM.value)) | (
|
|
||||||
cls.model.user_id == user_id)),
|
|
||||||
(fn.LOWER(cls.model.title).contains(keywords.lower()))
|
(fn.LOWER(cls.model.title).contains(keywords.lower()))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
||||||
((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))
|
||||||
TenantPermission.TEAM.value)) | (
|
|
||||||
cls.model.user_id == user_id))
|
|
||||||
)
|
)
|
||||||
agents = agents.where(cls.model.canvas_category == canvas_category)
|
if canvas_category:
|
||||||
|
agents = agents.where(cls.model.canvas_category == canvas_category)
|
||||||
if desc:
|
if desc:
|
||||||
agents = agents.order_by(cls.model.getter_by(orderby).desc())
|
agents = agents.order_by(cls.model.getter_by(orderby).desc())
|
||||||
else:
|
else:
|
||||||
agents = agents.order_by(cls.model.getter_by(orderby).asc())
|
agents = agents.order_by(cls.model.getter_by(orderby).asc())
|
||||||
|
|
||||||
count = agents.count()
|
count = agents.count()
|
||||||
agents = agents.paginate(page_number, items_per_page)
|
if page_number and items_per_page:
|
||||||
|
agents = agents.paginate(page_number, items_per_page)
|
||||||
return list(agents.dicts()), count
|
return list(agents.dicts()), count
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def accessible(cls, canvas_id, tenant_id):
|
def accessible(cls, canvas_id, tenant_id):
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||||
if not e:
|
if not e:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@ -14,12 +14,24 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||||
import peewee
|
import peewee
|
||||||
|
from peewee import InterfaceError, OperationalError
|
||||||
|
|
||||||
from api.db.db_models import DB
|
from api.db.db_models import DB
|
||||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||||
|
|
||||||
|
def retry_db_operation(func):
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=1, max=5),
|
||||||
|
retry=retry_if_exception_type((InterfaceError, OperationalError)),
|
||||||
|
before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
class CommonService:
|
class CommonService:
|
||||||
"""Base service class that provides common database operations.
|
"""Base service class that provides common database operations.
|
||||||
@ -202,6 +214,7 @@ class CommonService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
|
@retry_db_operation
|
||||||
def update_by_id(cls, pid, data):
|
def update_by_id(cls, pid, data):
|
||||||
# Update a single record by ID
|
# Update a single record by ID
|
||||||
# Args:
|
# Args:
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from api.db.services.dialog_service import DialogService, chat
|
|||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from rag.prompts import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
|
|
||||||
|
|
||||||
class ConversationService(CommonService):
|
class ConversationService(CommonService):
|
||||||
@ -48,6 +48,21 @@ class ConversationService(CommonService):
|
|||||||
|
|
||||||
return list(sessions.dicts())
|
return list(sessions.dicts())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_conversation_by_dialog_ids(cls, dialog_ids):
|
||||||
|
sessions = cls.model.select().where(cls.model.dialog_id.in_(dialog_ids))
|
||||||
|
sessions.order_by(cls.model.create_time.asc())
|
||||||
|
offset, limit = 0, 100
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
s_batch = sessions.offset(offset).limit(limit)
|
||||||
|
_temp = list(s_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
def structure_answer(conv, ans, message_id, session_id):
|
def structure_answer(conv, ans, message_id, session_id):
|
||||||
reference = ans["reference"]
|
reference = ans["reference"]
|
||||||
|
|||||||
@ -39,8 +39,8 @@ from graphrag.general.mind_map_extractor import MindMapExtractor
|
|||||||
from rag.app.resume import forbidden_select_fields4resume
|
from rag.app.resume import forbidden_select_fields4resume
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp.search import index_name
|
from rag.nlp.search import index_name
|
||||||
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
|
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||||
from rag.prompts.prompts import gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||||
from rag.utils import num_tokens_from_string, rmSpace
|
from rag.utils import num_tokens_from_string, rmSpace
|
||||||
from rag.utils.tavily_conn import Tavily
|
from rag.utils.tavily_conn import Tavily
|
||||||
|
|
||||||
@ -159,6 +159,22 @@ class DialogService(CommonService):
|
|||||||
|
|
||||||
return list(dialogs.dicts()), count
|
return list(dialogs.dicts()), count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_dialogs_by_tenant_id(cls, tenant_id):
|
||||||
|
fields = [cls.model.id]
|
||||||
|
dialogs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||||
|
dialogs.order_by(cls.model.create_time.asc())
|
||||||
|
offset, limit = 0, 100
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
d_batch = dialogs.offset(offset).limit(limit)
|
||||||
|
_temp = list(d_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
def chat_solo(dialog, messages, stream=True):
|
def chat_solo(dialog, messages, stream=True):
|
||||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||||
@ -176,7 +192,7 @@ def chat_solo(dialog, messages, stream=True):
|
|||||||
delta_ans = ""
|
delta_ans = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||||
answer = ans
|
answer = ans
|
||||||
delta_ans = ans[len(last_ans) :]
|
delta_ans = ans[len(last_ans):]
|
||||||
if num_tokens_from_string(delta_ans) < 16:
|
if num_tokens_from_string(delta_ans) < 16:
|
||||||
continue
|
continue
|
||||||
last_ans = answer
|
last_ans = answer
|
||||||
@ -261,13 +277,13 @@ def convert_conditions(metadata_condition):
|
|||||||
"not is": "≠"
|
"not is": "≠"
|
||||||
}
|
}
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
||||||
"key": cond["name"],
|
"key": cond["name"],
|
||||||
"value": cond["value"]
|
"value": cond["value"]
|
||||||
}
|
}
|
||||||
for cond in metadata_condition.get("conditions", [])
|
for cond in metadata_condition.get("conditions", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def meta_filter(metas: dict, filters: list[dict]):
|
def meta_filter(metas: dict, filters: list[dict]):
|
||||||
@ -284,19 +300,19 @@ def meta_filter(metas: dict, filters: list[dict]):
|
|||||||
value = str(value)
|
value = str(value)
|
||||||
|
|
||||||
for conds in [
|
for conds in [
|
||||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||||
(operator == "empty", not input),
|
(operator == "empty", not input),
|
||||||
(operator == "not empty", input),
|
(operator == "not empty", input),
|
||||||
(operator == "=", input == value),
|
(operator == "=", input == value),
|
||||||
(operator == "≠", input != value),
|
(operator == "≠", input != value),
|
||||||
(operator == ">", input > value),
|
(operator == ">", input > value),
|
||||||
(operator == "<", input < value),
|
(operator == "<", input < value),
|
||||||
(operator == "≥", input >= value),
|
(operator == "≥", input >= value),
|
||||||
(operator == "≤", input <= value),
|
(operator == "≤", input <= value),
|
||||||
]:
|
]:
|
||||||
try:
|
try:
|
||||||
if all(conds):
|
if all(conds):
|
||||||
ids.extend(docids)
|
ids.extend(docids)
|
||||||
@ -354,7 +370,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
chat_mdl.bind_tools(toolcall_session, tools)
|
chat_mdl.bind_tools(toolcall_session, tools)
|
||||||
bind_models_ts = timer()
|
bind_models_ts = timer()
|
||||||
|
|
||||||
retriever = settings.retrievaler
|
retriever = settings.retriever
|
||||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
||||||
if "doc_ids" in messages[-1]:
|
if "doc_ids" in messages[-1]:
|
||||||
@ -450,13 +466,18 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
rerank_mdl=rerank_mdl,
|
rerank_mdl=rerank_mdl,
|
||||||
rank_feature=label_question(" ".join(questions), kbs),
|
rank_feature=label_question(" ".join(questions), kbs),
|
||||||
)
|
)
|
||||||
|
if prompt_config.get("toc_enhance"):
|
||||||
|
cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
|
||||||
|
if cks:
|
||||||
|
kbinfos["chunks"] = cks
|
||||||
if prompt_config.get("tavily_api_key"):
|
if prompt_config.get("tavily_api_key"):
|
||||||
tav = Tavily(prompt_config["tavily_api_key"])
|
tav = Tavily(prompt_config["tavily_api_key"])
|
||||||
tav_res = tav.retrieve_chunks(" ".join(questions))
|
tav_res = tav.retrieve_chunks(" ".join(questions))
|
||||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||||
if prompt_config.get("use_kg"):
|
if prompt_config.get("use_kg"):
|
||||||
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
ck = settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
||||||
|
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
kbinfos["chunks"].insert(0, ck)
|
kbinfos["chunks"].insert(0, ck)
|
||||||
|
|
||||||
@ -467,7 +488,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
retrieval_ts = timer()
|
retrieval_ts = timer()
|
||||||
if not knowledges and prompt_config.get("empty_response"):
|
if not knowledges and prompt_config.get("empty_response"):
|
||||||
empty_res = prompt_config["empty_response"]
|
empty_res = prompt_config["empty_response"]
|
||||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)}
|
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||||
|
"audio_binary": tts(tts_mdl, empty_res)}
|
||||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||||
|
|
||||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||||
@ -565,7 +587,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
|
|
||||||
if langfuse_tracer:
|
if langfuse_tracer:
|
||||||
langfuse_generation = langfuse_tracer.start_generation(
|
langfuse_generation = langfuse_tracer.start_generation(
|
||||||
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
|
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"],
|
||||||
|
input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
@ -575,12 +598,12 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
if thought:
|
if thought:
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
answer = ans
|
answer = ans
|
||||||
delta_ans = ans[len(last_ans) :]
|
delta_ans = ans[len(last_ans):]
|
||||||
if num_tokens_from_string(delta_ans) < 16:
|
if num_tokens_from_string(delta_ans) < 16:
|
||||||
continue
|
continue
|
||||||
last_ans = answer
|
last_ans = answer
|
||||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||||
delta_ans = answer[len(last_ans) :]
|
delta_ans = answer[len(last_ans):]
|
||||||
if delta_ans:
|
if delta_ans:
|
||||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||||
yield decorate_answer(thought + answer)
|
yield decorate_answer(thought + answer)
|
||||||
@ -639,7 +662,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
|
|
||||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||||
tried_times += 1
|
tried_times += 1
|
||||||
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
|
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
||||||
|
|
||||||
tbl, sql = get_table()
|
tbl, sql = get_table()
|
||||||
if tbl is None:
|
if tbl is None:
|
||||||
@ -676,7 +699,9 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
|
|
||||||
# compose Markdown table
|
# compose Markdown table
|
||||||
columns = (
|
columns = (
|
||||||
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
"|" + "|".join(
|
||||||
|
[re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
|
||||||
|
"|Source|" if docid_idx and docid_idx else "|")
|
||||||
)
|
)
|
||||||
|
|
||||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||||
@ -731,7 +756,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||||
|
|
||||||
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||||
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
|
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever
|
||||||
|
|
||||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
|
||||||
@ -753,7 +778,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
doc_ids = None
|
doc_ids = None
|
||||||
|
|
||||||
kbinfos = retriever.retrieval(
|
kbinfos = retriever.retrieval(
|
||||||
question = question,
|
question=question,
|
||||||
embd_mdl=embd_mdl,
|
embd_mdl=embd_mdl,
|
||||||
tenant_ids=tenant_ids,
|
tenant_ids=tenant_ids,
|
||||||
kb_ids=kb_ids,
|
kb_ids=kb_ids,
|
||||||
@ -775,7 +800,8 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
|
|
||||||
def decorate_answer(answer):
|
def decorate_answer(answer):
|
||||||
nonlocal knowledges, kbinfos, sys_prompt
|
nonlocal knowledges, kbinfos, sys_prompt
|
||||||
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
|
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]],
|
||||||
|
embd_mdl, tkweight=0.7, vtweight=0.3)
|
||||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||||
if not recall_docs:
|
if not recall_docs:
|
||||||
@ -826,7 +852,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
|||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
doc_ids = None
|
doc_ids = None
|
||||||
|
|
||||||
ranks = settings.retrievaler.retrieval(
|
ranks = settings.retriever.retrieval(
|
||||||
question=question,
|
question=question,
|
||||||
embd_mdl=embd_mdl,
|
embd_mdl=embd_mdl,
|
||||||
tenant_ids=tenant_ids,
|
tenant_ids=tenant_ids,
|
||||||
|
|||||||
@ -24,12 +24,13 @@ from io import BytesIO
|
|||||||
|
|
||||||
import trio
|
import trio
|
||||||
import xxhash
|
import xxhash
|
||||||
from peewee import fn, Case
|
from peewee import fn, Case, JOIN
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
|
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
|
||||||
from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole
|
from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole, CanvasCategory
|
||||||
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File
|
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
|
||||||
|
User
|
||||||
from api.db.db_utils import bulk_insert_into_db
|
from api.db.db_utils import bulk_insert_into_db
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
@ -51,6 +52,7 @@ class DocumentService(CommonService):
|
|||||||
cls.model.thumbnail,
|
cls.model.thumbnail,
|
||||||
cls.model.kb_id,
|
cls.model.kb_id,
|
||||||
cls.model.parser_id,
|
cls.model.parser_id,
|
||||||
|
cls.model.pipeline_id,
|
||||||
cls.model.parser_config,
|
cls.model.parser_config,
|
||||||
cls.model.source_type,
|
cls.model.source_type,
|
||||||
cls.model.type,
|
cls.model.type,
|
||||||
@ -79,7 +81,10 @@ class DocumentService(CommonService):
|
|||||||
def get_list(cls, kb_id, page_number, items_per_page,
|
def get_list(cls, kb_id, page_number, items_per_page,
|
||||||
orderby, desc, keywords, id, name):
|
orderby, desc, keywords, id, name):
|
||||||
fields = cls.get_cls_model_fields()
|
fields = cls.get_cls_model_fields()
|
||||||
docs = cls.model.select(*fields).join(File2Document, on = (File2Document.document_id == cls.model.id)).join(File, on = (File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
|
docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\
|
||||||
|
.join(File, on = (File.id == File2Document.file_id))\
|
||||||
|
.join(UserCanvas, on = ((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)\
|
||||||
|
.where(cls.model.kb_id == kb_id)
|
||||||
if id:
|
if id:
|
||||||
docs = docs.where(
|
docs = docs.where(
|
||||||
cls.model.id == id)
|
cls.model.id == id)
|
||||||
@ -117,12 +122,22 @@ class DocumentService(CommonService):
|
|||||||
orderby, desc, keywords, run_status, types, suffix):
|
orderby, desc, keywords, run_status, types, suffix):
|
||||||
fields = cls.get_cls_model_fields()
|
fields = cls.get_cls_model_fields()
|
||||||
if keywords:
|
if keywords:
|
||||||
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
|
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
|
||||||
(cls.model.kb_id == kb_id),
|
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
.join(File, on=(File.id == File2Document.file_id))\
|
||||||
)
|
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||||
|
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
|
||||||
|
.where(
|
||||||
|
(cls.model.kb_id == kb_id),
|
||||||
|
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
|
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
|
||||||
|
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||||
|
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||||
|
.join(File, on=(File.id == File2Document.file_id))\
|
||||||
|
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
|
||||||
|
.where(cls.model.kb_id == kb_id)
|
||||||
|
|
||||||
if run_status:
|
if run_status:
|
||||||
docs = docs.where(cls.model.run.in_(run_status))
|
docs = docs.where(cls.model.run.in_(run_status))
|
||||||
@ -228,6 +243,46 @@ class DocumentService(CommonService):
|
|||||||
|
|
||||||
return int(query.scalar()) or 0
|
return int(query.scalar()) or 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_doc_ids_by_kb_ids(cls, kb_ids):
|
||||||
|
fields = [cls.model.id]
|
||||||
|
docs = cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids))
|
||||||
|
docs.order_by(cls.model.create_time.asc())
|
||||||
|
# maybe cause slow query by deep paginate, optimize later
|
||||||
|
offset, limit = 0, 100
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
doc_batch = docs.offset(offset).limit(limit)
|
||||||
|
_temp = list(doc_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_docs_by_creator_id(cls, creator_id):
|
||||||
|
fields = [
|
||||||
|
cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id
|
||||||
|
]
|
||||||
|
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(
|
||||||
|
cls.model.created_by == creator_id
|
||||||
|
)
|
||||||
|
docs.order_by(cls.model.create_time.asc())
|
||||||
|
# maybe cause slow query by deep paginate, optimize later
|
||||||
|
offset, limit = 0, 100
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
doc_batch = docs.offset(offset).limit(limit)
|
||||||
|
_temp = list(doc_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def insert(cls, doc):
|
def insert(cls, doc):
|
||||||
@ -330,8 +385,7 @@ class DocumentService(CommonService):
|
|||||||
process_duration=cls.model.process_duration + duration).where(
|
process_duration=cls.model.process_duration + duration).where(
|
||||||
cls.model.id == doc_id).execute()
|
cls.model.id == doc_id).execute()
|
||||||
if num == 0:
|
if num == 0:
|
||||||
raise LookupError(
|
logging.warning("Document not found which is supposed to be there")
|
||||||
"Document not found which is supposed to be there")
|
|
||||||
num = Knowledgebase.update(
|
num = Knowledgebase.update(
|
||||||
token_num=Knowledgebase.token_num +
|
token_num=Knowledgebase.token_num +
|
||||||
token_num,
|
token_num,
|
||||||
@ -597,6 +651,22 @@ class DocumentService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_progress(cls):
|
def update_progress(cls):
|
||||||
docs = cls.get_unfinished_docs()
|
docs = cls.get_unfinished_docs()
|
||||||
|
|
||||||
|
cls._sync_progress(docs)
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def update_progress_immediately(cls, docs:list[dict]):
|
||||||
|
if not docs:
|
||||||
|
return
|
||||||
|
|
||||||
|
cls._sync_progress(docs)
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def _sync_progress(cls, docs:list[dict]):
|
||||||
for d in docs:
|
for d in docs:
|
||||||
try:
|
try:
|
||||||
tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
|
tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
|
||||||
@ -606,8 +676,6 @@ class DocumentService(CommonService):
|
|||||||
prg = 0
|
prg = 0
|
||||||
finished = True
|
finished = True
|
||||||
bad = 0
|
bad = 0
|
||||||
has_raptor = False
|
|
||||||
has_graphrag = False
|
|
||||||
e, doc = DocumentService.get_by_id(d["id"])
|
e, doc = DocumentService.get_by_id(d["id"])
|
||||||
status = doc.run # TaskStatus.RUNNING.value
|
status = doc.run # TaskStatus.RUNNING.value
|
||||||
priority = 0
|
priority = 0
|
||||||
@ -619,24 +687,14 @@ class DocumentService(CommonService):
|
|||||||
prg += t.progress if t.progress >= 0 else 0
|
prg += t.progress if t.progress >= 0 else 0
|
||||||
if t.progress_msg.strip():
|
if t.progress_msg.strip():
|
||||||
msg.append(t.progress_msg)
|
msg.append(t.progress_msg)
|
||||||
if t.task_type == "raptor":
|
|
||||||
has_raptor = True
|
|
||||||
elif t.task_type == "graphrag":
|
|
||||||
has_graphrag = True
|
|
||||||
priority = max(priority, t.priority)
|
priority = max(priority, t.priority)
|
||||||
prg /= len(tsks)
|
prg /= len(tsks)
|
||||||
if finished and bad:
|
if finished and bad:
|
||||||
prg = -1
|
prg = -1
|
||||||
status = TaskStatus.FAIL.value
|
status = TaskStatus.FAIL.value
|
||||||
elif finished:
|
elif finished:
|
||||||
if (d["parser_config"].get("raptor") or {}).get("use_raptor") and not has_raptor:
|
prg = 1
|
||||||
queue_raptor_o_graphrag_tasks(d, "raptor", priority)
|
status = TaskStatus.DONE.value
|
||||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
|
||||||
elif (d["parser_config"].get("graphrag") or {}).get("use_graphrag") and not has_graphrag:
|
|
||||||
queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
|
|
||||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
|
||||||
else:
|
|
||||||
status = TaskStatus.DONE.value
|
|
||||||
|
|
||||||
msg = "\n".join(sorted(msg))
|
msg = "\n".join(sorted(msg))
|
||||||
info = {
|
info = {
|
||||||
@ -648,7 +706,7 @@ class DocumentService(CommonService):
|
|||||||
info["progress"] = prg
|
info["progress"] = prg
|
||||||
if msg:
|
if msg:
|
||||||
info["progress_msg"] = msg
|
info["progress_msg"] = msg
|
||||||
if msg.endswith("created task graphrag") or msg.endswith("created task raptor"):
|
if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"):
|
||||||
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
|
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||||
else:
|
else:
|
||||||
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
|
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||||
@ -729,7 +787,11 @@ class DocumentService(CommonService):
|
|||||||
"cancelled": int(cancelled),
|
"cancelled": int(cancelled),
|
||||||
}
|
}
|
||||||
|
|
||||||
def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]):
|
||||||
|
"""
|
||||||
|
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
|
||||||
|
Optionally, specify a list of doc_ids to determine which documents participate in the task.
|
||||||
|
"""
|
||||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||||
hasher = xxhash.xxh64()
|
hasher = xxhash.xxh64()
|
||||||
for field in sorted(chunking_config.keys()):
|
for field in sorted(chunking_config.keys()):
|
||||||
@ -739,11 +801,12 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
|||||||
nonlocal doc
|
nonlocal doc
|
||||||
return {
|
return {
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"doc_id": doc["id"],
|
"doc_id": fake_doc_id if fake_doc_id else doc["id"],
|
||||||
"from_page": 100000000,
|
"from_page": 100000000,
|
||||||
"to_page": 100000000,
|
"to_page": 100000000,
|
||||||
"task_type": ty,
|
"task_type": ty,
|
||||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
|
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
|
||||||
|
"begin_at": datetime.now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
task = new_task()
|
task = new_task()
|
||||||
@ -752,7 +815,12 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
|
|||||||
hasher.update(ty.encode("utf-8"))
|
hasher.update(ty.encode("utf-8"))
|
||||||
task["digest"] = hasher.hexdigest()
|
task["digest"] = hasher.hexdigest()
|
||||||
bulk_insert_into_db(Task, [task], True)
|
bulk_insert_into_db(Task, [task], True)
|
||||||
|
|
||||||
|
if ty in ["graphrag", "raptor", "mindmap"]:
|
||||||
|
task["doc_ids"] = doc_ids
|
||||||
|
DocumentService.begin2parse(doc["id"])
|
||||||
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
||||||
|
return task["id"]
|
||||||
|
|
||||||
|
|
||||||
def get_queue_length(priority):
|
def get_queue_length(priority):
|
||||||
|
|||||||
@ -38,6 +38,12 @@ class File2DocumentService(CommonService):
|
|||||||
objs = cls.model.select().where(cls.model.document_id == document_id)
|
objs = cls.model.select().where(cls.model.document_id == document_id)
|
||||||
return objs
|
return objs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_by_document_ids(cls, document_ids):
|
||||||
|
objs = cls.model.select().where(cls.model.document_id.in_(document_ids))
|
||||||
|
return list(objs.dicts())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def insert(cls, obj):
|
def insert(cls, obj):
|
||||||
@ -50,6 +56,15 @@ class File2DocumentService(CommonService):
|
|||||||
def delete_by_file_id(cls, file_id):
|
def delete_by_file_id(cls, file_id):
|
||||||
return cls.model.delete().where(cls.model.file_id == file_id).execute()
|
return cls.model.delete().where(cls.model.file_id == file_id).execute()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_by_document_ids_or_file_ids(cls, document_ids, file_ids):
|
||||||
|
if not document_ids:
|
||||||
|
return cls.model.delete().where(cls.model.file_id.in_(file_ids)).execute()
|
||||||
|
elif not file_ids:
|
||||||
|
return cls.model.delete().where(cls.model.document_id.in_(document_ids)).execute()
|
||||||
|
return cls.model.delete().where(cls.model.document_id.in_(document_ids) | cls.model.file_id.in_(file_ids)).execute()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def delete_by_document_id(cls, doc_id):
|
def delete_by_document_id(cls, doc_id):
|
||||||
|
|||||||
@ -161,6 +161,23 @@ class FileService(CommonService):
|
|||||||
result_ids.append(folder_id)
|
result_ids.append(folder_id)
|
||||||
return result_ids
|
return result_ids
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_file_ids_by_tenant_id(cls, tenant_id):
|
||||||
|
fields = [cls.model.id]
|
||||||
|
files = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||||
|
files.order_by(cls.model.create_time.asc())
|
||||||
|
offset, limit = 0, 100
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
file_batch = files.offset(offset).limit(limit)
|
||||||
|
_temp = list(file_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def create_folder(cls, file, parent_id, name, count):
|
def create_folder(cls, file, parent_id, name, count):
|
||||||
@ -440,6 +457,7 @@ class FileService(CommonService):
|
|||||||
"id": doc_id,
|
"id": doc_id,
|
||||||
"kb_id": kb.id,
|
"kb_id": kb.id,
|
||||||
"parser_id": self.get_parser(filetype, filename, kb.parser_id),
|
"parser_id": self.get_parser(filetype, filename, kb.parser_id),
|
||||||
|
"pipeline_id": kb.pipeline_id,
|
||||||
"parser_config": kb.parser_config,
|
"parser_config": kb.parser_config,
|
||||||
"created_by": user_id,
|
"created_by": user_id,
|
||||||
"type": filetype,
|
"type": filetype,
|
||||||
@ -495,7 +513,7 @@ class FileService(CommonService):
|
|||||||
return ParserType.AUDIO.value
|
return ParserType.AUDIO.value
|
||||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||||
return ParserType.PRESENTATION.value
|
return ParserType.PRESENTATION.value
|
||||||
if re.search(r"\.(eml)$", filename):
|
if re.search(r"\.(msg|eml)$", filename):
|
||||||
return ParserType.EMAIL.value
|
return ParserType.EMAIL.value
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|||||||
@ -15,10 +15,10 @@
|
|||||||
#
|
#
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from peewee import fn
|
from peewee import fn, JOIN
|
||||||
|
|
||||||
from api.db import StatusEnum, TenantPermission
|
from api.db import StatusEnum, TenantPermission
|
||||||
from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant
|
from api.db.db_models import DB, Document, Knowledgebase, User, UserTenant, UserCanvas
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.utils import current_timestamp, datetime_format
|
from api.utils import current_timestamp, datetime_format
|
||||||
|
|
||||||
@ -190,6 +190,41 @@ class KnowledgebaseService(CommonService):
|
|||||||
|
|
||||||
return list(kbs.dicts()), count
|
return list(kbs.dicts()), count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id):
|
||||||
|
# will get all permitted kb, be cautious.
|
||||||
|
fields = [
|
||||||
|
cls.model.name,
|
||||||
|
cls.model.language,
|
||||||
|
cls.model.permission,
|
||||||
|
cls.model.doc_num,
|
||||||
|
cls.model.token_num,
|
||||||
|
cls.model.chunk_num,
|
||||||
|
cls.model.status,
|
||||||
|
cls.model.create_date,
|
||||||
|
cls.model.update_date
|
||||||
|
]
|
||||||
|
# find team kb and owned kb
|
||||||
|
kbs = cls.model.select(*fields).where(
|
||||||
|
(cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | (
|
||||||
|
cls.model.tenant_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# sort by create_time asc
|
||||||
|
kbs.order_by(cls.model.create_time.asc())
|
||||||
|
# maybe cause slow query by deep paginate, optimize later.
|
||||||
|
offset, limit = 0, 50
|
||||||
|
res = []
|
||||||
|
while True:
|
||||||
|
kb_batch = kbs.offset(offset).limit(limit)
|
||||||
|
_temp = list(kb_batch.dicts())
|
||||||
|
if not _temp:
|
||||||
|
break
|
||||||
|
res.extend(_temp)
|
||||||
|
offset += limit
|
||||||
|
return res
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_ids(cls, tenant_id):
|
def get_kb_ids(cls, tenant_id):
|
||||||
@ -225,20 +260,29 @@ class KnowledgebaseService(CommonService):
|
|||||||
cls.model.token_num,
|
cls.model.token_num,
|
||||||
cls.model.chunk_num,
|
cls.model.chunk_num,
|
||||||
cls.model.parser_id,
|
cls.model.parser_id,
|
||||||
|
cls.model.pipeline_id,
|
||||||
|
UserCanvas.title.alias("pipeline_name"),
|
||||||
|
UserCanvas.avatar.alias("pipeline_avatar"),
|
||||||
cls.model.parser_config,
|
cls.model.parser_config,
|
||||||
cls.model.pagerank,
|
cls.model.pagerank,
|
||||||
|
cls.model.graphrag_task_id,
|
||||||
|
cls.model.graphrag_task_finish_at,
|
||||||
|
cls.model.raptor_task_id,
|
||||||
|
cls.model.raptor_task_finish_at,
|
||||||
|
cls.model.mindmap_task_id,
|
||||||
|
cls.model.mindmap_task_finish_at,
|
||||||
cls.model.create_time,
|
cls.model.create_time,
|
||||||
cls.model.update_time
|
cls.model.update_time
|
||||||
]
|
]
|
||||||
kbs = cls.model.select(*fields).join(Tenant, on=(
|
kbs = cls.model.select(*fields)\
|
||||||
(Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
|
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||||
|
.where(
|
||||||
(cls.model.id == kb_id),
|
(cls.model.id == kb_id),
|
||||||
(cls.model.status == StatusEnum.VALID.value)
|
(cls.model.status == StatusEnum.VALID.value)
|
||||||
)
|
).dicts()
|
||||||
if not kbs:
|
if not kbs:
|
||||||
return
|
return
|
||||||
d = kbs[0].to_dict()
|
return kbs[0]
|
||||||
return d
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
@ -335,6 +379,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
# name: Optional name filter
|
# name: Optional name filter
|
||||||
# Returns:
|
# Returns:
|
||||||
# List of knowledge bases
|
# List of knowledge bases
|
||||||
|
# Total count of knowledge bases
|
||||||
kbs = cls.model.select()
|
kbs = cls.model.select()
|
||||||
if id:
|
if id:
|
||||||
kbs = kbs.where(cls.model.id == id)
|
kbs = kbs.where(cls.model.id == id)
|
||||||
@ -346,14 +391,16 @@ class KnowledgebaseService(CommonService):
|
|||||||
cls.model.tenant_id == user_id))
|
cls.model.tenant_id == user_id))
|
||||||
& (cls.model.status == StatusEnum.VALID.value)
|
& (cls.model.status == StatusEnum.VALID.value)
|
||||||
)
|
)
|
||||||
|
|
||||||
if desc:
|
if desc:
|
||||||
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
||||||
else:
|
else:
|
||||||
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
||||||
|
|
||||||
|
total = kbs.count()
|
||||||
kbs = kbs.paginate(page_number, items_per_page)
|
kbs = kbs.paginate(page_number, items_per_page)
|
||||||
|
|
||||||
return list(kbs.dicts())
|
return list(kbs.dicts()), total
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
@ -436,3 +483,17 @@ class KnowledgebaseService(CommonService):
|
|||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def decrease_document_num_in_delete(cls, kb_id, doc_num_info: dict):
|
||||||
|
kb_row = cls.model.get_by_id(kb_id)
|
||||||
|
if not kb_row:
|
||||||
|
raise RuntimeError(f"kb_id {kb_id} does not exist")
|
||||||
|
update_dict = {
|
||||||
|
'doc_num': kb_row.doc_num - doc_num_info['doc_num'],
|
||||||
|
'chunk_num': kb_row.chunk_num - doc_num_info['chunk_num'],
|
||||||
|
'token_num': kb_row.token_num - doc_num_info['token_num'],
|
||||||
|
'update_time': current_timestamp(),
|
||||||
|
'update_date': datetime_format(datetime.now())
|
||||||
|
}
|
||||||
|
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()
|
||||||
|
|||||||
@ -51,6 +51,11 @@ class TenantLangfuseService(CommonService):
|
|||||||
except peewee.DoesNotExist:
|
except peewee.DoesNotExist:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_ty_tenant_id(cls, tenant_id):
|
||||||
|
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_by_tenant(cls, tenant_id, langfuse_keys):
|
def update_by_tenant(cls, tenant_id, langfuse_keys):
|
||||||
langfuse_keys["update_time"] = current_timestamp()
|
langfuse_keys["update_time"] = current_timestamp()
|
||||||
|
|||||||
@ -33,7 +33,8 @@ class MCPServerService(CommonService):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_servers(cls, tenant_id: str, id_list: list[str] | None, page_number, items_per_page, orderby, desc, keywords):
|
def get_servers(cls, tenant_id: str, id_list: list[str] | None, page_number, items_per_page, orderby, desc,
|
||||||
|
keywords):
|
||||||
"""Retrieve all MCP servers associated with a tenant.
|
"""Retrieve all MCP servers associated with a tenant.
|
||||||
|
|
||||||
This method fetches all MCP servers for a given tenant, ordered by creation time.
|
This method fetches all MCP servers for a given tenant, ordered by creation time.
|
||||||
@ -84,3 +85,8 @@ class MCPServerService(CommonService):
|
|||||||
return bool(mcp_server), mcp_server
|
return bool(mcp_server), mcp_server
|
||||||
except Exception:
|
except Exception:
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_by_tenant_id(cls, tenant_id: str):
|
||||||
|
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||||
|
|||||||
263
api/db/services/pipeline_operation_log_service.py
Normal file
263
api/db/services/pipeline_operation_log_service.py
Normal file
@ -0,0 +1,263 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from peewee import fn
|
||||||
|
|
||||||
|
from api.db import VALID_PIPELINE_TASK_TYPES, PipelineTaskType
|
||||||
|
from api.db.db_models import DB, Document, PipelineOperationLog
|
||||||
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
|
from api.db.services.common_service import CommonService
|
||||||
|
from api.db.services.document_service import DocumentService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID
|
||||||
|
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineOperationLogService(CommonService):
|
||||||
|
model = PipelineOperationLog
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_file_logs_fields(cls):
|
||||||
|
return [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.document_id,
|
||||||
|
cls.model.tenant_id,
|
||||||
|
cls.model.kb_id,
|
||||||
|
cls.model.pipeline_id,
|
||||||
|
cls.model.pipeline_title,
|
||||||
|
cls.model.parser_id,
|
||||||
|
cls.model.document_name,
|
||||||
|
cls.model.document_suffix,
|
||||||
|
cls.model.document_type,
|
||||||
|
cls.model.source_from,
|
||||||
|
cls.model.progress,
|
||||||
|
cls.model.progress_msg,
|
||||||
|
cls.model.process_begin_at,
|
||||||
|
cls.model.process_duration,
|
||||||
|
cls.model.dsl,
|
||||||
|
cls.model.task_type,
|
||||||
|
cls.model.operation_status,
|
||||||
|
cls.model.avatar,
|
||||||
|
cls.model.status,
|
||||||
|
cls.model.create_time,
|
||||||
|
cls.model.create_date,
|
||||||
|
cls.model.update_time,
|
||||||
|
cls.model.update_date,
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_dataset_logs_fields(cls):
|
||||||
|
return [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.tenant_id,
|
||||||
|
cls.model.kb_id,
|
||||||
|
cls.model.progress,
|
||||||
|
cls.model.progress_msg,
|
||||||
|
cls.model.process_begin_at,
|
||||||
|
cls.model.process_duration,
|
||||||
|
cls.model.task_type,
|
||||||
|
cls.model.operation_status,
|
||||||
|
cls.model.avatar,
|
||||||
|
cls.model.status,
|
||||||
|
cls.model.create_time,
|
||||||
|
cls.model.create_date,
|
||||||
|
cls.model.update_time,
|
||||||
|
cls.model.update_date,
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save(cls, **kwargs):
|
||||||
|
"""
|
||||||
|
wrap this function in a transaction
|
||||||
|
"""
|
||||||
|
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||||
|
return sample_obj
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[], dsl: str = "{}"):
|
||||||
|
referred_document_id = document_id
|
||||||
|
|
||||||
|
if referred_document_id == GRAPH_RAPTOR_FAKE_DOC_ID and fake_document_ids:
|
||||||
|
referred_document_id = fake_document_ids[0]
|
||||||
|
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||||
|
if not ok:
|
||||||
|
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||||
|
return
|
||||||
|
DocumentService.update_progress_immediately([document.to_dict()])
|
||||||
|
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||||
|
if not ok:
|
||||||
|
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||||
|
return
|
||||||
|
if document.progress not in [1, -1]:
|
||||||
|
return
|
||||||
|
operation_status = document.run
|
||||||
|
|
||||||
|
if pipeline_id:
|
||||||
|
ok, user_pipeline = UserCanvasService.get_by_id(pipeline_id)
|
||||||
|
if not ok:
|
||||||
|
raise RuntimeError(f"Pipeline {pipeline_id} not found")
|
||||||
|
tenant_id = user_pipeline.user_id
|
||||||
|
title = user_pipeline.title
|
||||||
|
avatar = user_pipeline.avatar
|
||||||
|
else:
|
||||||
|
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
|
||||||
|
if not ok:
|
||||||
|
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}")
|
||||||
|
|
||||||
|
tenant_id = kb_info.tenant_id
|
||||||
|
title = document.parser_id
|
||||||
|
avatar = document.thumbnail
|
||||||
|
|
||||||
|
if task_type not in VALID_PIPELINE_TASK_TYPES:
|
||||||
|
raise ValueError(f"Invalid task type: {task_type}")
|
||||||
|
|
||||||
|
if task_type in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
|
||||||
|
finish_at = document.process_begin_at + timedelta(seconds=document.process_duration)
|
||||||
|
if task_type == PipelineTaskType.GRAPH_RAG:
|
||||||
|
KnowledgebaseService.update_by_id(
|
||||||
|
document.kb_id,
|
||||||
|
{"graphrag_task_finish_at": finish_at},
|
||||||
|
)
|
||||||
|
elif task_type == PipelineTaskType.RAPTOR:
|
||||||
|
KnowledgebaseService.update_by_id(
|
||||||
|
document.kb_id,
|
||||||
|
{"raptor_task_finish_at": finish_at},
|
||||||
|
)
|
||||||
|
elif task_type == PipelineTaskType.MINDMAP:
|
||||||
|
KnowledgebaseService.update_by_id(
|
||||||
|
document.kb_id,
|
||||||
|
{"mindmap_task_finish_at": finish_at},
|
||||||
|
)
|
||||||
|
|
||||||
|
log = dict(
|
||||||
|
id=get_uuid(),
|
||||||
|
document_id=document_id, # GRAPH_RAPTOR_FAKE_DOC_ID or real document_id
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
kb_id=document.kb_id,
|
||||||
|
pipeline_id=pipeline_id,
|
||||||
|
pipeline_title=title,
|
||||||
|
parser_id=document.parser_id,
|
||||||
|
document_name=document.name,
|
||||||
|
document_suffix=document.suffix,
|
||||||
|
document_type=document.type,
|
||||||
|
source_from="", # TODO: add in the future
|
||||||
|
progress=document.progress,
|
||||||
|
progress_msg=document.progress_msg,
|
||||||
|
process_begin_at=document.process_begin_at,
|
||||||
|
process_duration=document.process_duration,
|
||||||
|
dsl=json.loads(dsl),
|
||||||
|
task_type=task_type,
|
||||||
|
operation_status=operation_status,
|
||||||
|
avatar=avatar,
|
||||||
|
)
|
||||||
|
log["create_time"] = current_timestamp()
|
||||||
|
log["create_date"] = datetime_format(datetime.now())
|
||||||
|
log["update_time"] = current_timestamp()
|
||||||
|
log["update_date"] = datetime_format(datetime.now())
|
||||||
|
|
||||||
|
with DB.atomic():
|
||||||
|
obj = cls.save(**log)
|
||||||
|
|
||||||
|
limit = int(os.getenv("PIPELINE_OPERATION_LOG_LIMIT", 1000))
|
||||||
|
total = cls.model.select().where(cls.model.kb_id == document.kb_id).count()
|
||||||
|
|
||||||
|
if total > limit:
|
||||||
|
keep_ids = [m.id for m in cls.model.select(cls.model.id).where(cls.model.kb_id == document.kb_id).order_by(cls.model.create_time.desc()).limit(limit)]
|
||||||
|
|
||||||
|
deleted = cls.model.delete().where(cls.model.kb_id == document.kb_id, cls.model.id.not_in(keep_ids)).execute()
|
||||||
|
logging.info(f"[PipelineOperationLogService] Cleaned {deleted} old logs, kept latest {limit} for {document.kb_id}")
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def record_pipeline_operation(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
|
||||||
|
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type, fake_document_ids=fake_document_ids)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from=None, create_date_to=None):
|
||||||
|
fields = cls.get_file_logs_fields()
|
||||||
|
if keywords:
|
||||||
|
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
|
||||||
|
else:
|
||||||
|
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
|
||||||
|
|
||||||
|
logs = logs.where(cls.model.document_id != GRAPH_RAPTOR_FAKE_DOC_ID)
|
||||||
|
|
||||||
|
if operation_status:
|
||||||
|
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||||
|
if types:
|
||||||
|
logs = logs.where(cls.model.document_type.in_(types))
|
||||||
|
if suffix:
|
||||||
|
logs = logs.where(cls.model.document_suffix.in_(suffix))
|
||||||
|
if create_date_from:
|
||||||
|
logs = logs.where(cls.model.create_date >= create_date_from)
|
||||||
|
if create_date_to:
|
||||||
|
logs = logs.where(cls.model.create_date <= create_date_to)
|
||||||
|
|
||||||
|
count = logs.count()
|
||||||
|
if desc:
|
||||||
|
logs = logs.order_by(cls.model.getter_by(orderby).desc())
|
||||||
|
else:
|
||||||
|
logs = logs.order_by(cls.model.getter_by(orderby).asc())
|
||||||
|
|
||||||
|
if page_number and items_per_page:
|
||||||
|
logs = logs.paginate(page_number, items_per_page)
|
||||||
|
|
||||||
|
return list(logs.dicts()), count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_documents_info(cls, id):
|
||||||
|
fields = [Document.id, Document.name, Document.progress, Document.kb_id]
|
||||||
|
return (
|
||||||
|
cls.model.select(*fields)
|
||||||
|
.join(Document, on=(cls.model.document_id == Document.id))
|
||||||
|
.where(
|
||||||
|
cls.model.id == id
|
||||||
|
)
|
||||||
|
.dicts()
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from=None, create_date_to=None):
|
||||||
|
fields = cls.get_dataset_logs_fields()
|
||||||
|
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID))
|
||||||
|
|
||||||
|
if operation_status:
|
||||||
|
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||||
|
if create_date_from:
|
||||||
|
logs = logs.where(cls.model.create_date >= create_date_from)
|
||||||
|
if create_date_to:
|
||||||
|
logs = logs.where(cls.model.create_date <= create_date_to)
|
||||||
|
|
||||||
|
count = logs.count()
|
||||||
|
if desc:
|
||||||
|
logs = logs.order_by(cls.model.getter_by(orderby).desc())
|
||||||
|
else:
|
||||||
|
logs = logs.order_by(cls.model.getter_by(orderby).asc())
|
||||||
|
|
||||||
|
if page_number and items_per_page:
|
||||||
|
logs = logs.paginate(page_number, items_per_page)
|
||||||
|
|
||||||
|
return list(logs.dicts()), count
|
||||||
@ -94,7 +94,8 @@ class SearchService(CommonService):
|
|||||||
query = (
|
query = (
|
||||||
cls.model.select(*fields)
|
cls.model.select(*fields)
|
||||||
.join(User, on=(cls.model.tenant_id == User.id))
|
.join(User, on=(cls.model.tenant_id == User.id))
|
||||||
.where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value))
|
.where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (
|
||||||
|
cls.model.status == StatusEnum.VALID.value))
|
||||||
)
|
)
|
||||||
|
|
||||||
if keywords:
|
if keywords:
|
||||||
@ -110,3 +111,8 @@ class SearchService(CommonService):
|
|||||||
query = query.paginate(page_number, items_per_page)
|
query = query.paginate(page_number, items_per_page)
|
||||||
|
|
||||||
return list(query.dicts()), count
|
return list(query.dicts()), count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_by_tenant_id(cls, tenant_id):
|
||||||
|
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||||
|
|||||||
@ -35,6 +35,8 @@ from rag.utils.redis_conn import REDIS_CONN
|
|||||||
from api import settings
|
from api import settings
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
|
|
||||||
|
CANVAS_DEBUG_DOC_ID = "dataflow_x"
|
||||||
|
GRAPH_RAPTOR_FAKE_DOC_ID = "graph_raptor_x"
|
||||||
|
|
||||||
def trim_header_by_lines(text: str, max_length) -> str:
|
def trim_header_by_lines(text: str, max_length) -> str:
|
||||||
# Trim header text to maximum length while preserving line breaks
|
# Trim header text to maximum length while preserving line breaks
|
||||||
@ -70,7 +72,7 @@ class TaskService(CommonService):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_task(cls, task_id):
|
def get_task(cls, task_id, doc_ids=[]):
|
||||||
"""Retrieve detailed task information by task ID.
|
"""Retrieve detailed task information by task ID.
|
||||||
|
|
||||||
This method fetches comprehensive task details including associated document,
|
This method fetches comprehensive task details including associated document,
|
||||||
@ -84,6 +86,10 @@ class TaskService(CommonService):
|
|||||||
dict: Task details dictionary containing all task information and related metadata.
|
dict: Task details dictionary containing all task information and related metadata.
|
||||||
Returns None if task is not found or has exceeded retry limit.
|
Returns None if task is not found or has exceeded retry limit.
|
||||||
"""
|
"""
|
||||||
|
doc_id = cls.model.doc_id
|
||||||
|
if doc_id == CANVAS_DEBUG_DOC_ID and doc_ids:
|
||||||
|
doc_id = doc_ids[0]
|
||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
cls.model.id,
|
cls.model.id,
|
||||||
cls.model.doc_id,
|
cls.model.doc_id,
|
||||||
@ -109,7 +115,7 @@ class TaskService(CommonService):
|
|||||||
]
|
]
|
||||||
docs = (
|
docs = (
|
||||||
cls.model.select(*fields)
|
cls.model.select(*fields)
|
||||||
.join(Document, on=(cls.model.doc_id == Document.id))
|
.join(Document, on=(doc_id == Document.id))
|
||||||
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
|
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
|
||||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||||
.where(cls.model.id == task_id)
|
.where(cls.model.id == task_id)
|
||||||
@ -159,7 +165,7 @@ class TaskService(CommonService):
|
|||||||
]
|
]
|
||||||
tasks = (
|
tasks = (
|
||||||
cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
|
cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
|
||||||
.where(cls.model.doc_id == doc_id)
|
.where(cls.model.doc_id == doc_id)
|
||||||
)
|
)
|
||||||
tasks = list(tasks.dicts())
|
tasks = list(tasks.dicts())
|
||||||
if not tasks:
|
if not tasks:
|
||||||
@ -199,18 +205,18 @@ class TaskService(CommonService):
|
|||||||
cls.model.select(
|
cls.model.select(
|
||||||
*[Document.id, Document.kb_id, Document.location, File.parent_id]
|
*[Document.id, Document.kb_id, Document.location, File.parent_id]
|
||||||
)
|
)
|
||||||
.join(Document, on=(cls.model.doc_id == Document.id))
|
.join(Document, on=(cls.model.doc_id == Document.id))
|
||||||
.join(
|
.join(
|
||||||
File2Document,
|
File2Document,
|
||||||
on=(File2Document.document_id == Document.id),
|
on=(File2Document.document_id == Document.id),
|
||||||
join_type=JOIN.LEFT_OUTER,
|
join_type=JOIN.LEFT_OUTER,
|
||||||
)
|
)
|
||||||
.join(
|
.join(
|
||||||
File,
|
File,
|
||||||
on=(File2Document.file_id == File.id),
|
on=(File2Document.file_id == File.id),
|
||||||
join_type=JOIN.LEFT_OUTER,
|
join_type=JOIN.LEFT_OUTER,
|
||||||
)
|
)
|
||||||
.where(
|
.where(
|
||||||
Document.status == StatusEnum.VALID.value,
|
Document.status == StatusEnum.VALID.value,
|
||||||
Document.run == TaskStatus.RUNNING.value,
|
Document.run == TaskStatus.RUNNING.value,
|
||||||
~(Document.type == FileType.VIRTUAL.value),
|
~(Document.type == FileType.VIRTUAL.value),
|
||||||
@ -288,25 +294,33 @@ class TaskService(CommonService):
|
|||||||
cls.model.update(progress=prog).where(
|
cls.model.update(progress=prog).where(
|
||||||
(cls.model.id == id) &
|
(cls.model.id == id) &
|
||||||
(
|
(
|
||||||
(cls.model.progress != -1) &
|
(cls.model.progress != -1) &
|
||||||
((prog == -1) | (prog > cls.model.progress))
|
((prog == -1) | (prog > cls.model.progress))
|
||||||
)
|
)
|
||||||
).execute()
|
).execute()
|
||||||
return
|
else:
|
||||||
|
with DB.lock("update_progress", -1):
|
||||||
|
if info["progress_msg"]:
|
||||||
|
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
|
||||||
|
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
||||||
|
if "progress" in info:
|
||||||
|
prog = info["progress"]
|
||||||
|
cls.model.update(progress=prog).where(
|
||||||
|
(cls.model.id == id) &
|
||||||
|
(
|
||||||
|
(cls.model.progress != -1) &
|
||||||
|
((prog == -1) | (prog > cls.model.progress))
|
||||||
|
)
|
||||||
|
).execute()
|
||||||
|
|
||||||
with DB.lock("update_progress", -1):
|
process_duration = (datetime.now() - task.begin_at).total_seconds()
|
||||||
if info["progress_msg"]:
|
cls.model.update(process_duration=process_duration).where(cls.model.id == id).execute()
|
||||||
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
|
|
||||||
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
@classmethod
|
||||||
if "progress" in info:
|
@DB.connection_context()
|
||||||
prog = info["progress"]
|
def delete_by_doc_ids(cls, doc_ids):
|
||||||
cls.model.update(progress=prog).where(
|
"""Delete task associated with a document."""
|
||||||
(cls.model.id == id) &
|
return cls.model.delete().where(cls.model.doc_id.in_(doc_ids)).execute()
|
||||||
(
|
|
||||||
(cls.model.progress != -1) &
|
|
||||||
((prog == -1) | (prog > cls.model.progress))
|
|
||||||
)
|
|
||||||
).execute()
|
|
||||||
|
|
||||||
|
|
||||||
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||||
@ -329,8 +343,16 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
|||||||
- Task digests are calculated for optimization and reuse
|
- Task digests are calculated for optimization and reuse
|
||||||
- Previous task chunks may be reused if available
|
- Previous task chunks may be reused if available
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def new_task():
|
def new_task():
|
||||||
return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0, "from_page": 0, "to_page": 100000000}
|
return {
|
||||||
|
"id": get_uuid(),
|
||||||
|
"doc_id": doc["id"],
|
||||||
|
"progress": 0.0,
|
||||||
|
"from_page": 0,
|
||||||
|
"to_page": 100000000,
|
||||||
|
"begin_at": datetime.now(),
|
||||||
|
}
|
||||||
|
|
||||||
parse_task_array = []
|
parse_task_array = []
|
||||||
|
|
||||||
@ -343,7 +365,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
|||||||
page_size = doc["parser_config"].get("task_page_size") or 12
|
page_size = doc["parser_config"].get("task_page_size") or 12
|
||||||
if doc["parser_id"] == "paper":
|
if doc["parser_id"] == "paper":
|
||||||
page_size = doc["parser_config"].get("task_page_size") or 22
|
page_size = doc["parser_config"].get("task_page_size") or 22
|
||||||
if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC":
|
if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC" or doc["parser_config"].get("toc", True):
|
||||||
page_size = 10 ** 9
|
page_size = 10 ** 9
|
||||||
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
|
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
|
||||||
for s, e in page_ranges:
|
for s, e in page_ranges:
|
||||||
@ -472,36 +494,29 @@ def has_canceled(task_id):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def queue_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, priority: int, callback=None) -> tuple[bool, str]:
|
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DEBUG_DOC_ID, file:dict=None, priority: int=0, rerun:bool=False) -> tuple[bool, str]:
|
||||||
"""
|
|
||||||
Returns a tuple (success: bool, error_message: str).
|
|
||||||
"""
|
|
||||||
_ = callback
|
|
||||||
|
|
||||||
task = dict(
|
task = dict(
|
||||||
id=get_uuid() if not task_id else task_id,
|
id=task_id,
|
||||||
doc_id=doc_id,
|
doc_id=doc_id,
|
||||||
from_page=0,
|
from_page=0,
|
||||||
to_page=100000000,
|
to_page=100000000,
|
||||||
task_type="dataflow",
|
task_type="dataflow" if not rerun else "dataflow_rerun",
|
||||||
priority=priority,
|
priority=priority,
|
||||||
|
begin_at=datetime.now(),
|
||||||
)
|
)
|
||||||
|
if doc_id not in [CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID]:
|
||||||
TaskService.model.delete().where(TaskService.model.id == task["id"]).execute()
|
TaskService.model.delete().where(TaskService.model.doc_id == doc_id).execute()
|
||||||
|
DocumentService.begin2parse(doc_id)
|
||||||
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
|
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
|
||||||
|
|
||||||
kb_id = DocumentService.get_knowledgebase_id(doc_id)
|
task["kb_id"] = DocumentService.get_knowledgebase_id(doc_id)
|
||||||
if not kb_id:
|
|
||||||
return False, f"Can't find KB of this document: {doc_id}"
|
|
||||||
|
|
||||||
task["kb_id"] = kb_id
|
|
||||||
task["tenant_id"] = tenant_id
|
task["tenant_id"] = tenant_id
|
||||||
task["task_type"] = "dataflow"
|
task["dataflow_id"] = flow_id
|
||||||
task["dsl"] = dsl
|
task["file"] = file
|
||||||
task["dataflow_id"] = get_uuid() if not flow_id else flow_id
|
|
||||||
|
|
||||||
if not REDIS_CONN.queue_product(
|
if not REDIS_CONN.queue_product(
|
||||||
get_svr_queue_name(priority), message=task
|
get_svr_queue_name(priority), message=task
|
||||||
):
|
):
|
||||||
return False, "Can't access Redis. Please check the Redis' status."
|
return False, "Can't access Redis. Please check the Redis' status."
|
||||||
|
|
||||||
|
|||||||
@ -57,8 +57,10 @@ class TenantLLMService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_my_llms(cls, tenant_id):
|
def get_my_llms(cls, tenant_id):
|
||||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
|
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name,
|
||||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
cls.model.used_tokens]
|
||||||
|
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
|
||||||
|
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||||
|
|
||||||
return list(objs)
|
return list(objs)
|
||||||
|
|
||||||
@ -122,7 +124,8 @@ class TenantLLMService(CommonService):
|
|||||||
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
||||||
if not model_config:
|
if not model_config:
|
||||||
if mdlnm == "flag-embedding":
|
if mdlnm == "flag-embedding":
|
||||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
|
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name,
|
||||||
|
"api_base": ""}
|
||||||
else:
|
else:
|
||||||
if not mdlnm:
|
if not mdlnm:
|
||||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||||
@ -137,27 +140,33 @@ class TenantLLMService(CommonService):
|
|||||||
if llm_type == LLMType.EMBEDDING.value:
|
if llm_type == LLMType.EMBEDDING.value:
|
||||||
if model_config["llm_factory"] not in EmbeddingModel:
|
if model_config["llm_factory"] not in EmbeddingModel:
|
||||||
return
|
return
|
||||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||||
|
base_url=model_config["api_base"])
|
||||||
|
|
||||||
if llm_type == LLMType.RERANK:
|
if llm_type == LLMType.RERANK:
|
||||||
if model_config["llm_factory"] not in RerankModel:
|
if model_config["llm_factory"] not in RerankModel:
|
||||||
return
|
return
|
||||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||||
|
base_url=model_config["api_base"])
|
||||||
|
|
||||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||||
if model_config["llm_factory"] not in CvModel:
|
if model_config["llm_factory"] not in CvModel:
|
||||||
return
|
return
|
||||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
|
||||||
|
base_url=model_config["api_base"], **kwargs)
|
||||||
|
|
||||||
if llm_type == LLMType.CHAT.value:
|
if llm_type == LLMType.CHAT.value:
|
||||||
if model_config["llm_factory"] not in ChatModel:
|
if model_config["llm_factory"] not in ChatModel:
|
||||||
return
|
return
|
||||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||||
|
base_url=model_config["api_base"], **kwargs)
|
||||||
|
|
||||||
if llm_type == LLMType.SPEECH2TEXT:
|
if llm_type == LLMType.SPEECH2TEXT:
|
||||||
if model_config["llm_factory"] not in Seq2txtModel:
|
if model_config["llm_factory"] not in Seq2txtModel:
|
||||||
return
|
return
|
||||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
|
||||||
|
model_name=model_config["llm_name"], lang=lang,
|
||||||
|
base_url=model_config["api_base"])
|
||||||
if llm_type == LLMType.TTS:
|
if llm_type == LLMType.TTS:
|
||||||
if model_config["llm_factory"] not in TTSModel:
|
if model_config["llm_factory"] not in TTSModel:
|
||||||
return
|
return
|
||||||
@ -194,11 +203,14 @@ class TenantLLMService(CommonService):
|
|||||||
try:
|
try:
|
||||||
num = (
|
num = (
|
||||||
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
||||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name,
|
||||||
|
cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||||
.execute()
|
.execute()
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
logging.exception(
|
||||||
|
"TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s",
|
||||||
|
tenant_id, llm_name)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
return num
|
return num
|
||||||
@ -206,9 +218,16 @@ class TenantLLMService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_openai_models(cls):
|
def get_openai_models(cls):
|
||||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"),
|
||||||
|
~(cls.model.llm_name == "text-embedding-3-small"),
|
||||||
|
~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||||
return list(objs)
|
return list(objs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_by_tenant_id(cls, tenant_id):
|
||||||
|
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||||
from api.db.services.llm_service import LLMService
|
from api.db.services.llm_service import LLMService
|
||||||
@ -245,8 +264,9 @@ class LLM4Tenant:
|
|||||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
||||||
self.langfuse = None
|
self.langfuse = None
|
||||||
if langfuse_keys:
|
if langfuse_keys:
|
||||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key,
|
||||||
|
host=langfuse_keys.host)
|
||||||
if langfuse.auth_check():
|
if langfuse.auth_check():
|
||||||
self.langfuse = langfuse
|
self.langfuse = langfuse
|
||||||
trace_id = self.langfuse.create_trace_id()
|
trace_id = self.langfuse.create_trace_id()
|
||||||
self.trace_context = {"trace_id": trace_id}
|
self.trace_context = {"trace_id": trace_id}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user